@@ -84,12 +84,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
8484 x .get ("continuous" , torch .empty (0 , 0 )),
8585 x .get ("categorical" , torch .empty (0 , 0 )),
8686 )
87- assert (
88- categorical_data . shape [ 1 ] == self . categorical_dim
89- ), "categorical_data must have same number of columns as categorical embedding layers"
90- assert (
91- continuous_data . shape [ 1 ] == self . continuous_dim
92- ), "continuous_data must have same number of columns as continuous dim"
87+ assert categorical_data . shape [ 1 ] == self . categorical_dim , (
88+ " categorical_data must have same number of columns as categorical embedding layers"
89+ )
90+ assert continuous_data . shape [ 1 ] == self . continuous_dim , (
91+ " continuous_data must have same number of columns as continuous dim"
92+ )
9393 embed = None
9494 if continuous_data .shape [1 ] > 0 :
9595 if self .batch_norm_continuous_input :
@@ -141,12 +141,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
141141 x .get ("continuous" , torch .empty (0 , 0 )),
142142 x .get ("categorical" , torch .empty (0 , 0 )),
143143 )
144- assert categorical_data .shape [1 ] == len (
145- self . cat_embedding_layers
146- ), "categorical_data must have same number of columns as categorical embedding layers"
147- assert (
148- continuous_data . shape [ 1 ] == self . continuous_dim
149- ), "continuous_data must have same number of columns as continuous dim"
144+ assert categorical_data .shape [1 ] == len (self . cat_embedding_layers ), (
145+ "categorical_data must have same number of columns as categorical embedding layers"
146+ )
147+ assert continuous_data . shape [ 1 ] == self . continuous_dim , (
148+ " continuous_data must have same number of columns as continuous dim"
149+ )
150150 embed = None
151151 if continuous_data .shape [1 ] > 0 :
152152 if self .batch_norm_continuous_input :
@@ -273,12 +273,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
273273 x .get ("continuous" , torch .empty (0 , 0 )),
274274 x .get ("categorical" , torch .empty (0 , 0 )),
275275 )
276- assert categorical_data .shape [1 ] == len (
277- self . cat_embedding_layers
278- ), "categorical_data must have same number of columns as categorical embedding layers"
279- assert (
280- continuous_data . shape [ 1 ] == self . continuous_dim
281- ), "continuous_data must have same number of columns as continuous dim"
276+ assert categorical_data .shape [1 ] == len (self . cat_embedding_layers ), (
277+ "categorical_data must have same number of columns as categorical embedding layers"
278+ )
279+ assert continuous_data . shape [ 1 ] == self . continuous_dim , (
280+ " continuous_data must have same number of columns as continuous dim"
281+ )
282282 embed = None
283283 if continuous_data .shape [1 ] > 0 :
284284 cont_idx = torch .arange (self .continuous_dim , device = continuous_data .device ).expand (
0 commit comments