Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dev = [
"quarto-cli==1.5.57",
"quartodoc==0.11.1",
"netCDF4",
"dask",
"mikeio1d>=1.1.1",
]

test = [
Expand Down
10 changes: 9 additions & 1 deletion src/modelskill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@
GridModelResult,
DfsuModelResult,
DummyModelResult,
NetworkModelResult,
)
from .obs import (
observation,
PointObservation,
TrackObservation,
NetworkLocationObservation,
)
from .obs import observation, PointObservation, TrackObservation
from .matching import from_matched, match
from .configuration import from_config
from .settings import options, get_option, set_option, reset_option, load_style
Expand Down Expand Up @@ -90,8 +96,10 @@ def load(filename: Union[str, Path]) -> Comparer | ComparerCollection:
"GridModelResult",
"DfsuModelResult",
"DummyModelResult",
"NetworkModelResult",
"observation",
"PointObservation",
"NetworkLocationObservation",
"TrackObservation",
"TimeSeries",
"match",
Expand Down
39 changes: 34 additions & 5 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
from .model.dummy import DummyModelResult
from .model.grid import GridModelResult
from .model.track import TrackModelResult
from .obs import Observation, PointObservation, TrackObservation, observation
from .model.network import NetworkModelResult
from .obs import (
Observation,
PointObservation,
TrackObservation,
NetworkLocationObservation,
observation,
)
from .timeseries import TimeSeries
from .types import Period

Expand All @@ -50,6 +57,7 @@
GridModelResult,
DfsuModelResult,
TrackModelResult,
NetworkModelResult,
DummyModelResult,
]
ObsInputType = Union[
Expand Down Expand Up @@ -274,7 +282,15 @@ def match(

if len(obs) > 1 and isinstance(mod, Collection) and len(mod) > 1:
if not all(
isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult))
isinstance(
m,
(
DfsuModelResult,
GridModelResult,
NetworkModelResult,
DummyModelResult,
),
)
for m in mod
):
raise ValueError(
Expand Down Expand Up @@ -337,7 +353,15 @@ def _match_single_obs(
raw_mod_data = {
m.name: (
m.extract(observation, spatial_method=spatial_method)
if isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult))
if isinstance(
m,
(
DfsuModelResult,
GridModelResult,
DummyModelResult,
NetworkModelResult,
),
)
else m
)
for m in model_results
Expand Down Expand Up @@ -379,6 +403,7 @@ def _match_space_time(
idxs = [m.time for m in raw_mod_data.values()]
period = _get_global_start_end(idxs)

# TODO is the trim step necessary?
observation = observation.trim(period.start, period.end, no_overlap=obs_no_overlap)
if len(observation.data.time) == 0:
return None
Expand Down Expand Up @@ -425,8 +450,10 @@ def _parse_single_obs(
obs: ObsInputType,
obs_item: Optional[int | str],
gtype: Optional[GeometryTypes],
) -> PointObservation | TrackObservation:
if isinstance(obs, (PointObservation, TrackObservation)):
) -> PointObservation | TrackObservation | NetworkLocationObservation:
if isinstance(
obs, (PointObservation, TrackObservation, NetworkLocationObservation)
):
if obs_item is not None:
raise ValueError(
"obs_item argument not allowed if obs is an modelskill.Observation type"
Expand All @@ -446,6 +473,7 @@ def _parse_single_model(
| TrackModelResult
| GridModelResult
| DfsuModelResult
| NetworkModelResult
| DummyModelResult
):
if isinstance(
Expand Down Expand Up @@ -478,6 +506,7 @@ def _parse_single_model(
TrackModelResult,
GridModelResult,
DfsuModelResult,
NetworkModelResult,
DummyModelResult,
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/modelskill/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .dfsu import DfsuModelResult
from .grid import GridModelResult
from .dummy import DummyModelResult
from .network import NetworkModelResult

__all__ = [
"PointModelResult",
Expand All @@ -29,4 +30,5 @@
"GridModelResult",
"model_result",
"DummyModelResult",
"NetworkModelResult",
]
58 changes: 58 additions & 0 deletions src/modelskill/model/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any
from pathlib import Path
from typing import Optional

import pandas as pd

# from ..quantity import Quantity

from ._base import SelectedItems
from ..obs import NetworkLocationObservation, PointObservation, TrackObservation
from .point import PointModelResult


class NetworkModelResult:
def __init__(
self,
data: str | Path,
*,
name: str,
item: str | int | None = None,
# quantity: Optional[Quantity] = None,
aux_items: Optional[list[int | str]] = None,
) -> None:
import mikeio1d

df = mikeio1d.Res1D(data).to_dataframe()

self.data = df
self.time = df.index
self.name = name
self.item = item

# TODO load from file
data_vars = ["Discharge", "Water Level"]

sel_items = SelectedItems.parse(data_vars, item=item, aux_items=aux_items)

self.sel_items = sel_items

def extract(
self,
observation: PointObservation | TrackObservation | NetworkLocationObservation,
**kwargs: Any,
) -> PointModelResult:
if not isinstance(observation, NetworkLocationObservation):
raise TypeError(
"NetworkModelResult can only extract NetworkLocationObservation"
)
col = f"{self.item}:{observation.reach}:{observation.chainage}"

df = pd.DataFrame()
df[observation.name] = self.data[col]
df.index = self.time
return PointModelResult(
data=df,
name=self.name,
item=observation.name,
)
31 changes: 31 additions & 0 deletions src/modelskill/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,37 @@ def z(self, value):
self.data["z"] = value


class NetworkLocationObservation(PointObservation):
def __init__(
self,
data: PointType,
*,
reach: str,
chainage: str,
item: Optional[int | str] = None,
x: Optional[float] = None,
y: Optional[float] = None,
name: Optional[str] = None,
weight: float = 1.0,
quantity: Optional[Quantity] = None,
aux_items: Optional[list[int | str]] = None,
attrs: Optional[dict] = None,
) -> None:
self.reach = reach
self.chainage = chainage
super().__init__(
data=data,
item=item,
x=x,
y=y,
name=name,
weight=weight,
quantity=quantity,
aux_items=aux_items,
attrs=attrs,
)


class TrackObservation(Observation):
"""Class for observation with locations moving in space, e.g. satellite altimetry

Expand Down
5 changes: 3 additions & 2 deletions src/modelskill/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,5 +339,6 @@ def trim(
warnings.warn(msg)
case _:
pass

return self.__class__(data)
copy = deepcopy(self)
copy.data = data
return copy
11 changes: 11 additions & 0 deletions tests/model/test_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import modelskill as ms


def test_network_model_result_has_name():
mod = ms.NetworkModelResult(
"tests/testdata/network/Vida_1BaseDefault_Network_HD.res1d",
name="Vida",
item="Discharge",
)

assert mod.name == "Vida"
16 changes: 16 additions & 0 deletions tests/observation/test_network_obs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import modelskill as ms


def test_network_observation():
obs = ms.NetworkLocationObservation(
"tests/testdata/network/vidaa_mag_4905.dfs0",
item=0,
reach="VIDAA-MAG",
chainage="4905",
name="By the bridge",
)
assert obs.name == "By the bridge"
assert obs.quantity.name == "Discharge"
assert obs.quantity.unit == "m^3/s"
assert obs.reach == "VIDAA-MAG"
assert obs.chainage == "4905"
18 changes: 18 additions & 0 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,21 @@ def test_multiple_models_same_name(tmp_path: Path) -> None:

with pytest.raises(ValueError, match="HKZN_local_2017_DutchCoast"):
ms.match(obs, [mr1, mr2])


def test_network():
obs = ms.NetworkLocationObservation(
"tests/testdata/network/vidaa_mag_4905.dfs0",
item=0,
reach="VIDAA-MAG",
chainage="4905",
)
mod = ms.NetworkModelResult(
"tests/testdata/network/Vida_1BaseDefault_Network_HD.res1d",
name="Vida",
item="Discharge",
)

cmp = ms.match(obs, mod)
assert cmp.n_points == 37
assert cmp.score()["Vida"] == pytest.approx(0.11985827)
Binary file not shown.
Binary file added tests/testdata/network/vidaa_mag_4905.dfs0
Binary file not shown.