Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion climada/util/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,40 @@
Define Forecast base class.
"""

import numpy as np


class Forecast:
pass
"""Mixin class for forecast data.

Attributes
----------
lead_time : np.ndarray
Array of forecast lead times, given as datetime64 objects.
Represents the time points for which forecasts are made.
member : np.ndarray
Array of ensemble member identifiers, given as integers.
Represents different forecast ensemble members.
"""

def __init__(
self,
lead_time: np.ndarray | None = None,
member: np.ndarray | None = None,
**kwargs,
):
"""Initialize Forecast.

Parameters
----------
lead_time : np.ndarray or None, optional
Forecast lead times. Default is empty array.
member : np.ndarray or None, optional
Ensemble member identifiers. Default is empty array.
"""

self.lead_time = (
np.asarray(lead_time) if lead_time is not None else np.array([])
)
self.member = np.asarray(member) if member is not None else np.array([])
super().__init__(**kwargs)
32 changes: 32 additions & 0 deletions climada/util/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,35 @@

Tests for Forecast base class.
"""

import numpy as np
import numpy.testing as npt

from climada.util.forecast import Forecast


def test_forecast_init():
"""Test initialization of Forecast class."""
forecast = Forecast()
npt.assert_array_equal(forecast.lead_time, np.array([]))
npt.assert_array_equal(forecast.member, np.array([]))

forecast = Forecast(member=np.array([1, 2]))
npt.assert_array_equal(forecast.member, np.array([1, 2]), strict=True)

forecast = Forecast(lead_time=np.array([1, 2]))
npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True)

forecast = Forecast(lead_time=np.array([1, 2]), member=[3, 4])
npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True)
npt.assert_array_equal(forecast.member, np.array([3, 4]), strict=True)
assert isinstance(forecast.member, np.ndarray)

# Test with datetime64 including seconds
lead_times_seconds = np.array(
["2024-01-01T00:00:00", "2024-01-01T00:01:00", "2024-01-01"],
dtype="datetime64[s]",
)
forecast = Forecast(lead_time=lead_times_seconds, member=[1, 2, 3])
npt.assert_array_equal(forecast.lead_time, lead_times_seconds, strict=True)
assert forecast.lead_time.dtype == np.dtype("datetime64[s]")
Loading