Enforce async_load_fixture in async test functions

This commit is contained in:
epenet
2025-05-27 14:46:02 +00:00
parent 8364d8a2e3
commit a2ab53fa7c
3 changed files with 78 additions and 5 deletions

View File

@@ -0,0 +1,64 @@
"""Plugin for logger invocations."""
from __future__ import annotations
from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.lint import PyLinter
FUNCTION_NAMES = ("load_fixture",)
class HassLoadFixturesChecker(BaseChecker):
"""Checker for I/O load fixtures."""
name = "hass_async_load_fixtures"
priority = -1
msgs = {
"W7481": (
"Fixtures should be loaded asynchronously in test modules",
"hass-async-load-fixtures",
"Used when a fixture is loaded synchronously",
),
}
options = ()
_function_queue: list[nodes.FunctionDef | nodes.AsyncFunctionDef]
_in_test_module: bool
def visit_module(self, node: nodes.Module) -> None:
"""Populate matchers for a Module node."""
self._in_test_module = node.name.startswith("tests.")
self._function_queue = []
def visit_functiondef(self, node: nodes.FunctionDef) -> None:
"""Visit a function definition."""
self._function_queue.append(node)
def leave_functiondef(self, node: nodes.FunctionDef) -> None:
"""Leave a function definition."""
self._function_queue.pop()
visit_asyncfunctiondef = visit_functiondef
leave_asyncfunctiondef = leave_functiondef
def visit_call(self, node: nodes.Call) -> None:
"""Check for sync I/O in load_fixture."""
if (
# Ensure we are in a test module
not self._in_test_module
# Ensure we are in an async function context
or not self._function_queue
or not isinstance(self._function_queue[-1], nodes.AsyncFunctionDef)
# Check function name
or not isinstance(node.func, nodes.Name)
or node.func.name not in FUNCTION_NAMES
):
return
self.add_message("hass-async-load-fixtures", node=node)
def register(linter: PyLinter) -> None:
"""Register the checker."""
linter.register_checker(HassLoadFixturesChecker(linter))

View File

@@ -161,6 +161,7 @@ init-hook = """\
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.typing",
"hass_async_load_fixtures",
"hass_decorator",
"hass_enforce_class_module",
"hass_enforce_sorted_platforms",

View File

@@ -9,7 +9,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util import location as location_util
from tests.common import load_fixture
from tests.common import async_load_fixture
from tests.test_util.aiohttp import AiohttpClientMocker
# Paris
@@ -77,10 +77,14 @@ def test_get_miles() -> None:
async def test_detect_location_info_whoami(
aioclient_mock: AiohttpClientMocker, session: aiohttp.ClientSession
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
session: aiohttp.ClientSession,
) -> None:
"""Test detect location info using services.home-assistant.io/whoami."""
aioclient_mock.get(location_util.WHOAMI_URL, text=load_fixture("whoami.json"))
aioclient_mock.get(
location_util.WHOAMI_URL, text=await async_load_fixture(hass, "whoami.json")
)
with patch("homeassistant.util.location.HA_VERSION", "1.0"):
info = await location_util.async_detect_location_info(session, _test_real=True)
@@ -101,10 +105,14 @@ async def test_detect_location_info_whoami(
async def test_dev_url(
aioclient_mock: AiohttpClientMocker, session: aiohttp.ClientSession
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
session: aiohttp.ClientSession,
) -> None:
"""Test usage of dev URL."""
aioclient_mock.get(location_util.WHOAMI_URL_DEV, text=load_fixture("whoami.json"))
aioclient_mock.get(
location_util.WHOAMI_URL_DEV, text=await async_load_fixture(hass, "whoami.json")
)
with patch("homeassistant.util.location.HA_VERSION", "1.0.dev0"):
info = await location_util.async_detect_location_info(session, _test_real=True)