Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"slowapi>=0.1.9",
"grpc-requests>=0.1.21",
"gunicorn>=25.3.0",
"jinja2>=3.1.6",
]

[dependency-groups]
Expand Down
66 changes: 36 additions & 30 deletions src/quartz_api/internal/backends/dataplatform/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def set_sync_client(self, host: str, port: int) -> None:
from grpc_requests import Client
self._sync_client = Client.get_by_endpoint(f"{host}:{port}")
# Pre-warm service discovery so the first real call isn't slow.
self._sync_client.service(_DP_SERVICE)
self.svc = self._sync_client.service(_DP_SERVICE)
log.info("grpc-requests sync client initialised at %s:%s", host, port)

def _sync_snapshot(
Expand Down Expand Up @@ -90,7 +90,7 @@ def _sync_snapshot(
if not resp.get("values"):
return []

valid_ts = dt.datetime.fromisoformat(resp["timestamp_utc"].rstrip("Z")).replace(tzinfo=dt.UTC) # noqa: E501
valid_ts = dt.datetime.fromisoformat(resp["timestamp_utc"])
return [
models.PredictedGenerationValue(
power_kilowatts=float(v.get("value_fraction", 0)) * float(v["effective_capacity_watts"]) / 1000, # noqa: E501
Expand All @@ -99,8 +99,8 @@ def _sync_snapshot(
capacity_kilowatts=float(v["effective_capacity_watts"]) / 1000,
forecaster_name=forecaster_name,
forecaster_version=forecaster_version,
created_timestamp=dt.datetime.fromisoformat(v["created_timestamp_utc"].rstrip("Z")).replace(tzinfo=dt.UTC),
init_timestamp=dt.datetime.fromisoformat(v["initialization_timestamp_utc"].rstrip("Z")).replace(tzinfo=dt.UTC),
created_timestamp=dt.datetime.fromisoformat(v["created_timestamp_utc"]),
init_timestamp=dt.datetime.fromisoformat(v["initialization_timestamp_utc"]),
metadata=v.get("metadata", {}),
)
for v in resp["values"]
Expand Down Expand Up @@ -196,53 +196,59 @@ async def get_predicted_generation(
forecaster_version=forecaster_version,
)

req = dp.GetForecastAsTimeseriesRequest(
location_uuid=str(location_uuid),
energy_source=energy_type_map[energy_type],
horizon_mins=forecast_horizon_minutes,
time_window=dp.TimeWindow(
start_timestamp_utc=window_start,
end_timestamp_utc=window_end,
),
forecaster=forecaster,
pivot_timestamp_utc=created_cutoff,
)
resp = await self.dpc.get_forecast_as_timeseries(req)
req = {
"location_uuid": str(location_uuid),
"energy_source": energy_type_map[energy_type].value,
"horizon_mins": forecast_horizon_minutes,
"time_window": {
"start_timestamp_utc": window_start.strftime("%Y-%m-%dT%H:%M:%SZ"),
"end_timestamp_utc": window_end.strftime("%Y-%m-%dT%H:%M:%SZ"),
},
"forecaster": forecaster.to_dict(),
"pivot_timestamp_utc": created_cutoff.strftime("%Y-%m-%dT%H:%M:%SZ"),
}

resp = self.svc.GetForecastAsTimeseries(req)

if location_type == models.LocationType.SUBSTATION:
# Spoof the forecast values so that the capacity and id corresponds to the substation
for v in resp.values:
v.location_uuid = location_uuid
v.effective_capacity_watts = location.effective_capacity_watts
# Spoof the forecast values so that the capacity and id corresponds to the substation
for v in resp["values"]:
v["location_uuid"] = location_uuid
v["effective_capacity_watts"] = location.effective_capacity_watts

out: list[models.PredictedGenerationValue] = [
models.PredictedGenerationValue(
power_kilowatts=int(
v.effective_capacity_watts * v.p50_value_fraction / 1000,
float(v["effective_capacity_watts"]) \
* float(v.get("p50_value_fraction", 0)) / 1000,
),
valid_timestamp=v.target_timestamp_utc,
valid_timestamp=v["target_timestamp_utc"],
location_uuid=location_uuid,
capacity_kilowatts=int(v.effective_capacity_watts / 1000),
created_timestamp=v.created_timestamp_utc,
init_timestamp=v.initialization_timestamp_utc,
capacity_kilowatts=int(float(v["effective_capacity_watts"]) / 1000),
created_timestamp=v["created_timestamp_utc"],
init_timestamp=v["initialization_timestamp_utc"],
forecaster_name=forecaster.forecaster_name,
forecaster_version=forecaster.forecaster_version,
plevels_kilowatts={
"p10": int(
v.effective_capacity_watts * v.other_statistics_fractions["p10"] / 1000.0,
float(v["effective_capacity_watts"]) \
* float(v["other_statistics_fractions"].get("p10", 0)) / 1000.0,
),
"p90": int(
v.effective_capacity_watts * v.other_statistics_fractions["p90"] / 1000.0,
float(v["effective_capacity_watts"]) \
* float(v["other_statistics_fractions"].get("p90", 0)) / 1000.0,
),
}
if "p10" in v.other_statistics_fractions and "p90" in v.other_statistics_fractions
if "p10" in v.get("other_statistics_fractions", {}) and \
"p90" in v.get("other_statistics_fractions", {})
else {},
metadata=struct_to_dict(v.metadata),
metadata=v["metadata"],
)
for v in resp.values
for v in resp["values"]
]
return out


@override
async def put_predicted_generation(
self,
Expand Down
37 changes: 35 additions & 2 deletions src/quartz_api/internal/backends/dataplatform/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import datetime as dt
import unittest
import uuid
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch

from betterproto.lib.google.protobuf import Struct, Value
from dp_sdk.ocf import dp
from fastapi import HTTPException
from grpc_requests import Client

from quartz_api.internal import models

Expand Down Expand Up @@ -50,6 +51,27 @@ def mock_list_locations(
)


def mock_get_forecast_sync(
req: dict,
metadata: object | None = None, # noqa: ARG001
) -> dict:
return {
"values": [
{
"target_timestamp_utc": TEST_TIMESTAMP_UTC + dt.timedelta(hours=i),
"p50_value_fraction": 0.5,
"effective_capacity_watts": 1e6,
"initialization_timestamp_utc": TEST_TIMESTAMP_UTC
- dt.timedelta(minutes=req["horizon_mins"]),
"created_timestamp_utc": TEST_TIMESTAMP_UTC
- dt.timedelta(hours=1, minutes=req["horizon_mins"]),
"other_statistics_fractions": {"p90": 0.9, "p10": 0.1},
"metadata": {},
}
for i in range(5)
],
}

def mock_get_forecast(
req: dp.GetForecastAsTimeseriesRequest,
metadata: object | None = None, # noqa: ARG001
Expand Down Expand Up @@ -139,7 +161,12 @@ class TestCase:
self.assertEqual(len(resp), tc.expected_num_locations)

@patch("dp_sdk.ocf.dp.DataPlatformDataServiceStub")
async def test_get_site_forecast(self, client_mock: dp.DataPlatformDataServiceStub) -> None:
@patch("grpc_requests.Client")
async def test_get_site_forecast(
self,
client_mock: dp.DataPlatformDataServiceStub,
_: Client,
) -> None:
@dataclasses.dataclass
class TestCase:
name: str
Expand All @@ -163,10 +190,12 @@ class TestCase:
]

client = StorageClient.from_dp(client_mock)
client.set_sync_client("localhost", "50051")
for tc in testcases:
client_mock.list_locations = AsyncMock(side_effect=mock_list_locations)
client_mock.get_forecast_as_timeseries = AsyncMock(side_effect=mock_get_forecast)
client_mock.get_latest_forecasts = AsyncMock(side_effect=mock_get_latest_forecasts)
client.svc.GetForecastAsTimeseries = Mock(side_effect=mock_get_forecast_sync)

with self.subTest(tc.name):
if tc.should_error:
Expand Down Expand Up @@ -336,9 +365,11 @@ class TestCase:
self.assertEqual(len(resp), tc.number_of_locations)

@patch("dp_sdk.ocf.dp.DataPlatformDataServiceStub")
@patch("grpc_requests.Client")
async def test_get_substation_forecast(
self,
client_mock: dp.DataPlatformDataServiceStub,
_: Client,
) -> None:
@dataclasses.dataclass
class TestCase:
Expand Down Expand Up @@ -369,10 +400,12 @@ class TestCase:
]

client = StorageClient.from_dp(client_mock)
client.set_sync_client("localhost", "50051")
for tc in testcases:
client_mock.list_locations = AsyncMock(side_effect=mock_list_locations)
client_mock.get_forecast_as_timeseries = AsyncMock(side_effect=mock_get_forecast)
client_mock.get_latest_forecasts = AsyncMock(side_effect=mock_get_latest_forecasts)
client.svc.GetForecastAsTimeseries = Mock(side_effect=mock_get_forecast_sync)

with self.subTest(tc.name):
if tc.should_error:
Expand Down
49 changes: 31 additions & 18 deletions src/quartz_api/internal/service/uk_national/gsp_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def _build_forecast_response(
forecasts = []
for gsp_id in gsp_uuid_id_map.values():
pgv = gsp_pgv_map.get(gsp_id)
if pgv is None:
continue
location = Location.from_location(gsp_id_map[gsp_id])
location.installed_capacity_mw = \
pgv.capacity_kilowatts / 1000.0
Expand Down Expand Up @@ -322,7 +324,7 @@ async def get_all_available_forecasts(
"""
# Default (no gsp_ids): served from warm cache only. If we're here it's a cache miss —
# trigger a warm in the background and ask the client to retry.
if gsp_ids is None:
if gsp_ids is None and (start_datetime_utc != end_datetime_utc):
global _cache_warming
if not _cache_warming:
background_tasks.add_task(_warm_forecast_all_cache, request.app)
Expand All @@ -332,27 +334,38 @@ async def get_all_available_forecasts(
detail="Forecast cache is being populated, please retry in 60 seconds.",
)

# gsp_ids path: custom query, fetch live.
gsps_to_convert: dict[int, models.Location] = {
k: v for k, v in gsp_id_map.items()
if k in convert_list_of_gsp_ids(gsp_ids)
}
tasks = [
asyncio.create_task(
db.get_predicted_generation(
location_uuid=str(loc.uuid),
window_start=start_datetime_utc,
window_end=end_datetime_utc,
energy_type=models.EnergyType.SOLAR,
location_type=models.LocationType.GSP,
if gsp_ids is None:
gsps_to_convert = gsp_id_map
tasks = [
db.get_predicted_generation_snapshot(
location_uuids=[v.uuid for _,v in gsp_id_map.items()],
snapshot_timestamp_utc=start_datetime_utc,
authdata={},
forecast_horizon_minutes=0,
energy_type=models.EnergyType.SOLAR,
forecaster_name=GSP_FORECASTER_NAME,
forecaster_version=GSP_FORECASTER_VERSION,
),
)
for loc in gsps_to_convert.values()
]
]
else:
# gsp_ids path: custom query, fetch live.
gsps_to_convert: dict[int, models.Location] = {
k: v for k, v in gsp_id_map.items()
if k in convert_list_of_gsp_ids(gsp_ids)
}
tasks = [
db.get_predicted_generation(
location_uuid=str(loc.uuid),
window_start=start_datetime_utc,
window_end=end_datetime_utc,
energy_type=models.EnergyType.SOLAR,
location_type=models.LocationType.GSP,
authdata={},
forecast_horizon_minutes=0,
forecaster_name=GSP_FORECASTER_NAME,
forecaster_version=GSP_FORECASTER_VERSION,
)
for loc in gsps_to_convert.values()
]
results: list[list[models.PredictedGenerationValue] | Exception] = await asyncio.gather(
*tasks, return_exceptions=True,
)
Expand Down
1 change: 1 addition & 0 deletions src/quartz_api/tests/integration/substations/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def api_client_substations(
app = _create_server(config_substations)

db_instance = DataPlatformStorage.from_dp(dp_client=dp_client)
db_instance.set_sync_client(os.environ["DATA_PLATFORM_HOST"], os.environ["DATA_PLATFORM_PORT"])
app.dependency_overrides[models.get_storage_client] = lambda: db_instance

async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
Expand Down
1 change: 1 addition & 0 deletions src/quartz_api/tests/integration/uk_national/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def api_client_uk_national(
app = _create_server(config_uk_national)

db_instance = DataPlatformStorage.from_dp(dp_client=dp_client)
db_instance.set_sync_client(os.environ["DATA_PLATFORM_HOST"], os.environ["DATA_PLATFORM_PORT"])
app.dependency_overrides[models.get_storage_client] = lambda: db_instance

async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
Expand Down
40 changes: 40 additions & 0 deletions src/quartz_api/tests/integration/uk_national/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,46 @@ async def test_gsp_forecast_all_refresh_non_admin(
assert response.status_code == 403


# 4.3.6 Get forecast/all with for one timestamp
@pytest.mark.asyncio(loop_scope="session")
async def test_gsp_forecast_all_for_one_timestamp(
api_client_uk_national,
gsp_locations, # noqa arg001
make_forecasters, # noqa arg001
make_gsp_forecast_values, # noqa arg001
) -> None:
"""Test that the cache refresh endpoint rejects a non-admin user."""
now = (pd.Timestamp.utcnow().ceil("30min").to_pydatetime()).strftime("%Y-%m-%dT%H:00:00Z")
response = await api_client_uk_national.get(
f"/v0/solar/GB/gsp/forecast/all/?start_datetime_utc={now}&end_datetime_utc={now}",
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 10 # 10 gsps

# 4.3.7 Get forecast/all with for one timestamp, compact=true
@pytest.mark.asyncio(loop_scope="session")
async def test_gsp_forecast_all_for_one_timestamp_compact(
api_client_uk_national,
gsp_locations, # noqa arg001
make_forecasters, # noqa arg001
make_gsp_forecast_values, # noqa arg001
) -> None:
"""Test that the cache refresh endpoint rejects a non-admin user."""
now = (pd.Timestamp.utcnow().ceil("30min").to_pydatetime()).strftime("%Y-%m-%dT%H:00:00Z")
response = await api_client_uk_national.get(
f"/v0/solar/GB/gsp/forecast/all/?start_datetime_utc={now}&end_datetime_utc={now}&compact=true",
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 1





# 4.4 Check GSP pvlive route
@pytest.mark.asyncio(loop_scope="session")
async def test_gsp_pvlive_all(
Expand Down
Loading
Loading