-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Description
Code führt zu Fehler:
Unter 1. fehlen wichtige Importe.
Mein Code:
import torch
import os
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import optim
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
Unter 3. werden Farbbilder normalisiert. MNIST Datensatz sind aber Graustufenbilder:
Daher muss die Compose-Funktion korrigiert werden und batch_size=batch_size geht auch nicht (ggf. muss bei dataset download=True bei 1. Ausführen gesetzt werden):
transform_image = transforms.Compose([transforms.ToTensor()
,transforms.Normalize((0.5, ), (0.5,))])
dataset = MNIST('./MNIST_data', transform=transform_image, download=False)
data_loader = DataLoader(dataset, batch_size=64,shuffle=True)
Plotfunktion auf Graustufenbilder anpassen:
def plot_img(image):
plt.imshow(image[0],cmap='gray')
Dann kann man es auch plotten:
sample_data = next(iter(data_loader))
plot_img(sample_data[0][2])
Bei 6. müsste man noch ein paar Anpassungen machen (bei dem print) und dem fehlenden Verzeichnis:
for epoch in range(number_epochs):
for data in data_loader:
image, i = data
image = image.view(image.size(0), -1)
image = Variable(image)
# Forward pass
output = model(image)
loss = criterion(output, image)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss:{:.4f}'.format(epoch + 1,number_epochs, loss.item()))
if epoch % 10 == 0:
os.makedirs('./mlp_img/', exist_ok=True)
pic = to_image(output.cpu().data)
save_image(pic, './mlp_img/image_{}.png'.format(epoch))
torch.save(model.state_dict(), './sim_autoencoder.pth')
Dann funktioniert es schon mal. Können Sie schauen, ob es passt?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels