Skip to content

Commit eb2ca73

Browse files
authored
Fixing array type casting (#279)
1 parent 3efb92a commit eb2ca73

1 file changed

Lines changed: 19 additions & 10 deletions

File tree

src/tracksdata/graph/_rustworkx_graph.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ def _maybe_fill_null(s: pl.Series, schema: AttrSchema) -> pl.Series:
6868
return s
6969

7070

71+
def _list_to_pl_series(key: str, values: list[Any], schema: AttrSchema) -> pl.Series:
72+
if isinstance(schema.dtype, pl.Array):
73+
try:
74+
values = np.asarray(values)
75+
except ValueError:
76+
# catches when it fails with None in `values`
77+
pass
78+
else:
79+
values = values
80+
s = pl.Series(name=key, values=values, dtype=schema.dtype)
81+
s = _maybe_fill_null(s, schema)
82+
return s
83+
84+
7185
def _create_filter_func(
7286
attr_comps: Sequence[AttrComparison],
7387
schema: dict[str, AttrSchema],
@@ -189,9 +203,7 @@ def _edge_attrs(self) -> pl.DataFrame:
189203

190204
for k in data.keys():
191205
schema = self._graph._edge_attr_schemas()[k]
192-
s = pl.Series(name=k, values=data[k], dtype=schema.dtype)
193-
s = _maybe_fill_null(s, schema)
194-
data[k] = s
206+
data[k] = _list_to_pl_series(k, data[k], schema)
195207

196208
data[DEFAULT_ATTR_KEYS.EDGE_SOURCE] = pl.Series(
197209
name=DEFAULT_ATTR_KEYS.EDGE_SOURCE, values=sources, dtype=pl.Int64
@@ -1109,9 +1121,7 @@ def _node_attrs_from_node_ids(
11091121

11101122
for key in attr_keys:
11111123
schema = node_attr_schemas[key]
1112-
s = pl.Series(name=key, values=columns[key], dtype=schema.dtype)
1113-
s = _maybe_fill_null(s, schema)
1114-
columns[key] = s
1124+
columns[key] = _list_to_pl_series(key, columns[key], schema)
11151125

11161126
# Create DataFrame and set node_id as index in one shot
11171127
df = pl.DataFrame(columns)
@@ -1172,11 +1182,10 @@ def edge_attrs(
11721182
columns[DEFAULT_ATTR_KEYS.EDGE_SOURCE] = source
11731183
columns[DEFAULT_ATTR_KEYS.EDGE_TARGET] = target
11741184

1185+
edge_attr_schemas = self._edge_attr_schemas()
11751186
for key in attr_keys:
1176-
schema = self._edge_attr_schemas()[key]
1177-
s = pl.Series(name=key, values=columns[key], dtype=schema.dtype)
1178-
s = _maybe_fill_null(s, schema)
1179-
columns[key] = s
1187+
schema = edge_attr_schemas[key]
1188+
columns[key] = _list_to_pl_series(key, columns[key], schema)
11801189

11811190
df = pl.DataFrame(columns)
11821191
if unpack:

0 commit comments

Comments
 (0)