Skip to content

Commit 4228cad

Browse files
committed
Add wrangler
- Wrangling works - Unwrangling awkward data needs work
1 parent 5a72167 commit 4228cad

File tree

3 files changed

+251
-1
lines changed

3 files changed

+251
-1
lines changed

imas/backends/netcdf/ids_tensorizer.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""Tensorization logic to convert IDSs to netCDF files and/or xarray Datasets."""
44

55
from collections import deque
6-
from typing import List
6+
from typing import List, Tuple
77

88
import numpy
9+
import awkward as ak
910

1011
from imas.backends.netcdf.iterators import indexed_tree_iter
1112
from imas.backends.netcdf.nc_metadata import NCMetadata
@@ -203,3 +204,41 @@ def tensorize(self, path, fillvalue):
203204
tmp_var[aos_coords + tuple(map(slice, node.shape))] = node.value
204205

205206
return tmp_var
207+
208+
def recursively_convert_to_list(self, path: str, inactive_index:Tuple,
209+
shape:Tuple, i_dim: int):
210+
entry = []
211+
for index in path:
212+
new_index = inactive_index + (index,)
213+
if i_dim == len(shape) - 1:
214+
entry.append(self.filled_data[path][new_index].value)
215+
else:
216+
entry.append(self.recursively_convert_to_list(path, new_index,
217+
shape, i_dim + 1))
218+
return entry
219+
220+
def awkward_tensorize(self, path:str):
221+
"""
222+
Tensorizes the data at the given path with the specified fill value.
223+
224+
Args:
225+
path: The path to the data in the IDS.
226+
fillvalue: The value to fill the tensor with. Can be of any type,
227+
including strings.
228+
229+
Returns:
230+
A tensor filled with the data from the specified path.
231+
"""
232+
if path in self.shapes:
233+
shape = self.shapes[path]
234+
else:
235+
dimensions = self.ncmeta.get_dimensions(path, self.homogeneous_time)
236+
shape = tuple(self.dimension_size[dim] for dim in dimensions)
237+
# Get the split between HDF5 indices and stored matrices
238+
# i.e. equilibrium.time_slice.profiles_2d <-> psi
239+
hdf5_dim = len(list(self.filled_data[path].keys())[0])
240+
if hdf5_dim == 0:
241+
return self.filled_data[path][()].value
242+
else:
243+
return ak.Array(self.recursively_convert_to_list(path, tuple(), shape[:hdf5_dim], 0))
244+

imas/test/test_wrangle.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import pytest
2+
import numpy as np
3+
import awkward as ak
4+
5+
from imas.wrangler import wrangle, unwrangle
6+
from imas.ids_factory import IDSFactory
7+
from imas.util import idsdiffgen
8+
9+
@pytest.fixture
10+
def test_data():
11+
data = {"equilibrium": {}}
12+
data["equilibrium"]["N_time"] = 100
13+
data["equilibrium"]["N_radial"] = 100
14+
data["equilibrium"]["N_grid"] = 1
15+
data["equilibrium"]["time"] = np.linspace(0.0, 5.0, data["equilibrium"]["N_time"])
16+
data["equilibrium"]["psi_1d"] = np.linspace(0.0, 1.0, data["equilibrium"]["N_radial"])
17+
data["equilibrium"]["r"] = np.linspace(1.0, 2.0, data["equilibrium"]["N_radial"])
18+
data["equilibrium"]["z"] = np.linspace(-1.0, 1.0, data["equilibrium"]["N_radial"])
19+
r_grid, z_grid = np.meshgrid(data["equilibrium"]["r"],
20+
data["equilibrium"]["z"], indexing="ij")
21+
data["equilibrium"]["psi_2d"] = (r_grid - 1.5) ** 2 + z_grid**2
22+
23+
data["thomson_scattering"] = {}
24+
data["thomson_scattering"]["N_ch"] = (20,10)
25+
data["thomson_scattering"]["N_time"] = (100, 300)
26+
data["thomson_scattering"]["r"] = np.concatenate([np.ones(data["thomson_scattering"]["N_ch"][0])*1.6,
27+
np.ones(data["thomson_scattering"]["N_ch"][1])*1.7])
28+
data["thomson_scattering"]["z"] = np.concatenate([np.linspace(-1.0, 1.0, data["thomson_scattering"]["N_ch"][0]),
29+
np.linspace(-1.0, 1.0, data["thomson_scattering"]["N_ch"][1])])
30+
data["thomson_scattering"]["t_e"] = data["thomson_scattering"]["z"]**2 * 5.e3
31+
data["thomson_scattering"]["n_e"] = data["thomson_scattering"]["z"]**2 * 5.e19
32+
data["thomson_scattering"]["time"] = (np.linspace(0,5.0, data["thomson_scattering"]["N_time"][0]),
33+
np.linspace(0,5.0, data["thomson_scattering"]["N_time"][1]))
34+
return data
35+
36+
@pytest.fixture
37+
def flat(test_data):
38+
flat = {}
39+
# Equilibrium test data
40+
flat["equilibrium.time"] = test_data["equilibrium"]["time"]
41+
flat["equilibrium.time_slice.time"] = test_data["equilibrium"]["time"]
42+
flat["equilibrium.ids_properties.homogeneous_time"] = 1
43+
flat["equilibrium.time_slice.profiles_1d.psi"] = np.zeros(
44+
(test_data["equilibrium"]["N_time"], test_data["equilibrium"]["N_radial"])
45+
)
46+
flat["equilibrium.time_slice.profiles_1d.psi"][:] = test_data["equilibrium"]["psi_1d"]
47+
flat["equilibrium.time_slice.profiles_2d.grid.dim1"] = np.zeros(
48+
(test_data["equilibrium"]["N_time"],
49+
test_data["equilibrium"]["N_grid"],
50+
test_data["equilibrium"]["N_radial"])
51+
)
52+
flat["equilibrium.time_slice.profiles_2d.grid.dim1"][:] = test_data["equilibrium"]["r"][None, :]
53+
flat["equilibrium.time_slice.profiles_2d.grid.dim2"] = np.zeros(
54+
(test_data["equilibrium"]["N_time"],
55+
test_data["equilibrium"]["N_grid"],
56+
test_data["equilibrium"]["N_radial"])
57+
)
58+
flat["equilibrium.time_slice.profiles_2d.grid.dim2"][:] = test_data["equilibrium"]["z"][None, :]
59+
flat["equilibrium.time_slice.profiles_2d.psi"] = np.zeros(
60+
(
61+
test_data["equilibrium"]["N_time"],
62+
test_data["equilibrium"]["N_grid"],
63+
test_data["equilibrium"]["N_radial"],
64+
test_data["equilibrium"]["N_radial"],
65+
)
66+
)
67+
flat["equilibrium.time_slice.profiles_2d.psi"][:] = test_data["equilibrium"]["psi_2d"][None, ...]
68+
69+
# Thomson scattering test data (ragged)
70+
flat["thomson_scattering.ids_properties.homogeneous_time"] = 0
71+
flat["thomson_scattering.channel.t_e.time"] = ak.concatenate([np.tile(test_data["thomson_scattering"]["time"][0],
72+
(test_data["thomson_scattering"]["N_ch"][0],1)),
73+
np.tile(test_data["thomson_scattering"]["time"][1],
74+
(test_data["thomson_scattering"]["N_ch"][1],1))])
75+
flat["thomson_scattering.channel.t_e.data"] = ak.concatenate([np.tile(test_data["thomson_scattering"]["t_e"][0],
76+
(test_data["thomson_scattering"]["N_ch"][0],1)),
77+
np.tile(test_data["thomson_scattering"]["t_e"][1],
78+
(test_data["thomson_scattering"]["N_ch"][1],1))])
79+
flat["thomson_scattering.channel.n_e.time"] = ak.concatenate([np.tile(test_data["thomson_scattering"]["time"][0],
80+
(test_data["thomson_scattering"]["N_ch"][0],1)),
81+
np.tile(test_data["thomson_scattering"]["time"][1],
82+
(test_data["thomson_scattering"]["N_ch"][1],1))])
83+
flat["thomson_scattering.channel.n_e.data"] = ak.concatenate([np.tile(test_data["thomson_scattering"]["n_e"][0],
84+
(test_data["thomson_scattering"]["N_ch"][0],1)),
85+
np.tile(test_data["thomson_scattering"]["n_e"][1],
86+
(test_data["thomson_scattering"]["N_ch"][1],1))])
87+
flat["thomson_scattering.channel.position.r"] = test_data["thomson_scattering"]["r"]
88+
flat["thomson_scattering.channel.position.z"] = test_data["thomson_scattering"]["z"]
89+
return flat
90+
91+
@pytest.fixture
92+
def test_ids_dict(test_data):
93+
factory = IDSFactory("3.41.0")
94+
equilibrium = factory.equilibrium()
95+
equilibrium.time = test_data["equilibrium"]["time"]
96+
equilibrium.time_slice.resize(test_data["equilibrium"]["N_time"])
97+
equilibrium.ids_properties.homogeneous_time = 1
98+
for i in range(test_data["equilibrium"]["N_time"]):
99+
equilibrium.time_slice[i].time = test_data["equilibrium"]["time"][i]
100+
equilibrium.time_slice[i].profiles_1d.psi = test_data["equilibrium"]["psi_1d"]
101+
equilibrium.time_slice[i].profiles_2d.resize(1)
102+
equilibrium.time_slice[i].profiles_2d[0].grid.dim1 = test_data["equilibrium"]["r"]
103+
equilibrium.time_slice[i].profiles_2d[0].grid.dim2 = test_data["equilibrium"]["z"]
104+
equilibrium.time_slice[i].profiles_2d[0].psi = test_data["equilibrium"]["psi_2d"]
105+
106+
thomson_scattering = factory.thomson_scattering()
107+
thomson_scattering.ids_properties.homogeneous_time = 0
108+
N = test_data["thomson_scattering"]["N_ch"][0] + test_data["thomson_scattering"]["N_ch"][1]
109+
thomson_scattering.channel.resize(N)
110+
index = 0
111+
for i in range(N):
112+
if i == test_data["thomson_scattering"]["N_ch"][0]:
113+
index = 1
114+
thomson_scattering.channel[i].t_e.time = test_data["thomson_scattering"]["time"][index]
115+
thomson_scattering.channel[i].t_e.data = np.tile(test_data["thomson_scattering"]["t_e"][i],
116+
test_data["thomson_scattering"]["N_time"][index])
117+
thomson_scattering.channel[i].n_e.time = test_data["thomson_scattering"]["time"][index]
118+
thomson_scattering.channel[i].n_e.data = np.tile(test_data["thomson_scattering"]["t_e"][i],
119+
test_data["thomson_scattering"]["N_time"][index])
120+
thomson_scattering.channel[i].position.r = test_data["thomson_scattering"]["r"][i]
121+
thomson_scattering.channel[i].position.z = test_data["thomson_scattering"]["z"][i]
122+
123+
return {"equilibrium":equilibrium, "thomson_scattering": thomson_scattering}
124+
125+
126+
def test_wrangle(test_ids_dict, flat):
127+
wrangled = wrangle(flat)
128+
for key in test_ids_dict:
129+
diff = idsdiffgen(wrangled[key],test_ids_dict[key])
130+
assert len(list(diff)) == 0, diff
131+
132+
def test_unwrangle(test_ids_dict, flat):
133+
result = unwrangle(list(flat.keys()), test_ids_dict)
134+
for key in flat.keys():
135+
np.testing.assert_allclose(result[key], flat[key])

imas/wrangler.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Dict, List
2+
import awkward as ak
3+
import numpy as np
4+
from . import IDSFactory
5+
from .ids_toplevel import IDSToplevel
6+
from .backends.netcdf.ids_tensorizer import IDSTensorizer
7+
8+
def recursively_put(location, value, ids):
9+
# time_slice.profiles_1d.psi
10+
if "." in location:
11+
position, sub_location = location.split(".", 1)
12+
sub_ids = getattr(ids, position)
13+
if hasattr(sub_ids, "size"):
14+
N = len(value)
15+
if sub_ids.size == 0:
16+
sub_ids.resize(N)
17+
elif sub_ids.size != N:
18+
raise ValueError(
19+
f"""Inconsistent size across flat entries {location}, {N} (flat) vs. ids {ids.size}!
20+
"""
21+
)
22+
# Need to iterate over indices (e.g. equilibrium.time_slice[:].)
23+
for index in range(N):
24+
recursively_put(sub_location, value[index], sub_ids[index])
25+
else:
26+
# Need to set an attribute
27+
# Now get the new substring, e.g. time_slice
28+
position, sub_location = location.split(".", 1)
29+
recursively_put(sub_location, value, sub_ids)
30+
else:
31+
setattr(ids, location, value)
32+
return ids
33+
34+
35+
def wrangle(flat: Dict, version="3.41.0") -> Dict[str, IDSToplevel]:
36+
wrangled = {}
37+
factory = IDSFactory(version)
38+
for key in flat:
39+
ids, location = key.split(".", 1)
40+
if ids not in wrangled:
41+
wrangled[ids] = getattr(factory, ids)()
42+
wrangled[ids] = recursively_put(location, flat[key], wrangled[ids])
43+
return wrangled
44+
45+
def split_location_across_ids(locations: List[str]) -> Dict[str, List[str]]:
46+
ids_locations = {}
47+
for location in locations:
48+
ids, path = location.split(".",1)
49+
if ids not in ids_locations:
50+
ids_locations[ids] = []
51+
ids_locations[ids].append(path.replace(".","/") )
52+
return ids_locations
53+
54+
def unwrangle(
55+
locations: List[str], ids_dict: Dict[str, IDSToplevel], version="3.41.0"
56+
) -> Dict[str, ak.Array | np.ndarray]:
57+
flat = {}
58+
ids_locations = split_location_across_ids(locations)
59+
for key in ids_locations:
60+
tensorizer = IDSTensorizer(ids_dict[key], ids_locations[key])
61+
tensorizer.include_coordinate_paths()
62+
tensorizer.collect_filled_data()
63+
tensorizer.determine_data_shapes()
64+
# Add IDS conversion
65+
for ids_location in ids_locations[key]:
66+
location = key + "." + ids_location.replace("/", ".")
67+
values = tensorizer.awkward_tensorize(ids_location)
68+
if hasattr(values, "__getattr__"):
69+
# Not a scalar, e.g. homogenous_time
70+
try:
71+
flat[location] = np.asarray(values)
72+
except ValueError as e:
73+
flat[location] = ak.Array(values)
74+
else:
75+
flat[location] = values
76+
return flat

0 commit comments

Comments
 (0)