| 
 | 1 | +import torch as tch  | 
 | 2 | +import torchvision.datasets as dt  | 
 | 3 | +import torchvision.transforms as trans  | 
 | 4 | +import torch.nn as nn  | 
 | 5 | +import matplotlib.pyplot as plt  | 
 | 6 | +from time import time  | 
 | 7 | + | 
 | 8 | +train = dt.MNIST(root="./datasets", train=True, transform=trans.ToTensor(), download=True)  | 
 | 9 | +test = dt.MNIST(root="./datasets", train=False, transform=trans.ToTensor(), download=True)  | 
 | 10 | +print("No. of Training examples: ",len(train))  | 
 | 11 | +print("No. of Test examples: ",len(test))  | 
 | 12 | + | 
 | 13 | +train_batch = tch.utils.data.DataLoader(train, batch_size=30, shuffle=True)  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +input = 784  | 
 | 17 | +hidden = 490  | 
 | 18 | +output = 10  | 
 | 19 | + | 
 | 20 | +model = nn.Sequential(nn.Linear(input, hidden),  | 
 | 21 | +                      nn.LeakyReLU(),  | 
 | 22 | +                      nn.Linear(hidden, output),  | 
 | 23 | +                      nn.LogSoftmax(dim=1))  | 
 | 24 | + | 
 | 25 | +lossfn = nn.NLLLoss()  | 
 | 26 | +images, labels = next(iter(train_batch))  | 
 | 27 | +images = images.view(images.shape[0], -1)  | 
 | 28 | + | 
 | 29 | +logps = model(images)  | 
 | 30 | +loss = lossfn(logps, labels)  | 
 | 31 | +loss.backward()  | 
 | 32 | + | 
 | 33 | +optimize = tch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)  | 
 | 34 | +time_start = time()  | 
 | 35 | +epochs = 18  | 
 | 36 | +for num in range(epochs):  | 
 | 37 | +    run=0  | 
 | 38 | +    for images, labels in train_batch:  | 
 | 39 | +        images = images.view(images.shape[0], -1)  | 
 | 40 | +        optimize.zero_grad()  | 
 | 41 | +        output = model(images)  | 
 | 42 | +        loss = lossfn(output, labels)  | 
 | 43 | +        loss.backward()  | 
 | 44 | +        optimize.step()  | 
 | 45 | +        run += loss.item()  | 
 | 46 | +    else:  | 
 | 47 | +        print("Epoch Number : {} = Loss : {}".format(num, run/len(train_batch)))  | 
 | 48 | +Elapsed=(time()-time_start)/60  | 
 | 49 | +print("\nTraining Time (in minutes) : ",Elapsed)  | 
 | 50 | + | 
 | 51 | +correct=0  | 
 | 52 | +all = 0  | 
 | 53 | +for images,labels in test:  | 
 | 54 | +  img = images.view(1, 784)  | 
 | 55 | +  with tch.no_grad():  | 
 | 56 | +    logps = model(img)     | 
 | 57 | +  ps = tch.exp(logps)  | 
 | 58 | +  probab = list(ps.numpy()[0])  | 
 | 59 | +  prediction = probab.index(max(probab))  | 
 | 60 | +  truth = labels  | 
 | 61 | +  if(truth == prediction):  | 
 | 62 | +    correct += 1  | 
 | 63 | +  all += 1  | 
 | 64 | + | 
 | 65 | +print("Number Of Images Tested : ", all)  | 
 | 66 | +print("Model Accuracy : ", (correct/all))  | 
 | 67 | + | 
 | 68 | +tch.save(model, './mnist_model.pt')  | 
0 commit comments