This commit is contained in:
J. Nick Koston
2023-03-09 15:27:22 -10:00
parent 7afe07bc08
commit 89b43bab3d
2 changed files with 15 additions and 9 deletions

View File

@@ -1319,15 +1319,16 @@ def migrate_event_type_ids(instance: Recorder) -> bool:
db_event_type.event_type db_event_type.event_type
] = db_event_type.event_type_id ] = db_event_type.event_type_id
for event_id, event_type in events: session.execute(
session.execute( update(Events),
update(Events), [
{ {
"event_id": event_id, "event_id": event_id,
"event_type_id": event_type_to_id[event_type], "event_type_id": event_type_to_id[event_type],
}, }
) for event_id, event_type in events
session.commit() ],
)
# If there is more work to do return False # If there is more work to do return False
# so that we can be called again # so that we can be called again

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import SQLITE_URL_PREFIX, core, statistics from homeassistant.components.recorder import SQLITE_URL_PREFIX, core, statistics
from homeassistant.components.recorder.queries import select_event_type_ids
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.core import EVENT_STATE_CHANGED, Event, EventOrigin, State from homeassistant.core import EVENT_STATE_CHANGED, Event, EventOrigin, State
from homeassistant.helpers import recorder as recorder_helper from homeassistant.helpers import recorder as recorder_helper
@@ -87,7 +88,9 @@ def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
with patch.object(recorder, "db_schema", old_db_schema), patch.object( with patch.object(recorder, "db_schema", old_db_schema), patch.object(
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch.object(core, "EventData", old_db_schema.EventData), patch.object( ), patch.object(core, "EventTypes", old_db_schema.EventTypes), patch.object(
core, "EventData", old_db_schema.EventData
), patch.object(
core, "States", old_db_schema.States core, "States", old_db_schema.States
), patch.object( ), patch.object(
core, "Events", old_db_schema.Events core, "Events", old_db_schema.Events
@@ -117,8 +120,10 @@ def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
result = list( result = list(
session.query(recorder.db_schema.Events).where( session.query(recorder.db_schema.Events).filter(
recorder.db_schema.Events.event_type == "custom_event" recorder.db_schema.Events.event_type_id.in_(
select_event_type_ids(("custom_event",))
)
) )
) )
assert len(result) == 1 assert len(result) == 1