Skip to content
This repository was archived by the owner on Apr 8, 2026. It is now read-only.
This repository was archived by the owner on Apr 8, 2026. It is now read-only.

Size mismatch #300

@dhurtigkth

Description

@dhurtigkth

Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:

/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device)
182 prior.apply(_convert_conv_weights_to_fp16)
183 prior = prior.to(device)
--> 184 restore_model(hps, prior, hps.restore_prior)
185 if hps.train:
186 print_all(f"Loading prior in train mode")

/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path)
64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
---> 66 model.load_state_dict(checkpoint['model'])
67 if 'step' in checkpoint: model.step = checkpoint['step']
68

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SimplePrior:
size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).

I've used small_single_enc_dec_prior and these are the hyperparameters:

small_single_enc_dec_prior = Hyperparams(
n_ctx=6144, # original: 6144, ours: 384
prior_width=1024,
prior_depth=48,
heads=2,
attn_order=12,
blocks=64,
init_scale=0.7,
c_res=1,
prime_loss_fraction=0.4,
single_enc_dec=True,
labels=True,
labels_v3=True,
y_bins=(10,100), # Set this to (genres, artists) for your dataset
max_bow_genre_size=1,
min_duration=24.0, # 24
max_duration=600.0, # 600
t_bins=64, # original: 64, ours: 16
use_tokens=True,
n_tokens=384, # original: 384, ours: 24
n_vocab=79,

I then restore my saved checkpoint and update so the keys match. This is how I've formulated the training:

args = {
'hps': 'vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema',
'name': 'pretrained_vqvae_small_single_enc_dec_prior_labels',
'sample_length': 786432, #49152
'bs': 4,
'aug_shift': True,
'aug_blend': True,
'audio_files_dir': '/content/data',
'train': True,
'test': True,
'prior': True,
'min_duration': 24, #24
'max_duration': 600, #600
'levels': 3,
'level': 2,
'weight_decay': 0.01,
'save_iters': 10,
'copy_input': True
}

Train the model:

train.run(**args)

I've tried all kinds of things but still get the same error, does anyone have an idea what could be the problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions