diff --git a/pylint/plugins/hass_imports.py b/pylint/plugins/hass_imports.py index d8f85df011f..b11ef233bc2 100644 --- a/pylint/plugins/hass_imports.py +++ b/pylint/plugins/hass_imports.py @@ -394,6 +394,12 @@ _OBSOLETE_IMPORT: dict[str, list[ObsoleteImportMatch]] = { ], } +# Should be gradually synchronised with pyproject.toml +# [tool.ruff.lint.flake8-import-conventions.extend-aliases] +_NAMESPACE_IMPORT: dict[str, str] = { + "homeassistant.helpers.issue_registry": "ir", +} + class HassImportsFormatChecker(BaseChecker): """Checker for imports.""" @@ -422,6 +428,11 @@ class HassImportsFormatChecker(BaseChecker): "Used when an import from another component should be " "from the component root", ), + "W7425": ( + "Helper import should be using the namespace", + "hass-helper-namespace-import", + "Used when a helper should be used via the namespace", + ), } options = () @@ -438,12 +449,23 @@ class HassImportsFormatChecker(BaseChecker): # Strip name of the current module self.current_package = node.name[: node.name.rfind(".")] + def _check_namespace_import( + self, node: nodes.Import, module: str, alias: str | None + ) -> bool: + for helper, shorthand in _NAMESPACE_IMPORT.items(): + if module.startswith(helper) and alias != shorthand: + self.add_message("hass-helper-namespace-import", node=node) + return False + return True + def visit_import(self, node: nodes.Import) -> None: """Check for improper `import _` invocations.""" if self.current_package is None: return for module, _alias in node.names: - if module.startswith(f"{self.current_package}."): + if not self._check_namespace_import(node, module, _alias): + continue + if module.startswith("{self.current_package}."): self.add_message("hass-relative-import", node=node) continue if module.startswith("homeassistant.components.") and module.endswith( @@ -524,6 +546,11 @@ class HassImportsFormatChecker(BaseChecker): node=node, args=(import_match.string, obsolete_import.reason), ) + for name in node.names: + if not self._check_namespace_import( + node, f"{node.modname}.{name[0]}", name[1] + ): + continue def register(linter: PyLinter) -> None: diff --git a/tests/pylint/test_imports.py b/tests/pylint/test_imports.py index 5f1d4d86840..72ba7b7f057 100644 --- a/tests/pylint/test_imports.py +++ b/tests/pylint/test_imports.py @@ -252,3 +252,86 @@ def test_bad_root_import( imports_checker.visit_import(node) if import_node.startswith("from"): imports_checker.visit_importfrom(node) + + +@pytest.mark.parametrize( + ("import_node", "module_name"), + [ + ( + "import homeassistant.helpers.issue_registry as ir", + "tests.components.pylint_test.climate", + ), + ( + "from homeassistant.helpers import issue_registry as ir", + "tests.components.pylint_test.climate", + ), + ], +) +def test_good_namespace_import( + linter: UnittestLinter, + imports_checker: BaseChecker, + import_node: str, + module_name: str, +) -> None: + """Ensure good namespace imports are accepted.""" + + node = astroid.extract_node( + f"{import_node} #@", + module_name, + ) + imports_checker.visit_module(node.parent) + + with assert_no_messages(linter): + if import_node.startswith("import"): + imports_checker.visit_import(node) + if import_node.startswith("from"): + imports_checker.visit_importfrom(node) + + +@pytest.mark.parametrize( + ("import_node", "module_name"), + [ + ( + "import homeassistant.helpers.issue_registry as issue_registry", + "tests.components.pylint_test.climate", + ), + ( + "from homeassistant.helpers import issue_registry", + "tests.components.pylint_test.climate", + ), + ( + "from homeassistant.helpers.issue_registry import IssueSeverity", + "tests.components.pylint_test.climate", + ), + ], +) +def test_bad_namespace_import( + linter: UnittestLinter, + imports_checker: BaseChecker, + import_node: str, + module_name: str, +) -> None: + """Ensure bad namespace imports are rejected.""" + + node = astroid.extract_node( + f"{import_node} #@", + module_name, + ) + imports_checker.visit_module(node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-helper-namespace-import", + node=node, + args=None, + line=1, + col_offset=0, + end_line=1, + end_col_offset=len(import_node), + ), + ): + if import_node.startswith("import"): + imports_checker.visit_import(node) + if import_node.startswith("from"): + imports_checker.visit_importfrom(node)