From 95186d1bb691eaef204d03a2939eb8977882a0a9 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 27 May 2024 08:58:24 +0000 Subject: [PATCH] Use blacklist --- pylint/plugins/hass_imports.py | 25 +++++++----------- tests/pylint/test_imports.py | 48 ++++++---------------------------- 2 files changed, 18 insertions(+), 55 deletions(-) diff --git a/pylint/plugins/hass_imports.py b/pylint/plugins/hass_imports.py index 550a091d855..0319fae7dde 100644 --- a/pylint/plugins/hass_imports.py +++ b/pylint/plugins/hass_imports.py @@ -396,11 +396,11 @@ _OBSOLETE_IMPORT: dict[str, list[ObsoleteImportMatch]] = { # Should be gradually synchronised with pyproject.toml # [tool.ruff.lint.flake8-import-conventions.extend-aliases] -_NAMESPACE_IMPORT: dict[str, str] = { - "homeassistant.helpers.area_registry": "ar", - "homeassistant.helpers.device_registry": "dr", - "homeassistant.helpers.entity_registry": "er", - "homeassistant.helpers.issue_registry": "ir", +_FORCE_NAMESPACE_IMPORT: dict[tuple[str, str], str] = { + ("homeassistant.helpers.area_registry", "async_get"): "ar.async_get", + ("homeassistant.helpers.device_registry", "async_get"): "dr.async_get", + ("homeassistant.helpers.entity_registry", "async_get"): "er.async_get", + ("homeassistant.helpers.issue_registry", "async_get"): "ir.async_get", } @@ -539,21 +539,16 @@ class HassImportsFormatChecker(BaseChecker): args=(import_match.string, obsolete_import.reason), ) for name in node.names: - if self._has_invalid_namespace_import(node, node.modname, name[0], name[1]): + if self._has_invalid_namespace_import(node, node.modname, name[0]): return def _has_invalid_namespace_import( - self, node: nodes.ImportFrom, module: str, name: str, alias: str | None + self, node: nodes.ImportFrom, module: str, name: str ) -> bool: - # Rule only applies to function imports - if not name[0].islower(): - return False - for helper, shorthand in _NAMESPACE_IMPORT.items(): - if module.startswith(helper) and alias != shorthand: + for key, value in _FORCE_NAMESPACE_IMPORT.items(): + if module == key[0] and name == key[1]: self.add_message( - "hass-helper-namespace-import", - node=node, - args=(name, f"{shorthand}.{name}"), + "hass-helper-namespace-import", node=node, args=(name, value) ) return True return False diff --git a/tests/pylint/test_imports.py b/tests/pylint/test_imports.py index 3ed307ccfbe..c8a07a7ed61 100644 --- a/tests/pylint/test_imports.py +++ b/tests/pylint/test_imports.py @@ -254,47 +254,18 @@ def test_bad_root_import( imports_checker.visit_importfrom(node) -@pytest.mark.parametrize( - ("import_node", "module_name"), - [ - ( - "from homeassistant.helpers.issue_registry import AClass", - "tests.components.pylint_test.climate", - ), - ( - "from homeassistant.helpers.issue_registry import A_CONSTANT", - "tests.components.pylint_test.climate", - ), - ], -) -def test_good_namespace_import( - linter: UnittestLinter, - imports_checker: BaseChecker, - import_node: str, - module_name: str, -) -> None: - """Ensure good namespace imports are accepted.""" - - node = astroid.extract_node( - f"{import_node} #@", - module_name, - ) - imports_checker.visit_module(node.parent) - - with assert_no_messages(linter): - if import_node.startswith("import"): - imports_checker.visit_import(node) - if import_node.startswith("from"): - imports_checker.visit_importfrom(node) - - @pytest.mark.parametrize( ("import_node", "module_name", "expected_args"), [ ( - "from homeassistant.helpers.issue_registry import a_function", + "from homeassistant.helpers.issue_registry import async_get", "tests.components.pylint_test.climate", - ("a_function", "ir.a_function"), + ("async_get", "ir.async_get"), + ), + ( + "from homeassistant.helpers.issue_registry import async_get as async_get_issue_registry", + "tests.components.pylint_test.climate", + ("async_get", "ir.async_get"), ), ], ) @@ -325,7 +296,4 @@ def test_bad_namespace_import( end_col_offset=len(import_node), ), ): - if import_node.startswith("import"): - imports_checker.visit_import(node) - if import_node.startswith("from"): - imports_checker.visit_importfrom(node) + imports_checker.visit_importfrom(node)