Implement client caching based on ETag header

This commit is contained in:
Jules 2025-02-16 20:40:44 +01:00
parent 48fca3197f
commit 93bda52ac8
Signed by: jdejaegh
GPG key ID: 99D6D184CA66933A
2 changed files with 42 additions and 18 deletions

View file

@ -3,13 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
import hashlib import hashlib
import json
import logging import logging
import socket import socket
import time
from datetime import datetime from datetime import datetime
import aiohttp import aiohttp
import async_timeout import async_timeout
from aiohttp import ClientResponse
from .const import USER_AGENT from .const import USER_AGENT
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -35,6 +36,8 @@ def _api_key(method_name: str) -> str:
class IrmKmiApiClient: class IrmKmiApiClient:
"""API client for IRM KMI weather data""" """API client for IRM KMI weather data"""
COORD_DECIMALS = 6 COORD_DECIMALS = 6
cache_max_age = 60 * 60 * 2 # Remove items from the cache if they have not been hit since 2 hours
cache = {}
def __init__(self, session: aiohttp.ClientSession) -> None: def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session self._session = session
@ -47,18 +50,18 @@ class IrmKmiApiClient:
coord['lat'] = round(coord['lat'], self.COORD_DECIMALS) coord['lat'] = round(coord['lat'], self.COORD_DECIMALS)
coord['long'] = round(coord['long'], self.COORD_DECIMALS) coord['long'] = round(coord['long'], self.COORD_DECIMALS)
response = await self._api_wrapper(params={"s": "getForecasts", "k": _api_key("getForecasts")} | coord) response: bytes = await self._api_wrapper(params={"s": "getForecasts", "k": _api_key("getForecasts")} | coord)
return await response.json() return json.loads(response)
async def get_image(self, url, params: dict | None = None) -> bytes: async def get_image(self, url, params: dict | None = None) -> bytes:
"""Get the image at the specified url with the parameters""" """Get the image at the specified url with the parameters"""
r: ClientResponse = await self._api_wrapper(base_url=url, params={} if params is None else params) r: bytes = await self._api_wrapper(base_url=url, params={} if params is None else params)
return await r.read() return r
async def get_svg(self, url, params: dict | None = None) -> str: async def get_svg(self, url, params: dict | None = None) -> str:
"""Get SVG as str at the specified url with the parameters""" """Get SVG as str at the specified url with the parameters"""
r: ClientResponse = await self._api_wrapper(base_url=url, params={} if params is None else params) r: bytes = await self._api_wrapper(base_url=url, params={} if params is None else params)
return await r.text() return r.decode()
async def _api_wrapper( async def _api_wrapper(
self, self,
@ -68,24 +71,41 @@ class IrmKmiApiClient:
method: str = "get", method: str = "get",
data: dict | None = None, data: dict | None = None,
headers: dict | None = None, headers: dict | None = None,
) -> any: ) -> bytes:
"""Get information from the API.""" """Get information from the API."""
url = f"{self._base_url if base_url is None else base_url}{path}"
if headers is None: if headers is None:
headers = {'User-Agent': USER_AGENT} headers = {'User-Agent': USER_AGENT}
else: else:
headers['User-Agent'] = USER_AGENT headers['User-Agent'] = USER_AGENT
if url in self.cache:
headers['If-None-Match'] = self.cache[url]['etag']
try: try:
async with async_timeout.timeout(60): async with async_timeout.timeout(60):
response = await self._session.request( response = await self._session.request(
method=method, method=method,
url=f"{self._base_url if base_url is None else base_url}{path}", url=url,
headers=headers, headers=headers,
json=data, json=data,
params=params params=params
) )
response.raise_for_status() response.raise_for_status()
return response
if response.status == 304:
_LOGGER.debug(f"Cache hit for {url}")
self.cache[url]['timestamp'] = time.time()
return self.cache[url]['response']
if 'ETag' in response.headers:
_LOGGER.debug(f"Saving in cache {url}")
r = await response.read()
self.cache[url] = {'etag': response.headers['ETag'], 'response': r, 'timestamp': time.time()}
return r
return await response.read()
except asyncio.TimeoutError as exception: except asyncio.TimeoutError as exception:
raise IrmKmiApiCommunicationError("Timeout error fetching information") from exception raise IrmKmiApiCommunicationError("Timeout error fetching information") from exception
@ -93,3 +113,13 @@ class IrmKmiApiClient:
raise IrmKmiApiCommunicationError("Error fetching information") from exception raise IrmKmiApiCommunicationError("Error fetching information") from exception
except Exception as exception: # pylint: disable=broad-except except Exception as exception: # pylint: disable=broad-except
raise IrmKmiApiError(f"Something really wrong happened! {exception}") from exception raise IrmKmiApiError(f"Something really wrong happened! {exception}") from exception
def expire_cache(self):
now = time.time()
keys_to_delete = set()
for key, value in self.cache.items():
if now - value['timestamp'] > self.cache_max_age:
keys_to_delete.add(key)
for key in keys_to_delete:
del self.cache[key]
_LOGGER.info(f"Expired {len(keys_to_delete)} elements from API cache")

View file

@ -1,9 +1,8 @@
"""DataUpdateCoordinator for the IRM KMI integration.""" """DataUpdateCoordinator for the IRM KMI integration."""
import asyncio
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from statistics import mean from statistics import mean
from typing import Any, List, Tuple, Coroutine from typing import List
import urllib.parse import urllib.parse
import async_timeout import async_timeout
@ -68,7 +67,7 @@ class IrmKmiCoordinator(TimestampDataUpdateCoordinator):
This is the place to pre-process the data to lookup tables This is the place to pre-process the data to lookup tables
so entities can quickly look up their data. so entities can quickly look up their data.
""" """
_LOGGER.info("Updating weather data") self._api_client.expire_cache()
if (zone := self.hass.states.get(self._zone)) is None: if (zone := self.hass.states.get(self._zone)) is None:
raise UpdateFailed(f"Zone '{self._zone}' not found") raise UpdateFailed(f"Zone '{self._zone}' not found")
try: try:
@ -118,8 +117,6 @@ class IrmKmiCoordinator(TimestampDataUpdateCoordinator):
async def _async_animation_data(self, api_data: dict) -> RainGraph | None: async def _async_animation_data(self, api_data: dict) -> RainGraph | None:
"""From the API data passed in, call the API to get all the images and create the radar animation data object. """From the API data passed in, call the API to get all the images and create the radar animation data object.
Frames from the API are merged with the background map and the location marker to create each frame.""" Frames from the API are merged with the background map and the location marker to create each frame."""
_LOGGER.debug("_async_animation_data")
animation_data = api_data.get('animation', {}).get('sequence') animation_data = api_data.get('animation', {}).get('sequence')
localisation_layer_url = api_data.get('animation', {}).get('localisationLayer') localisation_layer_url = api_data.get('animation', {}).get('localisationLayer')
country = api_data.get('country', '') country = api_data.get('country', '')
@ -140,9 +137,6 @@ class IrmKmiCoordinator(TimestampDataUpdateCoordinator):
location=localisation location=localisation
) )
rain_graph: RainGraph = await self.create_rain_graph(radar_animation, animation_data, country, images_from_api) rain_graph: RainGraph = await self.create_rain_graph(radar_animation, animation_data, country, images_from_api)
# radar_animation['svg_animated'] = rain_graph.get_svg_string()
# radar_animation['svg_still'] = rain_graph.get_svg_string(still_image=True)
_LOGGER.debug(f"Return rain_graph from coordinator {rain_graph.get_hint()}")
return rain_graph return rain_graph
@staticmethod @staticmethod