Add tests

This commit is contained in:
jbouwh
2025-02-15 16:59:51 +00:00
parent a9edfd7e97
commit 4594f7e07d
2 changed files with 58 additions and 7 deletions

View File

@@ -387,6 +387,7 @@ class MQTT:
self.loop = hass.loop self.loop = hass.loop
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
self.is_mqttv5 = conf.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) == PROTOCOL_5
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict( self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
set set
@@ -686,9 +687,7 @@ class MQTT:
# subscriptions) is cleared on successful connect when the # subscriptions) is cleared on successful connect when the
# clean_start flag is set. For MQTT v3.1.1, the clean_session # clean_start flag is set. For MQTT v3.1.1, the clean_session
# argument of Client should be used for similar result. # argument of Client should be used for similar result.
True True if self.is_mqttv5 else mqtt.MQTT_CLEAN_START_FIRST_ONLY,
if self._mqttc.protocol == mqtt.MQTTv5
else mqtt.MQTT_CLEAN_START_FIRST_ONLY, # clean_start
) )
except (OSError, mqtt.WebsocketConnectionError) as err: except (OSError, mqtt.WebsocketConnectionError) as err:
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)

View File

@@ -1330,7 +1330,7 @@ async def test_handle_message_callback(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mqtt_config_entry_data", "protocol"), ("mqtt_config_entry_data", "protocol", "clean_session"),
[ [
( (
{ {
@@ -1338,6 +1338,7 @@ async def test_handle_message_callback(
CONF_PROTOCOL: "3.1", CONF_PROTOCOL: "3.1",
}, },
3, 3,
True,
), ),
( (
{ {
@@ -1345,6 +1346,7 @@ async def test_handle_message_callback(
CONF_PROTOCOL: "3.1.1", CONF_PROTOCOL: "3.1.1",
}, },
4, 4,
True,
), ),
( (
{ {
@@ -1352,22 +1354,72 @@ async def test_handle_message_callback(
CONF_PROTOCOL: "5", CONF_PROTOCOL: "5",
}, },
5, 5,
None,
), ),
], ],
ids=["v3.1", "v3.1.1", "v5"],
) )
async def test_setup_mqtt_client_protocol( async def test_setup_mqtt_client_clean_session_and_protocol(
mqtt_mock_entry: MqttMockHAClientGenerator, protocol: int hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient,
protocol: int,
clean_session: bool | None,
) -> None: ) -> None:
"""Test MQTT client protocol setup.""" """Test MQTT client clean_session and protocol setup."""
with patch( with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient" "homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client: ) as mock_client:
await mqtt_mock_entry() await mqtt_mock_entry()
# check if clean_session was correctly
assert mock_client.call_args[1]["clean_session"] == clean_session
# check if protocol setup was correctly # check if protocol setup was correctly
assert mock_client.call_args[1]["protocol"] == protocol assert mock_client.call_args[1]["protocol"] == protocol
@pytest.mark.parametrize(
("mqtt_config_entry_data", "connect_args"),
[
(
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "3.1",
},
call("mock-broker", 1883, 60, "", 0, 3),
),
(
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "3.1.1",
},
call("mock-broker", 1883, 60, "", 0, 3),
),
(
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "5",
},
call("mock-broker", 1883, 60, "", 0, True),
),
],
ids=["v3.1", "v3.1.1", "v5"],
)
async def test_setup_mqtt_client_clean_start(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient,
connect_args: tuple[Any],
) -> None:
"""Test MQTT client protocol connects with `clean_start` set correctly."""
await mqtt_mock_entry()
# check if clean_start was set correctly
assert len(mqtt_client_mock.connect.mock_calls) == 1
assert mqtt_client_mock.connect.mock_calls[0] == connect_args
@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2) @patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2)
async def test_handle_mqtt_timeout_on_callback( async def test_handle_mqtt_timeout_on_callback(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_debouncer: asyncio.Event hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_debouncer: asyncio.Event