mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 18:28:14 +02:00
Prevent recursive script calls from deadlocking (#67861)
* Prevent recursive script calls from deadlocking * Address review comments, improve tests * Tweak comment
This commit is contained in:
@ -27,6 +27,13 @@ from homeassistant.core import (
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
from homeassistant.helpers.script import (
|
||||
SCRIPT_MODE_CHOICES,
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUED,
|
||||
SCRIPT_MODE_RESTART,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
)
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
@ -790,3 +797,121 @@ async def test_script_restore_last_triggered(hass: HomeAssistant) -> None:
|
||||
state = hass.states.get("script.last_triggered")
|
||||
assert state
|
||||
assert state.attributes["last_triggered"] == time
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"script_mode,warning_msg",
|
||||
(
|
||||
(SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
|
||||
(SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
|
||||
(SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
|
||||
(SCRIPT_MODE_SINGLE, "Already running"),
|
||||
),
|
||||
)
|
||||
async def test_recursive_script(hass, script_mode, warning_msg, caplog):
|
||||
"""Test recursive script calls does not deadlock."""
|
||||
# Make sure we cover all script modes
|
||||
assert SCRIPT_MODE_CHOICES == [
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUED,
|
||||
SCRIPT_MODE_RESTART,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
]
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"script1": {
|
||||
"mode": script_mode,
|
||||
"sequence": [
|
||||
{"service": "script.script1"},
|
||||
{"service": "test.script"},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
service_called = asyncio.Event()
|
||||
|
||||
async def async_service_handler(service):
|
||||
service_called.set()
|
||||
|
||||
hass.services.async_register("test", "script", async_service_handler)
|
||||
hass.states.async_set("input_boolean.test", "on")
|
||||
hass.states.async_set("input_boolean.test2", "off")
|
||||
|
||||
await hass.services.async_call("script", "script1")
|
||||
await asyncio.wait_for(service_called.wait(), 1)
|
||||
|
||||
assert warning_msg in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"script_mode,warning_msg",
|
||||
(
|
||||
(SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
|
||||
(SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
|
||||
(SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
|
||||
(SCRIPT_MODE_SINGLE, "Already running"),
|
||||
),
|
||||
)
|
||||
async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog):
|
||||
"""Test recursive script calls does not deadlock."""
|
||||
# Make sure we cover all script modes
|
||||
assert SCRIPT_MODE_CHOICES == [
|
||||
SCRIPT_MODE_PARALLEL,
|
||||
SCRIPT_MODE_QUEUED,
|
||||
SCRIPT_MODE_RESTART,
|
||||
SCRIPT_MODE_SINGLE,
|
||||
]
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"script1": {
|
||||
"mode": script_mode,
|
||||
"sequence": [
|
||||
{"service": "script.script2"},
|
||||
],
|
||||
},
|
||||
"script2": {
|
||||
"mode": script_mode,
|
||||
"sequence": [
|
||||
{"service": "script.script3"},
|
||||
],
|
||||
},
|
||||
"script3": {
|
||||
"mode": script_mode,
|
||||
"sequence": [
|
||||
{"service": "script.script4"},
|
||||
],
|
||||
},
|
||||
"script4": {
|
||||
"mode": script_mode,
|
||||
"sequence": [
|
||||
{"service": "script.script1"},
|
||||
{"service": "test.script"},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
service_called = asyncio.Event()
|
||||
|
||||
async def async_service_handler(service):
|
||||
service_called.set()
|
||||
|
||||
hass.services.async_register("test", "script", async_service_handler)
|
||||
hass.states.async_set("input_boolean.test", "on")
|
||||
hass.states.async_set("input_boolean.test2", "off")
|
||||
|
||||
await hass.services.async_call("script", "script1")
|
||||
await asyncio.wait_for(service_called.wait(), 1)
|
||||
|
||||
assert warning_msg in caplog.text
|
||||
|
Reference in New Issue
Block a user