Skip to content

Input dimension mismatch when using categorical variables with >2 classes #7

@mahmoudibrahim98

Description

@mahmoudibrahim98

Issue

When using categorical variables with more than 2 classes, there's a dimension mismatch between the expected input size and actual input size after one-hot encoding. This occurs because the current channel assumes all categorical variables are binary (2 classes) by simply doubling the number of categorical columns. The current calculation doesn't account for the expanded dimensions created by one-hot encoding categorical variables.

Current Behavior

The model initialization uses:
input_channels = dataset.channels + len(dataset.categorical_cols)

This leads to errors when categorical variables have more than 2 classes. For example, with:

  • 4 numerical features
  • 6 binary categorical features (2 classes each)
  • 1 categorical feature with 4 classes

The code calculates:

  • input_channels = 11 + 7 = 18

But after one-hot encoding, the actual input dimension becomes:

  • 4 (numerical) + (6 × 2) + (1 × 4) = 20 dimensions

This causes the error:
RuntimeError: input.size(-1) must be equal to input_size. Expected 18, got 20

Solution

Calculate the total channels by summing the numerical features and the total number of classes across all categorical features:

total_numerical = len(COLUMNS_DICT[data_name]["numerical"])
total_categorical = sum(COLUMNS_DICT[data_name]["categorical_num_classes"])
total_channels = total_numerical + total_categorical
model = RNN(
input_channels = total_channels, # Will account for all classes after one-hot encoding
hidden_channels = total_channels 4 if data_name in ["stock", "energy", "mimiciv", "mimiciii", "hirid"] else 256,
output_channels = total_channels,
...
)
diffusion = MixedDiffusion(
model = model,
channels = total_channels,
...
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions