Update Amberelectric to use amberelectric version 2.0.12 (#125701)

* Add price descriptor attribute to price sensors

* Adding price descriptor sensor

* Use correct number of sensors in spike sensor tests

* Add tests for normalize_descriptor

* Removing debug message

* Removing price_descriptor attribute from the current sensor

* Refactoring everything to use the new API

* Use SiteStatus object, fix some typnig issues

* fixing test

* Adding predicted price to attributes

* Fix advanced price in forecast

* Testing advanced forecasts

* WIP: Adding advanced forecast sensor. need to add attributes, and tests

* Add advanced price attributes

* Adding forecasts to the advanced price sensor

* Appending forecasts corectly

* Appending forecasts correctly. Again

* Removing sensor for the moment. Will do in another PR

* Fix failing test that had the wrong sign

* Adding test to improve coverage on config_flow test

* Bumping amberelectric dependency to version 2

* Remove advanced code from helpers

* Use f-strings

* Bumping to version 2.0.1

* Bumping amberelectric to version 2.0.2

* Bumping amberelectric to version 2.0.2

* Bumping verion amberelectric.py to 2.0.3. Using correct enums

* Bumping amberelectric.py version to 2.0.4

* Bump version to 2.0.5

* Fix formatting

* fixing mocks to include interval_length

* Bumping to 2.0.6

* Bumping to 2.0.7

* Bumping to 2.0.8

* Bumping to 2.0.9

* Bumping version 2.0.12
This commit is contained in:
Myles Eftos
2024-11-20 21:27:24 +11:00
committed by GitHub
parent 2cfacd8bc5
commit 621c66a214
12 changed files with 352 additions and 254 deletions

View File

@ -1,7 +1,6 @@
"""Support for Amber Electric."""
from amberelectric import Configuration
from amberelectric.api import amber_api
import amberelectric
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_TOKEN
@ -15,8 +14,9 @@ type AmberConfigEntry = ConfigEntry[AmberUpdateCoordinator]
async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> bool:
"""Set up Amber Electric from a config entry."""
configuration = Configuration(access_token=entry.data[CONF_API_TOKEN])
api_instance = amber_api.AmberApi.create(configuration)
configuration = amberelectric.Configuration(access_token=entry.data[CONF_API_TOKEN])
api_client = amberelectric.ApiClient(configuration)
api_instance = amberelectric.AmberApi(api_client)
site_id = entry.data[CONF_SITE_ID]
coordinator = AmberUpdateCoordinator(hass, api_instance, site_id)

View File

@ -3,8 +3,8 @@
from __future__ import annotations
import amberelectric
from amberelectric.api import amber_api
from amberelectric.model.site import Site, SiteStatus
from amberelectric.models.site import Site
from amberelectric.models.site_status import SiteStatus
import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
@ -23,11 +23,15 @@ API_URL = "https://app.amber.com.au/developers"
def generate_site_selector_name(site: Site) -> str:
"""Generate the name to show in the site drop down in the configuration flow."""
# For some reason the generated API key returns this as any, not a string. Thanks pydantic
nmi = str(site.nmi)
if site.status == SiteStatus.CLOSED:
return site.nmi + " (Closed: " + site.closed_on.isoformat() + ")" # type: ignore[no-any-return]
if site.closed_on is None:
return f"{nmi} (Closed)"
return f"{nmi} (Closed: {site.closed_on.isoformat()})"
if site.status == SiteStatus.PENDING:
return site.nmi + " (Pending)" # type: ignore[no-any-return]
return site.nmi # type: ignore[no-any-return]
return f"{nmi} (Pending)"
return nmi
def filter_sites(sites: list[Site]) -> list[Site]:
@ -35,7 +39,7 @@ def filter_sites(sites: list[Site]) -> list[Site]:
filtered: list[Site] = []
filtered_nmi: set[str] = set()
for site in sorted(sites, key=lambda site: site.status.value):
for site in sorted(sites, key=lambda site: site.status):
if site.status == SiteStatus.ACTIVE or site.nmi not in filtered_nmi:
filtered.append(site)
filtered_nmi.add(site.nmi)
@ -56,7 +60,8 @@ class AmberElectricConfigFlow(ConfigFlow, domain=DOMAIN):
def _fetch_sites(self, token: str) -> list[Site] | None:
configuration = amberelectric.Configuration(access_token=token)
api: amber_api.AmberApi = amber_api.AmberApi.create(configuration)
api_client = amberelectric.ApiClient(configuration)
api = amberelectric.AmberApi(api_client)
try:
sites: list[Site] = filter_sites(api.get_sites())

View File

@ -5,13 +5,13 @@ from __future__ import annotations
from datetime import timedelta
from typing import Any
from amberelectric import ApiException
from amberelectric.api import amber_api
from amberelectric.model.actual_interval import ActualInterval
from amberelectric.model.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.forecast_interval import ForecastInterval
from amberelectric.model.interval import Descriptor
import amberelectric
from amberelectric.models.actual_interval import ActualInterval
from amberelectric.models.channel import ChannelType
from amberelectric.models.current_interval import CurrentInterval
from amberelectric.models.forecast_interval import ForecastInterval
from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.rest import ApiException
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
@ -31,22 +31,22 @@ def is_forecast(interval: ActualInterval | CurrentInterval | ForecastInterval) -
def is_general(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool:
"""Return true if the supplied interval is on the general channel."""
return interval.channel_type == ChannelType.GENERAL # type: ignore[no-any-return]
return interval.channel_type == ChannelType.GENERAL
def is_controlled_load(
interval: ActualInterval | CurrentInterval | ForecastInterval,
) -> bool:
"""Return true if the supplied interval is on the controlled load channel."""
return interval.channel_type == ChannelType.CONTROLLED_LOAD # type: ignore[no-any-return]
return interval.channel_type == ChannelType.CONTROLLEDLOAD
def is_feed_in(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool:
"""Return true if the supplied interval is on the feed in channel."""
return interval.channel_type == ChannelType.FEED_IN # type: ignore[no-any-return]
return interval.channel_type == ChannelType.FEEDIN
def normalize_descriptor(descriptor: Descriptor) -> str | None:
def normalize_descriptor(descriptor: PriceDescriptor | None) -> str | None:
"""Return the snake case versions of descriptor names. Returns None if the name is not recognized."""
if descriptor is None:
return None
@ -71,7 +71,7 @@ class AmberUpdateCoordinator(DataUpdateCoordinator):
"""AmberUpdateCoordinator - In charge of downloading the data for a site, which all the sensors read."""
def __init__(
self, hass: HomeAssistant, api: amber_api.AmberApi, site_id: str
self, hass: HomeAssistant, api: amberelectric.AmberApi, site_id: str
) -> None:
"""Initialise the data service."""
super().__init__(
@ -93,12 +93,13 @@ class AmberUpdateCoordinator(DataUpdateCoordinator):
"grid": {},
}
try:
data = self._api.get_current_price(self.site_id, next=48)
data = self._api.get_current_prices(self.site_id, next=48)
intervals = [interval.actual_instance for interval in data]
except ApiException as api_exception:
raise UpdateFailed("Missing price data, skipping update") from api_exception
current = [interval for interval in data if is_current(interval)]
forecasts = [interval for interval in data if is_forecast(interval)]
current = [interval for interval in intervals if is_current(interval)]
forecasts = [interval for interval in intervals if is_forecast(interval)]
general = [interval for interval in current if is_general(interval)]
if len(general) == 0:
@ -137,7 +138,7 @@ class AmberUpdateCoordinator(DataUpdateCoordinator):
interval for interval in forecasts if is_feed_in(interval)
]
LOGGER.debug("Fetched new Amber data: %s", data)
LOGGER.debug("Fetched new Amber data: %s", intervals)
return result
async def _async_update_data(self) -> dict[str, Any]:

View File

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/amberelectric",
"iot_class": "cloud_polling",
"loggers": ["amberelectric"],
"requirements": ["amberelectric==1.1.1"]
"requirements": ["amberelectric==2.0.12"]
}

View File

@ -8,9 +8,9 @@ from __future__ import annotations
from typing import Any
from amberelectric.model.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.forecast_interval import ForecastInterval
from amberelectric.models.channel import ChannelType
from amberelectric.models.current_interval import CurrentInterval
from amberelectric.models.forecast_interval import ForecastInterval
from homeassistant.components.sensor import (
SensorEntity,
@ -52,7 +52,7 @@ class AmberSensor(CoordinatorEntity[AmberUpdateCoordinator], SensorEntity):
self,
coordinator: AmberUpdateCoordinator,
description: SensorEntityDescription,
channel_type: ChannelType,
channel_type: str,
) -> None:
"""Initialize the Sensor."""
super().__init__(coordinator)
@ -73,7 +73,7 @@ class AmberPriceSensor(AmberSensor):
"""Return the current price in $/kWh."""
interval = self.coordinator.data[self.entity_description.key][self.channel_type]
if interval.channel_type == ChannelType.FEED_IN:
if interval.channel_type == ChannelType.FEEDIN:
return format_cents_to_dollars(interval.per_kwh) * -1
return format_cents_to_dollars(interval.per_kwh)
@ -87,9 +87,9 @@ class AmberPriceSensor(AmberSensor):
return data
data["duration"] = interval.duration
data["date"] = interval.date.isoformat()
data["date"] = interval.var_date.isoformat()
data["per_kwh"] = format_cents_to_dollars(interval.per_kwh)
if interval.channel_type == ChannelType.FEED_IN:
if interval.channel_type == ChannelType.FEEDIN:
data["per_kwh"] = data["per_kwh"] * -1
data["nem_date"] = interval.nem_time.isoformat()
data["spot_per_kwh"] = format_cents_to_dollars(interval.spot_per_kwh)
@ -120,7 +120,7 @@ class AmberForecastSensor(AmberSensor):
return None
interval = intervals[0]
if interval.channel_type == ChannelType.FEED_IN:
if interval.channel_type == ChannelType.FEEDIN:
return format_cents_to_dollars(interval.per_kwh) * -1
return format_cents_to_dollars(interval.per_kwh)
@ -142,10 +142,10 @@ class AmberForecastSensor(AmberSensor):
for interval in intervals:
datum = {}
datum["duration"] = interval.duration
datum["date"] = interval.date.isoformat()
datum["date"] = interval.var_date.isoformat()
datum["nem_date"] = interval.nem_time.isoformat()
datum["per_kwh"] = format_cents_to_dollars(interval.per_kwh)
if interval.channel_type == ChannelType.FEED_IN:
if interval.channel_type == ChannelType.FEEDIN:
datum["per_kwh"] = datum["per_kwh"] * -1
datum["spot_per_kwh"] = format_cents_to_dollars(interval.spot_per_kwh)
datum["start_time"] = interval.start_time.isoformat()

View File

@ -447,7 +447,7 @@ airtouch5py==0.2.10
alpha-vantage==2.3.1
# homeassistant.components.amberelectric
amberelectric==1.1.1
amberelectric==2.0.12
# homeassistant.components.amcrest
amcrest==1.9.8

View File

@ -426,7 +426,7 @@ airtouch4pyapi==1.0.5
airtouch5py==0.2.10
# homeassistant.components.amberelectric
amberelectric==1.1.1
amberelectric==2.0.12
# homeassistant.components.androidtv
androidtv[async]==0.0.73

View File

@ -2,73 +2,82 @@
from datetime import datetime, timedelta
from amberelectric.model.actual_interval import ActualInterval
from amberelectric.model.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.forecast_interval import ForecastInterval
from amberelectric.model.interval import Descriptor, SpikeStatus
from amberelectric.models.actual_interval import ActualInterval
from amberelectric.models.channel import ChannelType
from amberelectric.models.current_interval import CurrentInterval
from amberelectric.models.forecast_interval import ForecastInterval
from amberelectric.models.interval import Interval
from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.models.spike_status import SpikeStatus
from dateutil import parser
def generate_actual_interval(
channel_type: ChannelType, end_time: datetime
) -> ActualInterval:
def generate_actual_interval(channel_type: ChannelType, end_time: datetime) -> Interval:
"""Generate a mock actual interval."""
start_time = end_time - timedelta(minutes=30)
return ActualInterval(
duration=30,
spot_per_kwh=1.0,
per_kwh=8.0,
date=start_time.date(),
nem_time=end_time,
start_time=start_time,
end_time=end_time,
renewables=50,
channel_type=channel_type.value,
spike_status=SpikeStatus.NO_SPIKE.value,
descriptor=Descriptor.LOW.value,
return Interval(
ActualInterval(
type="ActualInterval",
duration=30,
spot_per_kwh=1.0,
per_kwh=8.0,
date=start_time.date(),
nem_time=end_time,
start_time=start_time,
end_time=end_time,
renewables=50,
channel_type=channel_type,
spike_status=SpikeStatus.NONE,
descriptor=PriceDescriptor.LOW,
)
)
def generate_current_interval(
channel_type: ChannelType, end_time: datetime
) -> CurrentInterval:
) -> Interval:
"""Generate a mock current price."""
start_time = end_time - timedelta(minutes=30)
return CurrentInterval(
duration=30,
spot_per_kwh=1.0,
per_kwh=8.0,
date=start_time.date(),
nem_time=end_time,
start_time=start_time,
end_time=end_time,
renewables=50.6,
channel_type=channel_type.value,
spike_status=SpikeStatus.NO_SPIKE.value,
descriptor=Descriptor.EXTREMELY_LOW.value,
estimate=True,
return Interval(
CurrentInterval(
type="CurrentInterval",
duration=30,
spot_per_kwh=1.0,
per_kwh=8.0,
date=start_time.date(),
nem_time=end_time,
start_time=start_time,
end_time=end_time,
renewables=50.6,
channel_type=channel_type,
spike_status=SpikeStatus.NONE,
descriptor=PriceDescriptor.EXTREMELYLOW,
estimate=True,
)
)
def generate_forecast_interval(
channel_type: ChannelType, end_time: datetime
) -> ForecastInterval:
) -> Interval:
"""Generate a mock forecast interval."""
start_time = end_time - timedelta(minutes=30)
return ForecastInterval(
duration=30,
spot_per_kwh=1.1,
per_kwh=8.8,
date=start_time.date(),
nem_time=end_time,
start_time=start_time,
end_time=end_time,
renewables=50,
channel_type=channel_type.value,
spike_status=SpikeStatus.NO_SPIKE.value,
descriptor=Descriptor.VERY_LOW.value,
estimate=True,
return Interval(
ForecastInterval(
type="ForecastInterval",
duration=30,
spot_per_kwh=1.1,
per_kwh=8.8,
date=start_time.date(),
nem_time=end_time,
start_time=start_time,
end_time=end_time,
renewables=50,
channel_type=channel_type,
spike_status=SpikeStatus.NONE,
descriptor=PriceDescriptor.VERYLOW,
estimate=True,
)
)
@ -94,31 +103,31 @@ GENERAL_CHANNEL = [
CONTROLLED_LOAD_CHANNEL = [
generate_current_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T08:30:00+10:00")
ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T08:30:00+10:00")
),
generate_forecast_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T09:00:00+10:00")
ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T09:00:00+10:00")
),
generate_forecast_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T09:30:00+10:00")
ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T09:30:00+10:00")
),
generate_forecast_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T10:00:00+10:00")
ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T10:00:00+10:00")
),
]
FEED_IN_CHANNEL = [
generate_current_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T08:30:00+10:00")
ChannelType.FEEDIN, parser.parse("2021-09-21T08:30:00+10:00")
),
generate_forecast_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T09:00:00+10:00")
ChannelType.FEEDIN, parser.parse("2021-09-21T09:00:00+10:00")
),
generate_forecast_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T09:30:00+10:00")
ChannelType.FEEDIN, parser.parse("2021-09-21T09:30:00+10:00")
),
generate_forecast_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T10:00:00+10:00")
ChannelType.FEEDIN, parser.parse("2021-09-21T10:00:00+10:00")
),
]

View File

@ -5,10 +5,10 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from unittest.mock import Mock, patch
from amberelectric.model.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.interval import SpikeStatus
from amberelectric.model.tariff_information import TariffInformation
from amberelectric.models.channel import ChannelType
from amberelectric.models.current_interval import CurrentInterval
from amberelectric.models.spike_status import SpikeStatus
from amberelectric.models.tariff_information import TariffInformation
from dateutil import parser
import pytest
@ -42,10 +42,10 @@ async def setup_no_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
instance.get_current_price = Mock(return_value=GENERAL_CHANNEL)
instance.get_current_prices = Mock(return_value=GENERAL_CHANNEL)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
yield mock_update.return_value
@ -65,7 +65,7 @@ async def setup_potential_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
general_channel: list[CurrentInterval] = [
@ -73,8 +73,8 @@ async def setup_potential_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
),
]
general_channel[0].spike_status = SpikeStatus.POTENTIAL
instance.get_current_price = Mock(return_value=general_channel)
general_channel[0].actual_instance.spike_status = SpikeStatus.POTENTIAL
instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
yield mock_update.return_value
@ -94,7 +94,7 @@ async def setup_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
general_channel: list[CurrentInterval] = [
@ -102,8 +102,8 @@ async def setup_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
),
]
general_channel[0].spike_status = SpikeStatus.SPIKE
instance.get_current_price = Mock(return_value=general_channel)
general_channel[0].actual_instance.spike_status = SpikeStatus.SPIKE
instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
yield mock_update.return_value
@ -156,7 +156,7 @@ async def setup_inactive_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mo
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
general_channel: list[CurrentInterval] = [
@ -164,8 +164,10 @@ async def setup_inactive_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mo
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
),
]
general_channel[0].tariff_information = TariffInformation(demandWindow=False)
instance.get_current_price = Mock(return_value=general_channel)
general_channel[0].actual_instance.tariff_information = TariffInformation(
demandWindow=False
)
instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
yield mock_update.return_value
@ -185,7 +187,7 @@ async def setup_active_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mock
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
general_channel: list[CurrentInterval] = [
@ -193,8 +195,10 @@ async def setup_active_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mock
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
),
]
general_channel[0].tariff_information = TariffInformation(demandWindow=True)
instance.get_current_price = Mock(return_value=general_channel)
general_channel[0].actual_instance.tariff_information = TariffInformation(
demandWindow=True
)
instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
yield mock_update.return_value

View File

@ -5,7 +5,8 @@ from datetime import date
from unittest.mock import Mock, patch
from amberelectric import ApiException
from amberelectric.model.site import Site, SiteStatus
from amberelectric.models.site import Site
from amberelectric.models.site_status import SiteStatus
import pytest
from homeassistant.components.amberelectric.config_flow import filter_sites
@ -28,7 +29,7 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry")
def mock_invalid_key_api() -> Generator:
"""Return an authentication error."""
with patch("amberelectric.api.AmberApi.create") as mock:
with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.side_effect = ApiException(status=403)
yield mock
@ -36,7 +37,7 @@ def mock_invalid_key_api() -> Generator:
@pytest.fixture(name="api_error")
def mock_api_error() -> Generator:
"""Return an authentication error."""
with patch("amberelectric.api.AmberApi.create") as mock:
with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.side_effect = ApiException(status=500)
yield mock
@ -45,16 +46,36 @@ def mock_api_error() -> Generator:
def mock_single_site_api() -> Generator:
"""Return a single site."""
site = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111",
[],
"Jemena",
SiteStatus.ACTIVE,
date(2002, 1, 1),
None,
id="01FG0AGP818PXK0DWHXJRRT2DH",
nmi="11111111111",
channels=[],
network="Jemena",
status=SiteStatus.ACTIVE,
active_from=date(2002, 1, 1),
closed_on=None,
interval_length=30,
)
with patch("amberelectric.api.AmberApi.create") as mock:
with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.return_value = [site]
yield mock
@pytest.fixture(name="single_site_closed_no_close_date_api")
def single_site_closed_no_close_date_api() -> Generator:
"""Return a single closed site with no closed date."""
site = Site(
id="01FG0AGP818PXK0DWHXJRRT2DH",
nmi="11111111111",
channels=[],
network="Jemena",
status=SiteStatus.CLOSED,
active_from=None,
closed_on=None,
interval_length=30,
)
with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.return_value = [site]
yield mock
@ -63,16 +84,17 @@ def mock_single_site_api() -> Generator:
def mock_single_site_pending_api() -> Generator:
"""Return a single site."""
site = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111",
[],
"Jemena",
SiteStatus.PENDING,
None,
None,
id="01FG0AGP818PXK0DWHXJRRT2DH",
nmi="11111111111",
channels=[],
network="Jemena",
status=SiteStatus.PENDING,
active_from=None,
closed_on=None,
interval_length=30,
)
with patch("amberelectric.api.AmberApi.create") as mock:
with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.return_value = [site]
yield mock
@ -82,35 +104,38 @@ def mock_single_site_rejoin_api() -> Generator:
"""Return a single site."""
instance = Mock()
site_1 = Site(
"01HGD9QB72HB3DWQNJ6SSCGXGV",
"11111111111",
[],
"Jemena",
SiteStatus.CLOSED,
date(2002, 1, 1),
date(2002, 6, 1),
id="01HGD9QB72HB3DWQNJ6SSCGXGV",
nmi="11111111111",
channels=[],
network="Jemena",
status=SiteStatus.CLOSED,
active_from=date(2002, 1, 1),
closed_on=date(2002, 6, 1),
interval_length=30,
)
site_2 = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111",
[],
"Jemena",
SiteStatus.ACTIVE,
date(2003, 1, 1),
None,
id="01FG0AGP818PXK0DWHXJRRT2DH",
nmi="11111111111",
channels=[],
network="Jemena",
status=SiteStatus.ACTIVE,
active_from=date(2003, 1, 1),
closed_on=None,
interval_length=30,
)
site_3 = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111112",
[],
"Jemena",
SiteStatus.CLOSED,
date(2003, 1, 1),
date(2003, 6, 1),
id="01FG0AGP818PXK0DWHXJRRT2DH",
nmi="11111111112",
channels=[],
network="Jemena",
status=SiteStatus.CLOSED,
active_from=date(2003, 1, 1),
closed_on=date(2003, 6, 1),
interval_length=30,
)
instance.get_sites.return_value = [site_1, site_2, site_3]
with patch("amberelectric.api.AmberApi.create", return_value=instance):
with patch("amberelectric.AmberApi", return_value=instance):
yield instance
@ -120,7 +145,7 @@ def mock_no_site_api() -> Generator:
instance = Mock()
instance.get_sites.return_value = []
with patch("amberelectric.api.AmberApi.create", return_value=instance):
with patch("amberelectric.AmberApi", return_value=instance):
yield instance
@ -188,6 +213,39 @@ async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None:
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
async def test_single_closed_site_no_closed_date(
hass: HomeAssistant, single_site_closed_no_close_date_api: Mock
) -> None:
"""Test single closed site with no closed date."""
initial_result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert initial_result.get("type") is FlowResultType.FORM
assert initial_result.get("step_id") == "user"
# Test filling in API key
enter_api_key_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data={CONF_API_TOKEN: API_KEY},
)
assert enter_api_key_result.get("type") is FlowResultType.FORM
assert enter_api_key_result.get("step_id") == "site"
select_site_result = await hass.config_entries.flow.async_configure(
enter_api_key_result["flow_id"],
{CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"},
)
# Show available sites
assert select_site_result.get("type") is FlowResultType.CREATE_ENTRY
assert select_site_result.get("title") == "Home"
data = select_site_result.get("data")
assert data
assert data[CONF_API_TOKEN] == API_KEY
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
async def test_single_site_rejoin(
hass: HomeAssistant, single_site_rejoin_api: Mock
) -> None:

View File

@ -7,10 +7,12 @@ from datetime import date
from unittest.mock import Mock, patch
from amberelectric import ApiException
from amberelectric.model.channel import Channel, ChannelType
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.interval import Descriptor, SpikeStatus
from amberelectric.model.site import Site, SiteStatus
from amberelectric.models.channel import Channel, ChannelType
from amberelectric.models.interval import Interval
from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.models.site import Site
from amberelectric.models.site_status import SiteStatus
from amberelectric.models.spike_status import SpikeStatus
from dateutil import parser
import pytest
@ -38,37 +40,40 @@ def mock_api_current_price() -> Generator:
instance = Mock()
general_site = Site(
GENERAL_ONLY_SITE_ID,
"11111111111",
[Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100")],
"Jemena",
SiteStatus.ACTIVE,
date(2021, 1, 1),
None,
id=GENERAL_ONLY_SITE_ID,
nmi="11111111111",
channels=[Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100")],
network="Jemena",
status=SiteStatus("active"),
activeFrom=date(2021, 1, 1),
closedOn=None,
interval_length=30,
)
general_and_controlled_load = Site(
GENERAL_AND_CONTROLLED_SITE_ID,
"11111111112",
[
id=GENERAL_AND_CONTROLLED_SITE_ID,
nmi="11111111112",
channels=[
Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"),
Channel(identifier="E2", type=ChannelType.CONTROLLED_LOAD, tariff="A180"),
Channel(identifier="E2", type=ChannelType.CONTROLLEDLOAD, tariff="A180"),
],
"Jemena",
SiteStatus.ACTIVE,
date(2021, 1, 1),
None,
network="Jemena",
status=SiteStatus("active"),
activeFrom=date(2021, 1, 1),
closedOn=None,
interval_length=30,
)
general_and_feed_in = Site(
GENERAL_AND_FEED_IN_SITE_ID,
"11111111113",
[
id=GENERAL_AND_FEED_IN_SITE_ID,
nmi="11111111113",
channels=[
Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"),
Channel(identifier="E2", type=ChannelType.FEED_IN, tariff="A100"),
Channel(identifier="E2", type=ChannelType.FEEDIN, tariff="A100"),
],
"Jemena",
SiteStatus.ACTIVE,
date(2021, 1, 1),
None,
network="Jemena",
status=SiteStatus("active"),
activeFrom=date(2021, 1, 1),
closedOn=None,
interval_length=30,
)
instance.get_sites.return_value = [
general_site,
@ -76,44 +81,46 @@ def mock_api_current_price() -> Generator:
general_and_feed_in,
]
with patch("amberelectric.api.AmberApi.create", return_value=instance):
with patch("amberelectric.AmberApi", return_value=instance):
yield instance
def test_normalize_descriptor() -> None:
"""Test normalizing descriptors works correctly."""
assert normalize_descriptor(None) is None
assert normalize_descriptor(Descriptor.NEGATIVE) == "negative"
assert normalize_descriptor(Descriptor.EXTREMELY_LOW) == "extremely_low"
assert normalize_descriptor(Descriptor.VERY_LOW) == "very_low"
assert normalize_descriptor(Descriptor.LOW) == "low"
assert normalize_descriptor(Descriptor.NEUTRAL) == "neutral"
assert normalize_descriptor(Descriptor.HIGH) == "high"
assert normalize_descriptor(Descriptor.SPIKE) == "spike"
assert normalize_descriptor(PriceDescriptor.NEGATIVE) == "negative"
assert normalize_descriptor(PriceDescriptor.EXTREMELYLOW) == "extremely_low"
assert normalize_descriptor(PriceDescriptor.VERYLOW) == "very_low"
assert normalize_descriptor(PriceDescriptor.LOW) == "low"
assert normalize_descriptor(PriceDescriptor.NEUTRAL) == "neutral"
assert normalize_descriptor(PriceDescriptor.HIGH) == "high"
assert normalize_descriptor(PriceDescriptor.SPIKE) == "spike"
async def test_fetch_general_site(hass: HomeAssistant, current_price_api: Mock) -> None:
"""Test fetching a site with only a general channel."""
current_price_api.get_current_price.return_value = GENERAL_CHANNEL
current_price_api.get_current_prices.return_value = GENERAL_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with(
current_price_api.get_current_prices.assert_called_with(
GENERAL_ONLY_SITE_ID, next=48
)
assert result["current"].get("general") == GENERAL_CHANNEL[0]
assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1],
GENERAL_CHANNEL[2],
GENERAL_CHANNEL[3],
GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3].actual_instance,
]
assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables)
assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none"
@ -122,12 +129,12 @@ async def test_fetch_no_general_site(
) -> None:
"""Test fetching a site with no general channel."""
current_price_api.get_current_price.return_value = CONTROLLED_LOAD_CHANNEL
current_price_api.get_current_prices.return_value = CONTROLLED_LOAD_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
with pytest.raises(UpdateFailed):
await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with(
current_price_api.get_current_prices.assert_called_with(
GENERAL_ONLY_SITE_ID, next=48
)
@ -135,41 +142,45 @@ async def test_fetch_no_general_site(
async def test_fetch_api_error(hass: HomeAssistant, current_price_api: Mock) -> None:
"""Test that the old values are maintained if a second call fails."""
current_price_api.get_current_price.return_value = GENERAL_CHANNEL
current_price_api.get_current_prices.return_value = GENERAL_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with(
current_price_api.get_current_prices.assert_called_with(
GENERAL_ONLY_SITE_ID, next=48
)
assert result["current"].get("general") == GENERAL_CHANNEL[0]
assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1],
GENERAL_CHANNEL[2],
GENERAL_CHANNEL[3],
GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3].actual_instance,
]
assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables)
assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
current_price_api.get_current_price.side_effect = ApiException(status=403)
current_price_api.get_current_prices.side_effect = ApiException(status=403)
with pytest.raises(UpdateFailed):
await data_service._async_update_data()
assert result["current"].get("general") == GENERAL_CHANNEL[0]
assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1],
GENERAL_CHANNEL[2],
GENERAL_CHANNEL[3],
GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3].actual_instance,
]
assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables)
assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none"
@ -178,7 +189,7 @@ async def test_fetch_general_and_controlled_load_site(
) -> None:
"""Test fetching a site with a general and controlled load channel."""
current_price_api.get_current_price.return_value = (
current_price_api.get_current_prices.return_value = (
GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL
)
data_service = AmberUpdateCoordinator(
@ -186,25 +197,30 @@ async def test_fetch_general_and_controlled_load_site(
)
result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with(
current_price_api.get_current_prices.assert_called_with(
GENERAL_AND_CONTROLLED_SITE_ID, next=48
)
assert result["current"].get("general") == GENERAL_CHANNEL[0]
assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1],
GENERAL_CHANNEL[2],
GENERAL_CHANNEL[3],
GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3].actual_instance,
]
assert result["current"].get("controlled_load") is CONTROLLED_LOAD_CHANNEL[0]
assert (
result["current"].get("controlled_load")
is CONTROLLED_LOAD_CHANNEL[0].actual_instance
)
assert result["forecasts"].get("controlled_load") == [
CONTROLLED_LOAD_CHANNEL[1],
CONTROLLED_LOAD_CHANNEL[2],
CONTROLLED_LOAD_CHANNEL[3],
CONTROLLED_LOAD_CHANNEL[1].actual_instance,
CONTROLLED_LOAD_CHANNEL[2].actual_instance,
CONTROLLED_LOAD_CHANNEL[3].actual_instance,
]
assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables)
assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none"
@ -213,31 +229,35 @@ async def test_fetch_general_and_feed_in_site(
) -> None:
"""Test fetching a site with a general and feed_in channel."""
current_price_api.get_current_price.return_value = GENERAL_CHANNEL + FEED_IN_CHANNEL
current_price_api.get_current_prices.return_value = (
GENERAL_CHANNEL + FEED_IN_CHANNEL
)
data_service = AmberUpdateCoordinator(
hass, current_price_api, GENERAL_AND_FEED_IN_SITE_ID
)
result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with(
current_price_api.get_current_prices.assert_called_with(
GENERAL_AND_FEED_IN_SITE_ID, next=48
)
assert result["current"].get("general") == GENERAL_CHANNEL[0]
assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1],
GENERAL_CHANNEL[2],
GENERAL_CHANNEL[3],
GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3].actual_instance,
]
assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is FEED_IN_CHANNEL[0]
assert result["current"].get("feed_in") is FEED_IN_CHANNEL[0].actual_instance
assert result["forecasts"].get("feed_in") == [
FEED_IN_CHANNEL[1],
FEED_IN_CHANNEL[2],
FEED_IN_CHANNEL[3],
FEED_IN_CHANNEL[1].actual_instance,
FEED_IN_CHANNEL[2].actual_instance,
FEED_IN_CHANNEL[3].actual_instance,
]
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables)
assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none"
@ -246,13 +266,13 @@ async def test_fetch_potential_spike(
) -> None:
"""Test fetching a site with only a general channel."""
general_channel: list[CurrentInterval] = [
general_channel: list[Interval] = [
generate_current_interval(
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
),
)
]
general_channel[0].spike_status = SpikeStatus.POTENTIAL
current_price_api.get_current_price.return_value = general_channel
general_channel[0].actual_instance.spike_status = SpikeStatus.POTENTIAL
current_price_api.get_current_prices.return_value = general_channel
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data()
assert result["grid"]["price_spike"] == "potential"
@ -261,13 +281,13 @@ async def test_fetch_potential_spike(
async def test_fetch_spike(hass: HomeAssistant, current_price_api: Mock) -> None:
"""Test fetching a site with only a general channel."""
general_channel: list[CurrentInterval] = [
general_channel: list[Interval] = [
generate_current_interval(
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
),
)
]
general_channel[0].spike_status = SpikeStatus.SPIKE
current_price_api.get_current_price.return_value = general_channel
general_channel[0].actual_instance.spike_status = SpikeStatus.SPIKE
current_price_api.get_current_prices.return_value = general_channel
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data()
assert result["grid"]["price_spike"] == "spike"

View File

@ -3,8 +3,9 @@
from collections.abc import AsyncGenerator
from unittest.mock import Mock, patch
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.range import Range
from amberelectric.models.current_interval import CurrentInterval
from amberelectric.models.interval import Interval
from amberelectric.models.range import Range
import pytest
from homeassistant.components.amberelectric.const import (
@ -44,10 +45,10 @@ async def setup_general(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
instance.get_current_price = Mock(return_value=GENERAL_CHANNEL)
instance.get_current_prices = Mock(return_value=GENERAL_CHANNEL)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
yield mock_update.return_value
@ -68,10 +69,10 @@ async def setup_general_and_controlled_load(
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
instance.get_current_price = Mock(
instance.get_current_prices = Mock(
return_value=GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL
)
assert await async_setup_component(hass, DOMAIN, {})
@ -92,10 +93,10 @@ async def setup_general_and_feed_in(hass: HomeAssistant) -> AsyncGenerator[Mock]
instance = Mock()
with patch(
"amberelectric.api.AmberApi.create",
"amberelectric.AmberApi",
return_value=instance,
) as mock_update:
instance.get_current_price = Mock(
instance.get_current_prices = Mock(
return_value=GENERAL_CHANNEL + FEED_IN_CHANNEL
)
assert await async_setup_component(hass, DOMAIN, {})
@ -126,7 +127,7 @@ async def test_general_price_sensor(hass: HomeAssistant, setup_general: Mock) ->
assert attributes.get("range_max") is None
with_range: list[CurrentInterval] = GENERAL_CHANNEL
with_range[0].range = Range(7.8, 12.4)
with_range[0].actual_instance.range = Range(min=7.8, max=12.4)
setup_general.get_current_price.return_value = with_range
config_entry = hass.config_entries.async_entries(DOMAIN)[0]
@ -211,8 +212,8 @@ async def test_general_forecast_sensor(
assert first_forecast.get("range_min") is None
assert first_forecast.get("range_max") is None
with_range: list[CurrentInterval] = GENERAL_CHANNEL
with_range[1].range = Range(7.8, 12.4)
with_range: list[Interval] = GENERAL_CHANNEL
with_range[1].actual_instance.range = Range(min=7.8, max=12.4)
setup_general.get_current_price.return_value = with_range
config_entry = hass.config_entries.async_entries(DOMAIN)[0]