From 2c09eec3e11b95d26ab787253f91bf7d7ad7c4dd Mon Sep 17 00:00:00 2001 From: Jules Dejaeghere Date: Thu, 20 Jun 2024 20:33:23 +0200 Subject: [PATCH] Refactor classes --- src/open_irceline/api.py | 32 ++++++++++++++++++++------------ src/open_irceline/data.py | 8 ++++++-- tests/test_api_forecasts.py | 6 +++--- tests/test_api_rio.py | 2 +- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/open_irceline/api.py b/src/open_irceline/api.py index b46f470..7a5a989 100644 --- a/src/open_irceline/api.py +++ b/src/open_irceline/api.py @@ -1,6 +1,7 @@ import asyncio import csv import socket +from abc import ABC, abstractmethod from datetime import datetime, timedelta, date from io import StringIO from itertools import product @@ -12,7 +13,7 @@ import async_timeout from aiohttp import ClientResponse from . import rio_wfs_base_url, user_agent, forecast_base_url -from .data import RioFeature, FeatureValue, ForecastFeature +from .data import RioFeature, FeatureValue, ForecastFeature, IrcelineFeature from .utils import SizedDict, epsg_transform, round_coordinates @@ -20,11 +21,18 @@ class IrcelineApiError(Exception): """Exception to indicate an API error.""" -class IrcelineBaseClient: +class IrcelineBaseClient(ABC): def __init__(self, session: aiohttp.ClientSession, cache_size: int = 20) -> None: self._session = session self._cache = SizedDict(cache_size) + @abstractmethod + async def get_data(self, + timestamp: datetime | date, + features: List[IrcelineFeature], + position: Tuple[float, float]) -> dict: + pass + async def _api_wrapper(self, url: str, querystring: dict = None, headers: dict = None, method: str = 'GET'): """ Call the URL with the specified query string. Raises exception for >= 400 response code @@ -80,11 +88,11 @@ class IrcelineBaseClient: class IrcelineRioClient(IrcelineBaseClient): """API client for RIO interpolated IRCEL - CELINE open data""" - async def get_rio_value(self, - timestamp: datetime | date, - features: List[RioFeature], - position: Tuple[float, float] - ) -> Dict[RioFeature, FeatureValue]: + async def get_data(self, + timestamp: datetime | date, + features: List[RioFeature], + position: Tuple[float, float] + ) -> Dict[RioFeature, FeatureValue]: """ Call the WFS API to get the interpolated level of RioFeature. Raises exception upon API error :param timestamp: datetime for which to get the data for @@ -195,11 +203,11 @@ class IrcelineRioClient(IrcelineBaseClient): class IrcelineForecastClient(IrcelineBaseClient): """API client for forecast IRCEL - CELINE open data""" - async def get_forecasts(self, - day: date, - features: List[ForecastFeature], - position: Tuple[float, float] - ) -> Dict[Tuple[ForecastFeature, date], FeatureValue]: + async def get_data(self, + day: date, + features: List[ForecastFeature], + position: Tuple[float, float] + ) -> Dict[Tuple[ForecastFeature, date], FeatureValue]: """ Get forecasted concentrations for the given features at the given position. The forecasts are downloaded for the specified day and the 4 next days as well diff --git a/src/open_irceline/data.py b/src/open_irceline/data.py index 2972f25..7912493 100644 --- a/src/open_irceline/data.py +++ b/src/open_irceline/data.py @@ -3,7 +3,11 @@ from enum import StrEnum from typing import TypedDict -class RioFeature(StrEnum): +class IrcelineFeature(StrEnum): + pass + + +class RioFeature(IrcelineFeature): BC_24HMEAN = 'rio:bc_24hmean' BC_DMEAN = 'rio:bc_dmean' BC_HMEAN = 'rio:bc_hmean' @@ -26,7 +30,7 @@ class RioFeature(StrEnum): SO2_HMEAN = 'rio:so2_hmean' -class ForecastFeature(StrEnum): +class ForecastFeature(IrcelineFeature): NO2_MAXHMEAN = 'chimere_no2_maxhmean' O3_MAXHMEAN = 'chimere_o3_maxhmean' PM10_DMEAN = 'chimere_pm10_dmean' diff --git a/tests/test_api_forecasts.py b/tests/test_api_forecasts.py index df61357..85eeecf 100644 --- a/tests/test_api_forecasts.py +++ b/tests/test_api_forecasts.py @@ -22,7 +22,7 @@ async def test_cached_calls(): session = get_mock_session_many_csv() client = IrcelineForecastClient(session) - _ = await client.get_forecasts( + _ = await client.get_data( day=date(2024, 6, 19), features=[ForecastFeature.NO2_MAXHMEAN], position=(50.45, 4.85) @@ -39,7 +39,7 @@ async def test_cached_calls(): assert session.request.call_count == 5 session.request.assert_has_calls(calls) - _ = await client.get_forecasts( + _ = await client.get_data( day=date(2024, 6, 19), features=[ForecastFeature.NO2_MAXHMEAN], position=(50.45, 4.85) @@ -61,7 +61,7 @@ async def test_missed_cached_calls(): session = get_mock_session_many_csv() client = IrcelineForecastClient(session) - r = await client.get_forecasts( + r = await client.get_data( day=date(2024, 6, 21), features=[ForecastFeature.NO2_MAXHMEAN], position=(50.45, 4.85) diff --git a/tests/test_api_rio.py b/tests/test_api_rio.py index df18014..68a719e 100644 --- a/tests/test_api_rio.py +++ b/tests/test_api_rio.py @@ -99,7 +99,7 @@ async def test_api_rio(): d = date(2024, 6, 18) features = [RioFeature.NO2_HMEAN, RioFeature.O3_HMEAN] - _ = await client.get_rio_value(d, features, pos) + _ = await client.get_data(d, features, pos) session.request.assert_called_once_with( method='GET', url=rio_wfs_base_url,