Refactor classes

This commit is contained in:
Jules 2024-06-20 20:33:23 +02:00
parent 6be82c942a
commit 2c09eec3e1
Signed by: jdejaegh
GPG key ID: 99D6D184CA66933A
4 changed files with 30 additions and 18 deletions

View file

@ -1,6 +1,7 @@
import asyncio
import csv
import socket
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, date
from io import StringIO
from itertools import product
@ -12,7 +13,7 @@ import async_timeout
from aiohttp import ClientResponse
from . import rio_wfs_base_url, user_agent, forecast_base_url
from .data import RioFeature, FeatureValue, ForecastFeature
from .data import RioFeature, FeatureValue, ForecastFeature, IrcelineFeature
from .utils import SizedDict, epsg_transform, round_coordinates
@ -20,11 +21,18 @@ class IrcelineApiError(Exception):
"""Exception to indicate an API error."""
class IrcelineBaseClient:
class IrcelineBaseClient(ABC):
def __init__(self, session: aiohttp.ClientSession, cache_size: int = 20) -> None:
self._session = session
self._cache = SizedDict(cache_size)
@abstractmethod
async def get_data(self,
timestamp: datetime | date,
features: List[IrcelineFeature],
position: Tuple[float, float]) -> dict:
pass
async def _api_wrapper(self, url: str, querystring: dict = None, headers: dict = None, method: str = 'GET'):
"""
Call the URL with the specified query string. Raises exception for >= 400 response code
@ -80,11 +88,11 @@ class IrcelineBaseClient:
class IrcelineRioClient(IrcelineBaseClient):
"""API client for RIO interpolated IRCEL - CELINE open data"""
async def get_rio_value(self,
timestamp: datetime | date,
features: List[RioFeature],
position: Tuple[float, float]
) -> Dict[RioFeature, FeatureValue]:
async def get_data(self,
timestamp: datetime | date,
features: List[RioFeature],
position: Tuple[float, float]
) -> Dict[RioFeature, FeatureValue]:
"""
Call the WFS API to get the interpolated level of RioFeature. Raises exception upon API error
:param timestamp: datetime for which to get the data for
@ -195,11 +203,11 @@ class IrcelineRioClient(IrcelineBaseClient):
class IrcelineForecastClient(IrcelineBaseClient):
"""API client for forecast IRCEL - CELINE open data"""
async def get_forecasts(self,
day: date,
features: List[ForecastFeature],
position: Tuple[float, float]
) -> Dict[Tuple[ForecastFeature, date], FeatureValue]:
async def get_data(self,
day: date,
features: List[ForecastFeature],
position: Tuple[float, float]
) -> Dict[Tuple[ForecastFeature, date], FeatureValue]:
"""
Get forecasted concentrations for the given features at the given position. The forecasts are downloaded for
the specified day and the 4 next days as well

View file

@ -3,7 +3,11 @@ from enum import StrEnum
from typing import TypedDict
class RioFeature(StrEnum):
class IrcelineFeature(StrEnum):
pass
class RioFeature(IrcelineFeature):
BC_24HMEAN = 'rio:bc_24hmean'
BC_DMEAN = 'rio:bc_dmean'
BC_HMEAN = 'rio:bc_hmean'
@ -26,7 +30,7 @@ class RioFeature(StrEnum):
SO2_HMEAN = 'rio:so2_hmean'
class ForecastFeature(StrEnum):
class ForecastFeature(IrcelineFeature):
NO2_MAXHMEAN = 'chimere_no2_maxhmean'
O3_MAXHMEAN = 'chimere_o3_maxhmean'
PM10_DMEAN = 'chimere_pm10_dmean'

View file

@ -22,7 +22,7 @@ async def test_cached_calls():
session = get_mock_session_many_csv()
client = IrcelineForecastClient(session)
_ = await client.get_forecasts(
_ = await client.get_data(
day=date(2024, 6, 19),
features=[ForecastFeature.NO2_MAXHMEAN],
position=(50.45, 4.85)
@ -39,7 +39,7 @@ async def test_cached_calls():
assert session.request.call_count == 5
session.request.assert_has_calls(calls)
_ = await client.get_forecasts(
_ = await client.get_data(
day=date(2024, 6, 19),
features=[ForecastFeature.NO2_MAXHMEAN],
position=(50.45, 4.85)
@ -61,7 +61,7 @@ async def test_missed_cached_calls():
session = get_mock_session_many_csv()
client = IrcelineForecastClient(session)
r = await client.get_forecasts(
r = await client.get_data(
day=date(2024, 6, 21),
features=[ForecastFeature.NO2_MAXHMEAN],
position=(50.45, 4.85)

View file

@ -99,7 +99,7 @@ async def test_api_rio():
d = date(2024, 6, 18)
features = [RioFeature.NO2_HMEAN, RioFeature.O3_HMEAN]
_ = await client.get_rio_value(d, features, pos)
_ = await client.get_data(d, features, pos)
session.request.assert_called_once_with(
method='GET',
url=rio_wfs_base_url,