mirror of
https://github.com/home-assistant/core.git
synced 2025-07-30 02:38:10 +02:00
Fail tests if recorder creates nested sessions (#122579)
* Fail tests if recorder creates nested sessions * Adjust import order * Move get_instance
This commit is contained in:
@ -3,13 +3,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.components.recorder import Recorder
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
|
||||
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -56,3 +68,45 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
|
||||
if DOMAIN not in hass.data:
|
||||
return False
|
||||
return await hass.data[DOMAIN].db_connected
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||
"""Get the recorder instance."""
|
||||
return hass.data[DATA_INSTANCE]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope(
|
||||
*,
|
||||
hass: HomeAssistant | None = None,
|
||||
session: Session | None = None,
|
||||
exception_filter: Callable[[Exception], bool] | None = None,
|
||||
read_only: bool = False,
|
||||
) -> Generator[Session]:
|
||||
"""Provide a transactional scope around a series of operations.
|
||||
|
||||
read_only is used to indicate that the session is only used for reading
|
||||
data and that no commit is required. It does not prevent the session
|
||||
from writing and is not a security measure.
|
||||
"""
|
||||
if session is None and hass is not None:
|
||||
session = get_instance(hass).get_session()
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError("Session required")
|
||||
|
||||
need_rollback = False
|
||||
try:
|
||||
yield session
|
||||
if not read_only and session.get_transaction():
|
||||
need_rollback = True
|
||||
session.commit()
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Error executing query")
|
||||
if need_rollback:
|
||||
session.rollback()
|
||||
if not exception_filter or not exception_filter(err):
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
Reference in New Issue
Block a user