@@ -164,7 +164,32 @@ def interact_pattern(n_persons, select_persons, tag):
164164 return re .compile (pattern )
165165
166166
167- def cdap_interaction_utility (model , n_persons , alts , interaction_coef , coefficients ):
167+ def cdap_interaction_utility (
168+ model : lx .Model ,
169+ n_persons : int ,
170+ alts : dict ,
171+ interaction_coef : pd .DataFrame ,
172+ coefficients : pd .DataFrame ,
173+ ):
174+ """
175+ Build the interaction utility for each pattern.
176+
177+ Parameters
178+ ----------
179+ model : larch.Model
180+ n_persons : int
181+ alts : dict
182+ The keys are the names of the patterns, and
183+ the values are the alternative code numbers,
184+ as created by `generate_alternatives`.
185+ interaction_coef : pandas.DataFrame
186+ The interaction coefficients provided by
187+ the ActivitySim framework. Should include columns
188+ "cardinality", "activity", "interaction_ptypes", and "coefficient".
189+ coefficients : pandas.DataFrame
190+ The full set of coefficients provided by
191+ the ActivitySim framework.
192+ """
168193 person_numbers = list (range (1 , n_persons + 1 ))
169194
170195 matcher = re .compile ("coef_[HMN]_.*" )
@@ -174,8 +199,11 @@ def cdap_interaction_utility(model, n_persons, alts, interaction_coef, coefficie
174199 c_split = c .split ("_" )
175200 for j in c_split [2 :]:
176201 interact_coef_map [(c_split [1 ], j )] = c
177- if all ((i == "x" for i in j )): # wildcards also map to empty
178- interact_coef_map [(c_split [1 ], "" )] = c
202+ # previously, wildcards also mapped empty here, but this caused a clash
203+ # as all wildcards would map to the same coefficient name no matter the
204+ # cardinality, so instead we only map the exact wildcard case, and later
205+ # check that empty interaction_ptypes maps to the correct coefficient name
206+ # based on cardinality.
179207
180208 for (cardinality , activity ), coefs in interaction_coef .groupby (
181209 ["cardinality" , "activity" ]
@@ -194,17 +222,21 @@ def cdap_interaction_utility(model, n_persons, alts, interaction_coef, coefficie
194222 for (p , t ) in zip (person_numbers , row .interaction_ptypes )
195223 if t != "*"
196224 )
225+ row_interaction_ptypes = row .interaction_ptypes
226+ if not row_interaction_ptypes :
227+ # empty interaction_ptypes means all wildcards, but it needs to be the correct length
228+ row_interaction_ptypes = "x" * n_persons
197229 if expression :
198- if (activity , row . interaction_ptypes ) in interact_coef_map :
230+ if (activity , row_interaction_ptypes ) in interact_coef_map :
199231 linear_component = X (expression ) * P (
200- interact_coef_map [(activity , row . interaction_ptypes )]
232+ interact_coef_map [(activity , row_interaction_ptypes )]
201233 )
202234 else :
203235 linear_component = X (expression ) * P (row .coefficient )
204236 else :
205- if (activity , row . interaction_ptypes ) in interact_coef_map :
237+ if (activity , row_interaction_ptypes ) in interact_coef_map :
206238 linear_component = P (
207- interact_coef_map [(activity , row . interaction_ptypes )]
239+ interact_coef_map [(activity , row_interaction_ptypes )]
208240 )
209241 else :
210242 linear_component = P (row .coefficient )
@@ -377,7 +409,7 @@ def cdap_data(
377409 if not os .path .exists (edb_directory ):
378410 raise FileNotFoundError (edb_directory )
379411
380- def read_csv (filename , ** kwargs ):
412+ def read_csv (filename , ** kwargs ) -> pd . DataFrame :
381413 filename = Path (edb_directory ).joinpath (filename .format (name = name )).resolve ()
382414 if filename .with_suffix (".parquet" ).exists ():
383415 if "comment" in kwargs :
0 commit comments