-
Notifications
You must be signed in to change notification settings - Fork 62
use nn.Sequential to remove python control flow from autoencoder up/downsampling #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
31687a3 to
2a6dd2e
Compare
|
|
@yorickvP interesting that there's no graph breaks; wonder if the compiler is smart enough to realize that the if statements will always evaluate to true/false depending on loop iteration. agreed that we should get perf numbers & test that outputs are unchanged |
|
unfortunately this changes the state_dict keys: Got 180 missing keys: Detailsencoder.down.0.norm2.weight encoder.down.0.norm2.bias encoder.down.0.conv2.weight encoder.down.0.conv2.bias encoder.down.1.norm1.weight encoder.down.1.norm1.bias encoder.down.1.conv1.weight encoder.down.1.conv1.bias encoder.down.1.norm2.weight encoder.down.1.norm2.bias encoder.down.1.conv2.weight encoder.down.1.conv2.bias encoder.down.2.conv.weight encoder.down.2.conv.bias encoder.down.3.norm1.weight encoder.down.3.norm1.bias encoder.down.3.conv1.weight encoder.down.3.conv1.bias encoder.down.3.norm2.weight encoder.down.3.norm2.bias encoder.down.3.conv2.weight encoder.down.3.conv2.bias encoder.down.3.nin_shortcut.weight encoder.down.3.nin_shortcut.bias encoder.down.4.norm1.weight encoder.down.4.norm1.bias encoder.down.4.conv1.weight encoder.down.4.conv1.bias encoder.down.4.norm2.weight encoder.down.4.norm2.bias encoder.down.4.conv2.weight encoder.down.4.conv2.bias encoder.down.5.conv.weight encoder.down.5.conv.bias encoder.down.6.norm1.weight encoder.down.6.norm1.bias encoder.down.6.conv1.weight encoder.down.6.conv1.bias encoder.down.6.norm2.weight encoder.down.6.norm2.bias encoder.down.6.conv2.weight encoder.down.6.conv2.bias encoder.down.6.nin_shortcut.weight encoder.down.6.nin_shortcut.bias encoder.down.7.norm1.weight encoder.down.7.norm1.bias encoder.down.7.conv1.weight encoder.down.7.conv1.bias encoder.down.7.norm2.weight encoder.down.7.norm2.bias encoder.down.7.conv2.weight encoder.down.7.conv2.bias encoder.down.8.conv.weight encoder.down.8.conv.bias encoder.down.9.norm1.weight encoder.down.9.norm1.bias encoder.down.9.conv1.weight encoder.down.9.conv1.bias encoder.down.9.norm2.weight encoder.down.9.norm2.bias encoder.down.9.conv2.weight encoder.down.9.conv2.bias encoder.down.10.norm1.weight encoder.down.10.norm1.bias encoder.down.10.conv1.weight encoder.down.10.conv1.bias encoder.down.10.norm2.weight encoder.down.10.norm2.bias encoder.down.10.conv2.weight encoder.down.10.conv2.bias decoder.up.0.norm1.weight decoder.up.0.norm1.bias decoder.up.0.conv1.weight decoder.up.0.conv1.bias decoder.up.0.norm2.weight decoder.up.0.norm2.bias decoder.up.0.conv2.weight decoder.up.0.conv2.bias decoder.up.0.nin_shortcut.weight decoder.up.0.nin_shortcut.bias decoder.up.1.norm1.weight decoder.up.1.norm1.bias decoder.up.1.conv1.weight decoder.up.1.conv1.bias decoder.up.1.norm2.weight decoder.up.1.norm2.bias decoder.up.1.conv2.weight decoder.up.1.conv2.bias decoder.up.2.norm1.weight decoder.up.2.norm1.bias decoder.up.2.conv1.weight decoder.up.2.conv1.bias decoder.up.2.norm2.weight decoder.up.2.norm2.bias decoder.up.2.conv2.weight decoder.up.2.conv2.bias decoder.up.3.norm1.weight decoder.up.3.norm1.bias decoder.up.3.conv1.weight decoder.up.3.conv1.bias decoder.up.3.norm2.weight decoder.up.3.norm2.bias decoder.up.3.conv2.weight decoder.up.3.conv2.bias decoder.up.3.nin_shortcut.weight decoder.up.3.nin_shortcut.bias decoder.up.4.norm1.weight decoder.up.4.norm1.bias decoder.up.4.conv1.weight decoder.up.4.conv1.bias decoder.up.4.norm2.weight decoder.up.4.norm2.bias decoder.up.4.conv2.weight decoder.up.4.conv2.bias decoder.up.5.norm1.weight decoder.up.5.norm1.bias decoder.up.5.conv1.weight decoder.up.5.conv1.bias decoder.up.5.norm2.weight decoder.up.5.norm2.bias decoder.up.5.conv2.weight decoder.up.5.conv2.bias decoder.up.6.conv.weight decoder.up.6.conv.bias decoder.up.7.norm1.weight decoder.up.7.norm1.bias decoder.up.7.conv1.weight decoder.up.7.conv1.bias decoder.up.7.norm2.weight decoder.up.7.norm2.bias decoder.up.7.conv2.weight decoder.up.7.conv2.bias decoder.up.8.norm1.weight decoder.up.8.norm1.bias decoder.up.8.conv1.weight decoder.up.8.conv1.bias decoder.up.8.norm2.weight decoder.up.8.norm2.bias decoder.up.8.conv2.weight decoder.up.8.conv2.bias decoder.up.9.norm1.weight decoder.up.9.norm1.bias decoder.up.9.conv1.weight decoder.up.9.conv1.bias decoder.up.9.norm2.weight decoder.up.9.norm2.bias decoder.up.9.conv2.weight decoder.up.9.conv2.bias decoder.up.10.conv.weight decoder.up.10.conv.bias decoder.up.11.norm1.weight decoder.up.11.norm1.bias decoder.up.11.conv1.weight decoder.up.11.conv1.bias decoder.up.11.norm2.weight decoder.up.11.norm2.bias decoder.up.11.conv2.weight decoder.up.11.conv2.bias decoder.up.12.norm1.weight decoder.up.12.norm1.bias decoder.up.12.conv1.weight decoder.up.12.conv1.bias decoder.up.12.norm2.weight decoder.up.12.norm2.bias decoder.up.12.conv2.weight decoder.up.12.conv2.bias decoder.up.13.norm1.weight decoder.up.13.norm1.bias decoder.up.13.conv1.weight decoder.up.13.conv1.bias decoder.up.13.norm2.weight decoder.up.13.norm2.bias decoder.up.13.conv2.weight decoder.up.13.conv2.bias decoder.up.14.conv.weight decoder.up.14.conv.biasGot 180 unexpected keys: Detailsencoder.down.0.block.0.norm1.bias encoder.down.0.block.0.norm1.weight encoder.down.0.block.0.norm2.bias encoder.down.0.block.0.norm2.weight encoder.down.0.block.1.conv1.bias encoder.down.0.block.1.conv1.weight encoder.down.0.block.1.conv2.bias encoder.down.0.block.1.conv2.weight encoder.down.0.block.1.norm1.bias encoder.down.0.block.1.norm1.weight encoder.down.0.block.1.norm2.bias encoder.down.0.block.1.norm2.weight encoder.down.0.downsample.conv.bias encoder.down.0.downsample.conv.weight encoder.down.1.block.0.conv1.bias encoder.down.1.block.0.conv1.weight encoder.down.1.block.0.conv2.bias encoder.down.1.block.0.conv2.weight encoder.down.1.block.0.nin_shortcut.bias encoder.down.1.block.0.nin_shortcut.weight encoder.down.1.block.0.norm1.bias encoder.down.1.block.0.norm1.weight encoder.down.1.block.0.norm2.bias encoder.down.1.block.0.norm2.weight encoder.down.1.block.1.conv1.bias encoder.down.1.block.1.conv1.weight encoder.down.1.block.1.conv2.bias encoder.down.1.block.1.conv2.weight encoder.down.1.block.1.norm1.bias encoder.down.1.block.1.norm1.weight encoder.down.1.block.1.norm2.bias encoder.down.1.block.1.norm2.weight encoder.down.1.downsample.conv.bias encoder.down.1.downsample.conv.weight encoder.down.2.block.0.conv1.bias encoder.down.2.block.0.conv1.weight encoder.down.2.block.0.conv2.bias encoder.down.2.block.0.conv2.weight encoder.down.2.block.0.nin_shortcut.bias encoder.down.2.block.0.nin_shortcut.weight encoder.down.2.block.0.norm1.bias encoder.down.2.block.0.norm1.weight encoder.down.2.block.0.norm2.bias encoder.down.2.block.0.norm2.weight encoder.down.2.block.1.conv1.bias encoder.down.2.block.1.conv1.weight encoder.down.2.block.1.conv2.bias encoder.down.2.block.1.conv2.weight encoder.down.2.block.1.norm1.bias encoder.down.2.block.1.norm1.weight encoder.down.2.block.1.norm2.bias encoder.down.2.block.1.norm2.weight encoder.down.2.downsample.conv.bias encoder.down.2.downsample.conv.weight encoder.down.3.block.0.conv1.bias encoder.down.3.block.0.conv1.weight encoder.down.3.block.0.conv2.bias encoder.down.3.block.0.conv2.weight encoder.down.3.block.0.norm1.bias encoder.down.3.block.0.norm1.weight encoder.down.3.block.0.norm2.bias encoder.down.3.block.0.norm2.weight encoder.down.3.block.1.conv1.bias encoder.down.3.block.1.conv1.weight encoder.down.3.block.1.conv2.bias encoder.down.3.block.1.conv2.weight encoder.down.3.block.1.norm1.bias encoder.down.3.block.1.norm1.weight encoder.down.3.block.1.norm2.bias encoder.down.3.block.1.norm2.weight decoder.up.0.block.0.conv1.bias decoder.up.0.block.0.conv1.weight decoder.up.0.block.0.conv2.bias decoder.up.0.block.0.conv2.weight decoder.up.0.block.0.nin_shortcut.bias decoder.up.0.block.0.nin_shortcut.weight decoder.up.0.block.0.norm1.bias decoder.up.0.block.0.norm1.weight decoder.up.0.block.0.norm2.bias decoder.up.0.block.0.norm2.weight decoder.up.0.block.1.conv1.bias decoder.up.0.block.1.conv1.weight decoder.up.0.block.1.conv2.bias decoder.up.0.block.1.conv2.weight decoder.up.0.block.1.norm1.bias decoder.up.0.block.1.norm1.weight decoder.up.0.block.1.norm2.bias decoder.up.0.block.1.norm2.weight decoder.up.0.block.2.conv1.bias decoder.up.0.block.2.conv1.weight decoder.up.0.block.2.conv2.bias decoder.up.0.block.2.conv2.weight decoder.up.0.block.2.norm1.bias decoder.up.0.block.2.norm1.weight decoder.up.0.block.2.norm2.bias decoder.up.0.block.2.norm2.weight decoder.up.1.block.0.conv1.bias decoder.up.1.block.0.conv1.weight decoder.up.1.block.0.conv2.bias decoder.up.1.block.0.conv2.weight decoder.up.1.block.0.nin_shortcut.bias decoder.up.1.block.0.nin_shortcut.weight decoder.up.1.block.0.norm1.bias decoder.up.1.block.0.norm1.weight decoder.up.1.block.0.norm2.bias decoder.up.1.block.0.norm2.weight decoder.up.1.block.1.conv1.bias decoder.up.1.block.1.conv1.weight decoder.up.1.block.1.conv2.bias decoder.up.1.block.1.conv2.weight decoder.up.1.block.1.norm1.bias decoder.up.1.block.1.norm1.weight decoder.up.1.block.1.norm2.bias decoder.up.1.block.1.norm2.weight decoder.up.1.block.2.conv1.bias decoder.up.1.block.2.conv1.weight decoder.up.1.block.2.conv2.bias decoder.up.1.block.2.conv2.weight decoder.up.1.block.2.norm1.bias decoder.up.1.block.2.norm1.weight decoder.up.1.block.2.norm2.bias decoder.up.1.block.2.norm2.weight decoder.up.1.upsample.conv.bias decoder.up.1.upsample.conv.weight decoder.up.2.block.0.conv1.bias decoder.up.2.block.0.conv1.weight decoder.up.2.block.0.conv2.bias decoder.up.2.block.0.conv2.weight decoder.up.2.block.0.norm1.bias decoder.up.2.block.0.norm1.weight decoder.up.2.block.0.norm2.bias decoder.up.2.block.0.norm2.weight decoder.up.2.block.1.conv1.bias decoder.up.2.block.1.conv1.weight decoder.up.2.block.1.conv2.bias decoder.up.2.block.1.conv2.weight decoder.up.2.block.1.norm1.bias decoder.up.2.block.1.norm1.weight decoder.up.2.block.1.norm2.bias decoder.up.2.block.1.norm2.weight decoder.up.2.block.2.conv1.bias decoder.up.2.block.2.conv1.weight decoder.up.2.block.2.conv2.bias decoder.up.2.block.2.conv2.weight decoder.up.2.block.2.norm1.bias decoder.up.2.block.2.norm1.weight decoder.up.2.block.2.norm2.bias decoder.up.2.block.2.norm2.weight decoder.up.2.upsample.conv.bias decoder.up.2.upsample.conv.weight decoder.up.3.block.0.conv1.bias decoder.up.3.block.0.conv1.weight decoder.up.3.block.0.conv2.bias decoder.up.3.block.0.conv2.weight decoder.up.3.block.0.norm1.bias decoder.up.3.block.0.norm1.weight decoder.up.3.block.0.norm2.bias decoder.up.3.block.0.norm2.weight decoder.up.3.block.1.conv1.bias decoder.up.3.block.1.conv1.weight decoder.up.3.block.1.conv2.bias decoder.up.3.block.1.conv2.weight decoder.up.3.block.1.norm1.bias decoder.up.3.block.1.norm1.weight decoder.up.3.block.1.norm2.bias decoder.up.3.block.1.norm2.weight decoder.up.3.block.2.conv1.bias decoder.up.3.block.2.conv1.weight decoder.up.3.block.2.conv2.bias decoder.up.3.block.2.conv2.weight decoder.up.3.block.2.norm1.bias decoder.up.3.block.2.norm1.weight decoder.up.3.block.2.norm2.bias decoder.up.3.block.2.norm2.weight decoder.up.3.upsample.conv.bias decoder.up.3.upsample.conv.weightThe unexpected keys are the one in the ae.sft file, the expected keys are the ones from using flat |
92c664c to
35b343d
Compare
35b343d to
7333f34
Compare
|
It might be nicer to instead override |
the current autoencoder implementation causes graph breaks, likely due to python control flow. the
hslist constructed in the encoder is also egregious.performance improvement numbers TBDthis does not make an improvement with normal eager torch, if we decide to compile for the autoencoder it will fix the graph breaks with tensorrt, but torch.compile may see through this structure.