Fail tests if recorder creates nested sessions (#122579)

* Fail tests if recorder creates nested sessions

* Adjust import order

* Move get_instance
This commit is contained in:
Erik Montnemery
2024-07-25 21:18:55 +02:00
committed by GitHub
parent 32a0463f47
commit 5dbd7684ce
6 changed files with 139 additions and 49 deletions

View File

@ -3,13 +3,25 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
import functools
import logging
from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from homeassistant.components.recorder import Recorder
_LOGGER = logging.getLogger(__name__)
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
@dataclass(slots=True)
@ -56,3 +68,45 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
if DOMAIN not in hass.data:
return False
return await hass.data[DOMAIN].db_connected
@functools.lru_cache(maxsize=1)
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
@contextmanager
def session_scope(
*,
hass: HomeAssistant | None = None,
session: Session | None = None,
exception_filter: Callable[[Exception], bool] | None = None,
read_only: bool = False,
) -> Generator[Session]:
"""Provide a transactional scope around a series of operations.
read_only is used to indicate that the session is only used for reading
data and that no commit is required. It does not prevent the session
from writing and is not a security measure.
"""
if session is None and hass is not None:
session = get_instance(hass).get_session()
if session is None:
raise RuntimeError("Session required")
need_rollback = False
try:
yield session
if not read_only and session.get_transaction():
need_rollback = True
session.commit()
except Exception as err:
_LOGGER.exception("Error executing query")
if need_rollback:
session.rollback()
if not exception_filter or not exception_filter(err):
raise
finally:
session.close()