Refactor classes

This commit is contained in:
Jules 2024-06-20 20:33:23 +02:00
parent 6be82c942a
commit 2c09eec3e1
Signed by: jdejaegh
GPG key ID: 99D6D184CA66933A
4 changed files with 30 additions and 18 deletions

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import csv import csv
import socket import socket
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, date from datetime import datetime, timedelta, date
from io import StringIO from io import StringIO
from itertools import product from itertools import product
@ -12,7 +13,7 @@ import async_timeout
from aiohttp import ClientResponse from aiohttp import ClientResponse
from . import rio_wfs_base_url, user_agent, forecast_base_url from . import rio_wfs_base_url, user_agent, forecast_base_url
from .data import RioFeature, FeatureValue, ForecastFeature from .data import RioFeature, FeatureValue, ForecastFeature, IrcelineFeature
from .utils import SizedDict, epsg_transform, round_coordinates from .utils import SizedDict, epsg_transform, round_coordinates
@ -20,11 +21,18 @@ class IrcelineApiError(Exception):
"""Exception to indicate an API error.""" """Exception to indicate an API error."""
class IrcelineBaseClient: class IrcelineBaseClient(ABC):
def __init__(self, session: aiohttp.ClientSession, cache_size: int = 20) -> None: def __init__(self, session: aiohttp.ClientSession, cache_size: int = 20) -> None:
self._session = session self._session = session
self._cache = SizedDict(cache_size) self._cache = SizedDict(cache_size)
@abstractmethod
async def get_data(self,
timestamp: datetime | date,
features: List[IrcelineFeature],
position: Tuple[float, float]) -> dict:
pass
async def _api_wrapper(self, url: str, querystring: dict = None, headers: dict = None, method: str = 'GET'): async def _api_wrapper(self, url: str, querystring: dict = None, headers: dict = None, method: str = 'GET'):
""" """
Call the URL with the specified query string. Raises exception for >= 400 response code Call the URL with the specified query string. Raises exception for >= 400 response code
@ -80,7 +88,7 @@ class IrcelineBaseClient:
class IrcelineRioClient(IrcelineBaseClient): class IrcelineRioClient(IrcelineBaseClient):
"""API client for RIO interpolated IRCEL - CELINE open data""" """API client for RIO interpolated IRCEL - CELINE open data"""
async def get_rio_value(self, async def get_data(self,
timestamp: datetime | date, timestamp: datetime | date,
features: List[RioFeature], features: List[RioFeature],
position: Tuple[float, float] position: Tuple[float, float]
@ -195,7 +203,7 @@ class IrcelineRioClient(IrcelineBaseClient):
class IrcelineForecastClient(IrcelineBaseClient): class IrcelineForecastClient(IrcelineBaseClient):
"""API client for forecast IRCEL - CELINE open data""" """API client for forecast IRCEL - CELINE open data"""
async def get_forecasts(self, async def get_data(self,
day: date, day: date,
features: List[ForecastFeature], features: List[ForecastFeature],
position: Tuple[float, float] position: Tuple[float, float]

View file

@ -3,7 +3,11 @@ from enum import StrEnum
from typing import TypedDict from typing import TypedDict
class RioFeature(StrEnum): class IrcelineFeature(StrEnum):
pass
class RioFeature(IrcelineFeature):
BC_24HMEAN = 'rio:bc_24hmean' BC_24HMEAN = 'rio:bc_24hmean'
BC_DMEAN = 'rio:bc_dmean' BC_DMEAN = 'rio:bc_dmean'
BC_HMEAN = 'rio:bc_hmean' BC_HMEAN = 'rio:bc_hmean'
@ -26,7 +30,7 @@ class RioFeature(StrEnum):
SO2_HMEAN = 'rio:so2_hmean' SO2_HMEAN = 'rio:so2_hmean'
class ForecastFeature(StrEnum): class ForecastFeature(IrcelineFeature):
NO2_MAXHMEAN = 'chimere_no2_maxhmean' NO2_MAXHMEAN = 'chimere_no2_maxhmean'
O3_MAXHMEAN = 'chimere_o3_maxhmean' O3_MAXHMEAN = 'chimere_o3_maxhmean'
PM10_DMEAN = 'chimere_pm10_dmean' PM10_DMEAN = 'chimere_pm10_dmean'

View file

@ -22,7 +22,7 @@ async def test_cached_calls():
session = get_mock_session_many_csv() session = get_mock_session_many_csv()
client = IrcelineForecastClient(session) client = IrcelineForecastClient(session)
_ = await client.get_forecasts( _ = await client.get_data(
day=date(2024, 6, 19), day=date(2024, 6, 19),
features=[ForecastFeature.NO2_MAXHMEAN], features=[ForecastFeature.NO2_MAXHMEAN],
position=(50.45, 4.85) position=(50.45, 4.85)
@ -39,7 +39,7 @@ async def test_cached_calls():
assert session.request.call_count == 5 assert session.request.call_count == 5
session.request.assert_has_calls(calls) session.request.assert_has_calls(calls)
_ = await client.get_forecasts( _ = await client.get_data(
day=date(2024, 6, 19), day=date(2024, 6, 19),
features=[ForecastFeature.NO2_MAXHMEAN], features=[ForecastFeature.NO2_MAXHMEAN],
position=(50.45, 4.85) position=(50.45, 4.85)
@ -61,7 +61,7 @@ async def test_missed_cached_calls():
session = get_mock_session_many_csv() session = get_mock_session_many_csv()
client = IrcelineForecastClient(session) client = IrcelineForecastClient(session)
r = await client.get_forecasts( r = await client.get_data(
day=date(2024, 6, 21), day=date(2024, 6, 21),
features=[ForecastFeature.NO2_MAXHMEAN], features=[ForecastFeature.NO2_MAXHMEAN],
position=(50.45, 4.85) position=(50.45, 4.85)

View file

@ -99,7 +99,7 @@ async def test_api_rio():
d = date(2024, 6, 18) d = date(2024, 6, 18)
features = [RioFeature.NO2_HMEAN, RioFeature.O3_HMEAN] features = [RioFeature.NO2_HMEAN, RioFeature.O3_HMEAN]
_ = await client.get_rio_value(d, features, pos) _ = await client.get_data(d, features, pos)
session.request.assert_called_once_with( session.request.assert_called_once_with(
method='GET', method='GET',
url=rio_wfs_base_url, url=rio_wfs_base_url,