Only apply to functions

This commit is contained in:
epenet
2024-05-27 08:28:26 +00:00
parent 2ea50a438b
commit 84f836029f
2 changed files with 26 additions and 29 deletions

View File

@@ -432,7 +432,7 @@ class HassImportsFormatChecker(BaseChecker):
"from the component root", "from the component root",
), ),
"W7425": ( "W7425": (
"Helper import should be using the namespace", "%s should be used via the namespace: %s",
"hass-helper-namespace-import", "hass-helper-namespace-import",
"Used when a helper should be used via the namespace", "Used when a helper should be used via the namespace",
), ),
@@ -452,22 +452,11 @@ class HassImportsFormatChecker(BaseChecker):
# Strip name of the current module # Strip name of the current module
self.current_package = node.name[: node.name.rfind(".")] self.current_package = node.name[: node.name.rfind(".")]
def _check_namespace_import(
self, node: nodes.Import, module: str, alias: str | None
) -> bool:
for helper, shorthand in _NAMESPACE_IMPORT.items():
if module.startswith(helper) and alias != shorthand:
self.add_message("hass-helper-namespace-import", node=node)
return False
return True
def visit_import(self, node: nodes.Import) -> None: def visit_import(self, node: nodes.Import) -> None:
"""Check for improper `import _` invocations.""" """Check for improper `import _` invocations."""
if self.current_package is None: if self.current_package is None:
return return
for module, _alias in node.names: for module, _alias in node.names:
if not self._check_namespace_import(node, module, _alias):
continue
if module.startswith("{self.current_package}."): if module.startswith("{self.current_package}."):
self.add_message("hass-relative-import", node=node) self.add_message("hass-relative-import", node=node)
continue continue
@@ -550,10 +539,24 @@ class HassImportsFormatChecker(BaseChecker):
args=(import_match.string, obsolete_import.reason), args=(import_match.string, obsolete_import.reason),
) )
for name in node.names: for name in node.names:
if not self._check_namespace_import( if self._has_invalid_namespace_import(node, node.modname, name[0], name[1]):
node, f"{node.modname}.{name[0]}", name[1] return
):
continue def _has_invalid_namespace_import(
self, node: nodes.ImportFrom, module: str, name: str, alias: str | None
) -> bool:
# Rule only applies to function imports
if not name[0].islower():
return False
for helper, shorthand in _NAMESPACE_IMPORT.items():
if module.startswith(helper) and alias != shorthand:
self.add_message(
"hass-helper-namespace-import",
node=node,
args=(name, f"{shorthand}.{name}"),
)
return True
return False
def register(linter: PyLinter) -> None: def register(linter: PyLinter) -> None:

View File

@@ -258,11 +258,11 @@ def test_bad_root_import(
("import_node", "module_name"), ("import_node", "module_name"),
[ [
( (
"import homeassistant.helpers.issue_registry as ir", "from homeassistant.helpers.issue_registry import AClass",
"tests.components.pylint_test.climate", "tests.components.pylint_test.climate",
), ),
( (
"from homeassistant.helpers import issue_registry as ir", "from homeassistant.helpers.issue_registry import A_CONSTANT",
"tests.components.pylint_test.climate", "tests.components.pylint_test.climate",
), ),
], ],
@@ -289,19 +289,12 @@ def test_good_namespace_import(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("import_node", "module_name"), ("import_node", "module_name", "expected_args"),
[ [
( (
"import homeassistant.helpers.issue_registry as issue_registry", "from homeassistant.helpers.issue_registry import a_function",
"tests.components.pylint_test.climate",
),
(
"from homeassistant.helpers import issue_registry",
"tests.components.pylint_test.climate",
),
(
"from homeassistant.helpers.issue_registry import IssueSeverity",
"tests.components.pylint_test.climate", "tests.components.pylint_test.climate",
("a_function", "ir.a_function"),
), ),
], ],
) )
@@ -310,6 +303,7 @@ def test_bad_namespace_import(
imports_checker: BaseChecker, imports_checker: BaseChecker,
import_node: str, import_node: str,
module_name: str, module_name: str,
expected_args: tuple[str, ...],
) -> None: ) -> None:
"""Ensure bad namespace imports are rejected.""" """Ensure bad namespace imports are rejected."""
@@ -324,7 +318,7 @@ def test_bad_namespace_import(
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-helper-namespace-import", msg_id="hass-helper-namespace-import",
node=node, node=node,
args=None, args=expected_args,
line=1, line=1,
col_offset=0, col_offset=0,
end_line=1, end_line=1,