Skip to content

Commit 17ad3b8

Browse files
committed
Extend tests to strings
1 parent 7faa192 commit 17ad3b8

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

imas/test/test_wrangle.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def test_data():
2222

2323
data["thomson_scattering"] = {}
2424
data["thomson_scattering"]["N_ch"] = (20,10)
25+
N = data["thomson_scattering"]["N_ch"][0] + data["thomson_scattering"]["N_ch"][1]
26+
data["thomson_scattering"]["identifier"] = np.asarray("channel_" + np.asarray(np.linspace(1,N+1,N, dtype=int),dtype="|U2"),dtype="|U10")
2527
data["thomson_scattering"]["N_time"] = (100, 300)
2628
data["thomson_scattering"]["r"] = np.concatenate([np.ones(data["thomson_scattering"]["N_ch"][0])*1.6,
2729
np.ones(data["thomson_scattering"]["N_ch"][1])*1.7])
@@ -65,9 +67,8 @@ def flat(test_data):
6567
)
6668
)
6769
flat["equilibrium.time_slice.profiles_2d.psi"][:] = test_data["equilibrium"]["psi_2d"][None, ...]
68-
6970
# Thomson scattering test data (ragged)
70-
N = test_data["thomson_scattering"]["N_ch"][0] + test_data["thomson_scattering"]["N_ch"][1]
71+
flat["thomson_scattering.channel.identifier"] = test_data["thomson_scattering"]["identifier"]
7172
flat["thomson_scattering.ids_properties.homogeneous_time"] = 0
7273
flat["thomson_scattering.channel.t_e.time"] = ak.concatenate([np.tile(test_data["thomson_scattering"]["time"][0],
7374
(test_data["thomson_scattering"]["N_ch"][0],
@@ -116,6 +117,7 @@ def test_ids_dict(test_data):
116117
for i in range(N):
117118
if i == test_data["thomson_scattering"]["N_ch"][0]:
118119
index = 1
120+
thomson_scattering.channel[i].identifier = test_data["thomson_scattering"]["identifier"][i]
119121
thomson_scattering.channel[i].t_e.time = test_data["thomson_scattering"]["time"][index]
120122
thomson_scattering.channel[i].t_e.data = np.tile(test_data["thomson_scattering"]["t_e"][i],
121123
test_data["thomson_scattering"]["N_time"][index])
@@ -134,7 +136,20 @@ def test_wrangle(test_ids_dict, flat):
134136
diff = idsdiffgen(wrangled[key],test_ids_dict[key])
135137
assert len(list(diff)) == 0, diff
136138

139+
def get_dtype(arr):
140+
"""Get dtype from either numpy or awkward array."""
141+
if isinstance(arr, ak.Array):
142+
# This is the easiest way I found to extract the numpy dtype from an awkward array
143+
return eval("np." + arr.typestr.split("*")[-1])
144+
if hasattr(arr, "dtype"):
145+
return arr.dtype
146+
else:
147+
return type(arr)
148+
137149
def test_unwrangle(test_ids_dict, flat):
138150
result = unwrangle(list(flat.keys()), test_ids_dict)
139151
for key in flat.keys():
140-
assert ak.almost_equal(result[key], flat[key])
152+
if np.issubdtype(get_dtype(result[key]), np.floating):
153+
assert ak.almost_equal(result[key], flat[key])
154+
else:
155+
assert ak.array_equal(result[key], flat[key])

0 commit comments

Comments
 (0)