From 03f6f22b71d5a8ad29f0d545550f32fa3599d8d2 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Fri, 28 Oct 2022 15:28:08 -0600 Subject: [PATCH 01/26] Added File For Multilayer Perceptron I'll fix up the imports and make them look cleaner. --- ...l - Julia Knet Multilayer Perceptron.ipynb | 596 ++++++++++++++++++ 1 file changed, 596 insertions(+) create mode 100644 multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb diff --git a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb new file mode 100644 index 0000000..66663cc --- /dev/null +++ b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb @@ -0,0 +1,596 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0e969ff0", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "061ec219", + "metadata": {}, + "outputs": [], + "source": [ + "using MLDatasets: MNIST\n", + "using Knet, IterTools, MLDatasets\n", + "using Base.Iterators: take, drop, cycle, Stateful\n", + "using Printf\n", + "using Knet:minibatch\n", + "using Knet:minimize\n", + "using Knet\n", + "using Knet: Param\n", + "using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", + "using Flatten\n", + "using Flux.Data;\n", + "using Flux, Statistics\n", + "import Flatten: flattenable" + ] + }, + { + "cell_type": "markdown", + "id": "4be16139", + "metadata": {}, + "source": [ + "# Processing Data" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "cb077122", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "28×28×60000 Array{Float32, 3}\n", + "60000-element Vector{Int64}\n", + "28×28×10000 Array{Float32, 3}\n", + "10000-element Vector{Int64}\n" + ] + } + ], + "source": [ + "# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. \n", + "xtrn,ytrn = MNIST.traindata(Float32)\n", + "xtst,ytst = MNIST.testdata(Float32)\n", + "println.(summary.((xtrn,ytrn,xtst,ytst)));" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "1c9abdfc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(\"784×60000 Matrix{Float32}\", \"784×10000 Matrix{Float32}\")\n" + ] + } + ], + "source": [ + "xtrn = reshape(xtrn, 784, 60000 ) \n", + "xtst = reshape(xtst, 784, 10000 )\n", + "println(summary.((xtrn, xtst))) # can see the data that is flattened " + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "5b120fe0", + "metadata": {}, + "outputs": [], + "source": [ + "#Preprocessing targets: one hot vectors\n", + "# ytrn = onehotbatch(ytrn, 0:9)\n", + "# ytst = onehotbatch(ytst, 0:9)" + ] + }, + { + "cell_type": "markdown", + "id": "ca5b6908", + "metadata": {}, + "source": [ + "# Batch Processing" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "3d8bc019", + "metadata": {}, + "outputs": [], + "source": [ + "train_loader = DataLoader((xtrn, ytrn), batchsize=128);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08be497f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "1d6b4034", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "cc30960a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "784×128 Matrix{Float32}\n", + "128-element Vector{Int64}\n" + ] + } + ], + "source": [ + "(x,y) = first(train_loader) #gives the first minibatch from training dataset\n", + "println.(summary.((x,y)));" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "224f849d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader" + ] + }, + { + "cell_type": "markdown", + "id": "4748aa9a", + "metadata": {}, + "source": [ + "# Define Dense Layer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f0043d2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "830e10d0", + "metadata": {}, + "outputs": [], + "source": [ + "struct Dense1; w; b; f; end\n", + "Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f)\n", + "(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b)" + ] + }, + { + "cell_type": "markdown", + "id": "7f50ff27", + "metadata": { + "scrolled": true + }, + "source": [ + "# Define Chain Layer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "6507022b", + "metadata": {}, + "outputs": [], + "source": [ + "# Define a chain of layers and a loss function:\n", + "struct Chain; layers; end\n", + "(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)\n", + "(c::Chain)(x,y) = nll(c(x),y)" + ] + }, + { + "cell_type": "markdown", + "id": "588d08f8", + "metadata": {}, + "source": [ + "# Define the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "78ca61e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain((Dense1(P(Matrix{Float32}(100,784)), P(Vector{Float32}(100)), Knet.Ops20.relu), Dense1(P(Matrix{Float32}(10,100)), P(Vector{Float32}(10)), Knet.Ops20.relu)))" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Chain((Dense1(784, 100), Dense1(100, 10)))" + ] + }, + { + "cell_type": "markdown", + "id": "12fa672b", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "01ea3152", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "10×128 Matrix{Float32}:\n", + " 0.214272 0.269956 0.0 … 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0688322\n", + " 0.447498 0.223454 0.391645 0.0 0.0 0.0\n", + " 0.0 0.0 0.209459 0.0 0.0 0.0\n", + " 0.11944 0.0 0.0 0.124931 0.0 0.0\n", + " 0.480624 0.0618137 0.0567281 … 0.337361 0.321017 0.279554\n", + " 0.567259 0.579033 0.0 0.0177205 0.0287222 0.0821771\n", + " 0.519518 0.900291 0.177696 0.304197 0.39036 0.436304\n", + " 0.0 0.0 0.00714952 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(x) #checking if training is working" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "0cb35eba", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:04/00:04, 122.72i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:03, 162.35i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 1, loss is 0.223415, accuracy is 0.935518 4.726154 seconds (327.16 k allocations: 1.256 GiB, 15.36% gc time, 2.45% compilation time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 185.82i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:02, 205.09i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 2, loss is 0.161276, accuracy is 0.956081 3.293366 seconds (311.97 k allocations: 1.254 GiB, 4.57% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 161.56i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:03, 169.43i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 3, loss is 0.129508, accuracy is 0.967232 3.659111 seconds (312.32 k allocations: 1.254 GiB, 4.26% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 151.47i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:03, 162.74i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 4, loss is 0.111773, accuracy is 0.973371 3.976988 seconds (312.50 k allocations: 1.254 GiB, 4.70% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 163.64i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:02, 247.97i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 5, loss is 0.100974, accuracy is 0.978142 3.635032 seconds (312.30 k allocations: 1.254 GiB, 5.48% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 156.73i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:03, 172.15i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 6, loss is 0.093855, accuracy is 0.981822 3.744225 seconds (312.30 k allocations: 1.254 GiB, 4.23% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 162.35i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:03, 164.20i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 7, loss is 0.089960, accuracy is 0.984115 3.644892 seconds (312.31 k allocations: 1.254 GiB, 3.83% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 170.55i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:02, 196.92i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 8, loss is 0.087794, accuracy is 0.986242 3.537420 seconds (312.31 k allocations: 1.254 GiB, 5.22% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:04/00:04, 109.85i/s] \n", + "┣ ┫ [0.21%, 1/469, 00:00/00:03, 179.85i/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 9, loss is 0.085906, accuracy is 0.987814 5.220644 seconds (312.67 k allocations: 1.254 GiB, 4.46% gc time)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 157.98i/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch is 10, loss is 0.086418, accuracy is 0.989034 3.682666 seconds (312.32 k allocations: 1.254 GiB, 3.70% gc time)\n", + "Overall Loss: 0.086418\n", + "Overall Accuracy: 0.989034155001202\n" + ] + } + ], + "source": [ + "loss(xtst, ytst) = nll(model(xtst), ytst)\n", + "evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss \n", + "\n", + " for epoch in 1:10\n", + " @time begin\n", + " progress!(adam(model, train_loader; lr = 1e-3))\n", + " @printf(\"epoch is %d, loss is %f, accuracy is %f\", epoch, (evalcb()), accuracy(model, train_loader))\n", + " end \n", + " end \n", + "\n", + "\n", + "println(\"Overall Loss: \", evalcb()) \n", + "println(\"Overall Accuracy: \", accuracy(model, train_loader))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7af4c1eb", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a28d3616", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cd9c1b7", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d030c34c", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adc056c1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1673271f", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "686f14b8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a56aaea6", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17a0901a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83327f51", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68debc4d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d0e3aa3d1c7281149a285843356f6f58bdc583e4 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 31 Oct 2022 12:21:22 -0600 Subject: [PATCH 02/26] Metrics Will need to adjust and fix some metrics. --- ...odel - Julia Knet Multilayer Perceptron.jl | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl diff --git a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl new file mode 100644 index 0000000..c60568d --- /dev/null +++ b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl @@ -0,0 +1,139 @@ +using MLDatasets: MNIST +using Knet, IterTools, MLDatasets +using Dictionaries +using TimerOutputs +using TimerOutputs +using JSON +using Printf +using Knet:minibatch +using Knet:minimize +using Knet +using Knet: Param +using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu +using Flatten +using Flux.Data; +using Flux, Statistics + +# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. +xtrn,ytrn = MNIST.traindata(Float32) +xtst,ytst = MNIST.testdata(Float32) +println.(summary.((xtrn,ytrn,xtst,ytst))); + +xtrn = reshape(xtrn, 784, 60000 ) +xtst = reshape(xtst, 784, 10000 ) +println(summary.((xtrn, xtst))) # can see the data that is flattened + +#Preprocessing targets: one hot vectors +# ytrn = onehotbatch(ytrn, 0:9) +# ytst = onehotbatch(ytst, 0:9) + +train_loader = DataLoader((xtrn, ytrn), batchsize=128); +test_loader = DataLoader((xtst, ytst), batchsize = 128) + +length(test_loader) + + + +(x,y) = first(train_loader) #gives the first minibatch from training dataset +println.(summary.((x,y))); + + + + + +struct Dense1; w; b; f; end +Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f) +(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b) + +# Define a chain of layers and a loss function: +struct Chain; layers; end +(c::Chain)(x) = (for l in c.layers; x = l(x); end; x) +(c::Chain)(x,y) = nll(c(x),y) + +model = Chain((Dense1(784, 100), Dense1(100, 10))) + +model(x) #checking if training is working + + +loss(xtst, ytst) = nll(model(xtst), ytst) +evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss +const to = TimerOutput() # creating a TimerOutput, keeps track of everything + + +@timeit to "Train Total" begin + for epoch in 1:10 + @timeit to "train_epoch" begin + progress!(adam(model, train_loader; lr = 1e-3)) + end + + @timeit to "evaluation" begin + accuracy(model, test_loader) + end + @printf("epoch is %d, loss is %f, accuracy is %f", epoch, (evalcb()), accuracy(model, test_loader)) + end +end + + + + +final_train_loss = evalcb() +final_eval_accuracy = accuracy(model, test_loader) + +# see the overall loss +println("Overall Loss: ", final_train_loss) +println("Overall Accuracy: ", final_eval_accuracy ) #see the overall accuracy + + + + +show(to, allocations = true, compact = true) #see the time it took for training and evaluating the model + + + +#average epoch training time converted to seconds from nanoseconds +average_train_epoch_time = (mean(TimerOutputs.time(to["Train Total"]["train_epoch"])))/(1e+9 *10) +total_train_time = TimerOutputs.time(to["Train Total"])/(1e+9) +average_batch_inference_time = TimerOutputs.time(to["Train Total"]["evaluation"])/(length(test_loader)*1e+9) + +average_train_epoch_time + + + +#getting dictionary to format the metrics +metrics = Dict("model_name" => "MLP", + "framework_name"=>"Knet", + "dataset" => "MNIST Digits", + "task" => "classifcation", + "average_epoch_training_time" => average_train_epoch_time, + "total_training_time" => total_train_time, + "average_batch_inference_time" => average_batch_inference_time, + "final_training_loss" => final_train_loss, + "final_evaluation_accuracy" => final_eval_accuracy +) + + +stringdata = JSON.json(metrics) + +#will allow the metrics to be entered into a file + +open("M1-Knet-mlp.json", "w") do f + write(f, stringdata) + end + +dict2 = Dict() +open("M1-Knet-mlp.json", "r") do f + global dict2 + dict2 = JSON.parse(f) +end + +pwd() #checking directory + + + + + + + + + + From 65fbd67933ac5c8fb924bd94051654cf61ac7b54 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 31 Oct 2022 12:22:46 -0600 Subject: [PATCH 03/26] Delete First Model - Julia Knet Multilayer Perceptron.jl --- ...odel - Julia Knet Multilayer Perceptron.jl | 139 ------------------ 1 file changed, 139 deletions(-) delete mode 100644 multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl diff --git a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl deleted file mode 100644 index c60568d..0000000 --- a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl +++ /dev/null @@ -1,139 +0,0 @@ -using MLDatasets: MNIST -using Knet, IterTools, MLDatasets -using Dictionaries -using TimerOutputs -using TimerOutputs -using JSON -using Printf -using Knet:minibatch -using Knet:minimize -using Knet -using Knet: Param -using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu -using Flatten -using Flux.Data; -using Flux, Statistics - -# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. -xtrn,ytrn = MNIST.traindata(Float32) -xtst,ytst = MNIST.testdata(Float32) -println.(summary.((xtrn,ytrn,xtst,ytst))); - -xtrn = reshape(xtrn, 784, 60000 ) -xtst = reshape(xtst, 784, 10000 ) -println(summary.((xtrn, xtst))) # can see the data that is flattened - -#Preprocessing targets: one hot vectors -# ytrn = onehotbatch(ytrn, 0:9) -# ytst = onehotbatch(ytst, 0:9) - -train_loader = DataLoader((xtrn, ytrn), batchsize=128); -test_loader = DataLoader((xtst, ytst), batchsize = 128) - -length(test_loader) - - - -(x,y) = first(train_loader) #gives the first minibatch from training dataset -println.(summary.((x,y))); - - - - - -struct Dense1; w; b; f; end -Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f) -(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b) - -# Define a chain of layers and a loss function: -struct Chain; layers; end -(c::Chain)(x) = (for l in c.layers; x = l(x); end; x) -(c::Chain)(x,y) = nll(c(x),y) - -model = Chain((Dense1(784, 100), Dense1(100, 10))) - -model(x) #checking if training is working - - -loss(xtst, ytst) = nll(model(xtst), ytst) -evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss -const to = TimerOutput() # creating a TimerOutput, keeps track of everything - - -@timeit to "Train Total" begin - for epoch in 1:10 - @timeit to "train_epoch" begin - progress!(adam(model, train_loader; lr = 1e-3)) - end - - @timeit to "evaluation" begin - accuracy(model, test_loader) - end - @printf("epoch is %d, loss is %f, accuracy is %f", epoch, (evalcb()), accuracy(model, test_loader)) - end -end - - - - -final_train_loss = evalcb() -final_eval_accuracy = accuracy(model, test_loader) - -# see the overall loss -println("Overall Loss: ", final_train_loss) -println("Overall Accuracy: ", final_eval_accuracy ) #see the overall accuracy - - - - -show(to, allocations = true, compact = true) #see the time it took for training and evaluating the model - - - -#average epoch training time converted to seconds from nanoseconds -average_train_epoch_time = (mean(TimerOutputs.time(to["Train Total"]["train_epoch"])))/(1e+9 *10) -total_train_time = TimerOutputs.time(to["Train Total"])/(1e+9) -average_batch_inference_time = TimerOutputs.time(to["Train Total"]["evaluation"])/(length(test_loader)*1e+9) - -average_train_epoch_time - - - -#getting dictionary to format the metrics -metrics = Dict("model_name" => "MLP", - "framework_name"=>"Knet", - "dataset" => "MNIST Digits", - "task" => "classifcation", - "average_epoch_training_time" => average_train_epoch_time, - "total_training_time" => total_train_time, - "average_batch_inference_time" => average_batch_inference_time, - "final_training_loss" => final_train_loss, - "final_evaluation_accuracy" => final_eval_accuracy -) - - -stringdata = JSON.json(metrics) - -#will allow the metrics to be entered into a file - -open("M1-Knet-mlp.json", "w") do f - write(f, stringdata) - end - -dict2 = Dict() -open("M1-Knet-mlp.json", "r") do f - global dict2 - dict2 = JSON.parse(f) -end - -pwd() #checking directory - - - - - - - - - - From 865bf81f8f51c8ffa52194769128a7bbced376de Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 31 Oct 2022 12:28:55 -0600 Subject: [PATCH 04/26] Add files via upload The metrics based on performance might not be accurately defined. Will most likely make changes. --- ...odel - Julia Knet Multilayer Perceptron.jl | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl diff --git a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl new file mode 100644 index 0000000..5bfeef0 --- /dev/null +++ b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl @@ -0,0 +1,139 @@ +using MLDatasets: MNIST +using Knet, IterTools, MLDatasets +using Dictionaries +using TimerOutputs +using TimerOutputs +using JSON +using Printf +using Knet:minibatch +using Knet:minimize +using Knet +using Knet: Param +using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu +using Flatten +using Flux.Data; +using Flux, Statistics + +# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. +xtrn,ytrn = MNIST.traindata(Float32) +xtst,ytst = MNIST.testdata(Float32) +println.(summary.((xtrn,ytrn,xtst,ytst))); + +xtrn = reshape(xtrn, 784, 60000 ) +xtst = reshape(xtst, 784, 10000 ) +println(summary.((xtrn, xtst))) # can see the data that is flattened + +#Preprocessing targets: one hot vectors +# ytrn = onehotbatch(ytrn, 0:9) +# ytst = onehotbatch(ytst, 0:9) + +train_loader = DataLoader((xtrn, ytrn), batchsize=128); +test_loader = DataLoader((xtst, ytst), batchsize = 128) + +length(test_loader) + + + +(x,y) = first(train_loader) #gives the first minibatch from training dataset +println.(summary.((x,y))); + + + + + +struct Dense1; w; b; f; end +Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f) +(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b) + +# Define a chain of layers and a loss function: +struct Chain; layers; end +(c::Chain)(x) = (for l in c.layers; x = l(x); end; x) +(c::Chain)(x,y) = nll(c(x),y) + +model = Chain((Dense1(784, 100), Dense1(100, 10))) + +model(x) #checking if training is working + + +loss(xtst, ytst) = nll(model(xtst), ytst) +evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss +const to = TimerOutput() # creating a TimerOutput, keeps track of everything + + +@timeit to "Train Total" begin + for epoch in 1:10 + @timeit to "train_epoch" begin + progress!(adam(model, train_loader; lr = 1e-3)) + end + + @timeit to "evaluation" begin + accuracy(model, test_loader) + end + @printf("epoch is %d, loss is %f, accuracy is %f", epoch, (evalcb()), accuracy(model, test_loader)) + end +end + + + + +final_train_loss = evalcb() +final_eval_accuracy = accuracy(model, test_loader) + +# see the overall loss +println("Overall Loss: ", final_train_loss) +println("Overall Accuracy: ", final_eval_accuracy ) #see the overall accuracy + + + + +show(to, allocations = true, compact = true) #see the time it took for training and evaluating the model + + + +#average epoch training time converted to seconds from nanoseconds +average_train_epoch_time = (TimerOutputs.time(to["Train Total"]["train_epoch"]))/(1e+9 *10) +total_train_time = TimerOutputs.time(to["Train Total"])/(1e+9) +average_batch_inference_time = TimerOutputs.time(to["Train Total"]["evaluation"])/(length(test_loader)*1e+9) + +average_train_epoch_time + + + +#getting dictionary to format the metrics +metrics = Dict("model_name" => "MLP", + "framework_name"=>"Knet", + "dataset" => "MNIST Digits", + "task" => "classifcation", + "average_epoch_training_time" => average_train_epoch_time, + "total_training_time" => total_train_time, + "average_batch_inference_time" => average_batch_inference_time, + "final_training_loss" => final_train_loss, + "final_evaluation_accuracy" => final_eval_accuracy +) + + +stringdata = JSON.json(metrics) + +#will allow the metrics to be entered into a file + +open("M1-Knet-mlp.json", "w") do f + write(f, stringdata) + end + +dict2 = Dict() +open("M1-Knet-mlp.json", "r") do f + global dict2 + dict2 = JSON.parse(f) +end + +pwd() #checking directory + + + + + + + + + + From bd0ad027acdd6d0dc83d2608b7fa119816ec39aa Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Tue, 1 Nov 2022 08:58:04 -0600 Subject: [PATCH 05/26] Script Julia Knet updated script --- multi-layer-perceptron/KNet_test.jl | 143 ++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 multi-layer-perceptron/KNet_test.jl diff --git a/multi-layer-perceptron/KNet_test.jl b/multi-layer-perceptron/KNet_test.jl new file mode 100644 index 0000000..5e3f457 --- /dev/null +++ b/multi-layer-perceptron/KNet_test.jl @@ -0,0 +1,143 @@ +using MLDatasets: MNIST +using Knet, IterTools, MLDatasets +using Dictionaries +using TimerOutputs +using TimerOutputs +using JSON +using Printf +using Knet:minibatch +using Knet:minimize +using Knet +using Knet: Param +using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu +using Flatten +using Flux.Data; +using Flux, Statistics + +# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. +xtrn,ytrn = MNIST.traindata(Float32) +xtst,ytst = MNIST.testdata(Float32) +println.(summary.((xtrn,ytrn,xtst,ytst))); + +xtrn = reshape(xtrn, 784, 60000 ) +xtst = reshape(xtst, 784, 10000 ) +println(summary.((xtrn, xtst))) # can see the data that is flattened + +#Preprocessing targets: one hot vectors +# ytrn = onehotbatch(ytrn, 0:9) +# ytst = onehotbatch(ytst, 0:9) + +train_loader = DataLoader((xtrn, ytrn), batchsize=128); +test_loader = DataLoader((xtst, ytst), batchsize = 128) + +length(test_loader) + + + +(x,y) = first(train_loader) #gives the first minibatch from training dataset +println.(summary.((x,y))); + + + + + +struct Dense1; w; b; f; end +Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f) +(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b) + +# Define a chain of layers and a loss function: +struct Chain; layers; end +(c::Chain)(x) = (for l in c.layers; x = l(x); end; x) +(c::Chain)(x,y) = nll(c(x),y) + +model = Chain((Dense1(784, 100), Dense1(100, 10), identity)) + +model(x) #checking if training is working + + +loss(xtst, ytst) = nll(model(xtst), ytst) +evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss +const to = TimerOutput() # creating a TimerOutput, keeps track of everything + + +@timeit to "Train Total" begin + for epoch in 1:10 + train_epoch = epoch > 1 ? "train_epoch" : "train_ji" + @timeit to train_epoch begin + progress!(adam(model, train_loader; lr = 1e-3)) + end + + evaluation = epoch > 1 ? "evaluation" : "eval_jit" + @timeit to evaluation begin + accuracy(model, test_loader) + end + + end +end + + + + +final_train_loss = evalcb() +final_eval_accuracy = accuracy(model, test_loader) + +# see the overall loss +println("Overall Loss: ", final_train_loss) +println("Overall Accuracy: ", final_eval_accuracy ) #see the overall accuracy + + + + +show(to, allocations = true, compact = true) #see the time it took for training and evaluating the model + + + +#average epoch training time converted to seconds from nanoseconds +average_train_epoch_time = (TimerOutputs.time(to["Train Total"]["train_epoch"]))/(1e+9 *9) +total_train_time = TimerOutputs.time(to["Train Total"])/(1e+9) +average_batch_inference_time = TimerOutputs.time(to["Train Total"]["evaluation"])/(length(test_loader)*1e+6*9) + +average_train_epoch_time + + + +#getting dictionary to format the metrics +metrics = Dict("model_name" => "MLP", + "framework_name"=>"Knet", + "dataset" => "MNIST Digits", + "task" => "classifcation", + "average_epoch_training_time" => average_train_epoch_time, + "total_training_time" => total_train_time, + "average_batch_inference_time" => average_batch_inference_time, + "final_training_loss" => final_train_loss, + "final_evaluation_accuracy" => final_eval_accuracy +) + + +stringdata = JSON.json(metrics) + +#will allow the metrics to be entered into a file + +open("M1-Knet-mlp.json", "w") do f + write(f, stringdata) + end + +dict2 = Dict() +open("M1-Knet-mlp.json", "r") do f + global dict2 + dict2 = JSON.parse(f) +end + +pwd() #checking directory + + + + + + + + + + + + From e47dcf26df9261aa892d065e5229ffd445d9bc39 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Tue, 1 Nov 2022 08:58:19 -0600 Subject: [PATCH 06/26] Delete First Model - Julia Knet Multilayer Perceptron.jl --- ...odel - Julia Knet Multilayer Perceptron.jl | 139 ------------------ 1 file changed, 139 deletions(-) delete mode 100644 multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl diff --git a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl deleted file mode 100644 index 5bfeef0..0000000 --- a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.jl +++ /dev/null @@ -1,139 +0,0 @@ -using MLDatasets: MNIST -using Knet, IterTools, MLDatasets -using Dictionaries -using TimerOutputs -using TimerOutputs -using JSON -using Printf -using Knet:minibatch -using Knet:minimize -using Knet -using Knet: Param -using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu -using Flatten -using Flux.Data; -using Flux, Statistics - -# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. -xtrn,ytrn = MNIST.traindata(Float32) -xtst,ytst = MNIST.testdata(Float32) -println.(summary.((xtrn,ytrn,xtst,ytst))); - -xtrn = reshape(xtrn, 784, 60000 ) -xtst = reshape(xtst, 784, 10000 ) -println(summary.((xtrn, xtst))) # can see the data that is flattened - -#Preprocessing targets: one hot vectors -# ytrn = onehotbatch(ytrn, 0:9) -# ytst = onehotbatch(ytst, 0:9) - -train_loader = DataLoader((xtrn, ytrn), batchsize=128); -test_loader = DataLoader((xtst, ytst), batchsize = 128) - -length(test_loader) - - - -(x,y) = first(train_loader) #gives the first minibatch from training dataset -println.(summary.((x,y))); - - - - - -struct Dense1; w; b; f; end -Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f) -(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b) - -# Define a chain of layers and a loss function: -struct Chain; layers; end -(c::Chain)(x) = (for l in c.layers; x = l(x); end; x) -(c::Chain)(x,y) = nll(c(x),y) - -model = Chain((Dense1(784, 100), Dense1(100, 10))) - -model(x) #checking if training is working - - -loss(xtst, ytst) = nll(model(xtst), ytst) -evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss -const to = TimerOutput() # creating a TimerOutput, keeps track of everything - - -@timeit to "Train Total" begin - for epoch in 1:10 - @timeit to "train_epoch" begin - progress!(adam(model, train_loader; lr = 1e-3)) - end - - @timeit to "evaluation" begin - accuracy(model, test_loader) - end - @printf("epoch is %d, loss is %f, accuracy is %f", epoch, (evalcb()), accuracy(model, test_loader)) - end -end - - - - -final_train_loss = evalcb() -final_eval_accuracy = accuracy(model, test_loader) - -# see the overall loss -println("Overall Loss: ", final_train_loss) -println("Overall Accuracy: ", final_eval_accuracy ) #see the overall accuracy - - - - -show(to, allocations = true, compact = true) #see the time it took for training and evaluating the model - - - -#average epoch training time converted to seconds from nanoseconds -average_train_epoch_time = (TimerOutputs.time(to["Train Total"]["train_epoch"]))/(1e+9 *10) -total_train_time = TimerOutputs.time(to["Train Total"])/(1e+9) -average_batch_inference_time = TimerOutputs.time(to["Train Total"]["evaluation"])/(length(test_loader)*1e+9) - -average_train_epoch_time - - - -#getting dictionary to format the metrics -metrics = Dict("model_name" => "MLP", - "framework_name"=>"Knet", - "dataset" => "MNIST Digits", - "task" => "classifcation", - "average_epoch_training_time" => average_train_epoch_time, - "total_training_time" => total_train_time, - "average_batch_inference_time" => average_batch_inference_time, - "final_training_loss" => final_train_loss, - "final_evaluation_accuracy" => final_eval_accuracy -) - - -stringdata = JSON.json(metrics) - -#will allow the metrics to be entered into a file - -open("M1-Knet-mlp.json", "w") do f - write(f, stringdata) - end - -dict2 = Dict() -open("M1-Knet-mlp.json", "r") do f - global dict2 - dict2 = JSON.parse(f) -end - -pwd() #checking directory - - - - - - - - - - From 0be33a03de1260676d57e35bdb5fd809249c57d3 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Tue, 1 Nov 2022 09:00:54 -0600 Subject: [PATCH 07/26] Delete First Model - Julia Knet Multilayer Perceptron.ipynb --- ...l - Julia Knet Multilayer Perceptron.ipynb | 596 ------------------ 1 file changed, 596 deletions(-) delete mode 100644 multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb diff --git a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb b/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb deleted file mode 100644 index 66663cc..0000000 --- a/multi-layer-perceptron/First Model - Julia Knet Multilayer Perceptron.ipynb +++ /dev/null @@ -1,596 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0e969ff0", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "061ec219", - "metadata": {}, - "outputs": [], - "source": [ - "using MLDatasets: MNIST\n", - "using Knet, IterTools, MLDatasets\n", - "using Base.Iterators: take, drop, cycle, Stateful\n", - "using Printf\n", - "using Knet:minibatch\n", - "using Knet:minimize\n", - "using Knet\n", - "using Knet: Param\n", - "using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", - "using Flatten\n", - "using Flux.Data;\n", - "using Flux, Statistics\n", - "import Flatten: flattenable" - ] - }, - { - "cell_type": "markdown", - "id": "4be16139", - "metadata": {}, - "source": [ - "# Processing Data" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "cb077122", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "28×28×60000 Array{Float32, 3}\n", - "60000-element Vector{Int64}\n", - "28×28×10000 Array{Float32, 3}\n", - "10000-element Vector{Int64}\n" - ] - } - ], - "source": [ - "# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. \n", - "xtrn,ytrn = MNIST.traindata(Float32)\n", - "xtst,ytst = MNIST.testdata(Float32)\n", - "println.(summary.((xtrn,ytrn,xtst,ytst)));" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "1c9abdfc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(\"784×60000 Matrix{Float32}\", \"784×10000 Matrix{Float32}\")\n" - ] - } - ], - "source": [ - "xtrn = reshape(xtrn, 784, 60000 ) \n", - "xtst = reshape(xtst, 784, 10000 )\n", - "println(summary.((xtrn, xtst))) # can see the data that is flattened " - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "5b120fe0", - "metadata": {}, - "outputs": [], - "source": [ - "#Preprocessing targets: one hot vectors\n", - "# ytrn = onehotbatch(ytrn, 0:9)\n", - "# ytst = onehotbatch(ytst, 0:9)" - ] - }, - { - "cell_type": "markdown", - "id": "ca5b6908", - "metadata": {}, - "source": [ - "# Batch Processing" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "3d8bc019", - "metadata": {}, - "outputs": [], - "source": [ - "train_loader = DataLoader((xtrn, ytrn), batchsize=128);" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08be497f", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "1d6b4034", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_loader" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "cc30960a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "784×128 Matrix{Float32}\n", - "128-element Vector{Int64}\n" - ] - } - ], - "source": [ - "(x,y) = first(train_loader) #gives the first minibatch from training dataset\n", - "println.(summary.((x,y)));" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "224f849d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_loader" - ] - }, - { - "cell_type": "markdown", - "id": "4748aa9a", - "metadata": {}, - "source": [ - "# Define Dense Layer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5f0043d2", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "830e10d0", - "metadata": {}, - "outputs": [], - "source": [ - "struct Dense1; w; b; f; end\n", - "Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f)\n", - "(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b)" - ] - }, - { - "cell_type": "markdown", - "id": "7f50ff27", - "metadata": { - "scrolled": true - }, - "source": [ - "# Define Chain Layer\n" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "6507022b", - "metadata": {}, - "outputs": [], - "source": [ - "# Define a chain of layers and a loss function:\n", - "struct Chain; layers; end\n", - "(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)\n", - "(c::Chain)(x,y) = nll(c(x),y)" - ] - }, - { - "cell_type": "markdown", - "id": "588d08f8", - "metadata": {}, - "source": [ - "# Define the Model" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "78ca61e1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Chain((Dense1(P(Matrix{Float32}(100,784)), P(Vector{Float32}(100)), Knet.Ops20.relu), Dense1(P(Matrix{Float32}(10,100)), P(Vector{Float32}(10)), Knet.Ops20.relu)))" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = Chain((Dense1(784, 100), Dense1(100, 10)))" - ] - }, - { - "cell_type": "markdown", - "id": "12fa672b", - "metadata": {}, - "source": [ - "# Training" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "01ea3152", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "10×128 Matrix{Float32}:\n", - " 0.214272 0.269956 0.0 … 0.0 0.0 0.0\n", - " 0.0 0.0 0.0 0.0 0.0 0.0688322\n", - " 0.447498 0.223454 0.391645 0.0 0.0 0.0\n", - " 0.0 0.0 0.209459 0.0 0.0 0.0\n", - " 0.11944 0.0 0.0 0.124931 0.0 0.0\n", - " 0.480624 0.0618137 0.0567281 … 0.337361 0.321017 0.279554\n", - " 0.567259 0.579033 0.0 0.0177205 0.0287222 0.0821771\n", - " 0.519518 0.900291 0.177696 0.304197 0.39036 0.436304\n", - " 0.0 0.0 0.00714952 0.0 0.0 0.0\n", - " 0.0 0.0 0.0 0.0 0.0 0.0" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model(x) #checking if training is working" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "0cb35eba", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:04/00:04, 122.72i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:03, 162.35i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 1, loss is 0.223415, accuracy is 0.935518 4.726154 seconds (327.16 k allocations: 1.256 GiB, 15.36% gc time, 2.45% compilation time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 185.82i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:02, 205.09i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 2, loss is 0.161276, accuracy is 0.956081 3.293366 seconds (311.97 k allocations: 1.254 GiB, 4.57% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 161.56i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:03, 169.43i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 3, loss is 0.129508, accuracy is 0.967232 3.659111 seconds (312.32 k allocations: 1.254 GiB, 4.26% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 151.47i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:03, 162.74i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 4, loss is 0.111773, accuracy is 0.973371 3.976988 seconds (312.50 k allocations: 1.254 GiB, 4.70% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 163.64i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:02, 247.97i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 5, loss is 0.100974, accuracy is 0.978142 3.635032 seconds (312.30 k allocations: 1.254 GiB, 5.48% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 156.73i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:03, 172.15i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 6, loss is 0.093855, accuracy is 0.981822 3.744225 seconds (312.30 k allocations: 1.254 GiB, 4.23% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 162.35i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:03, 164.20i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 7, loss is 0.089960, accuracy is 0.984115 3.644892 seconds (312.31 k allocations: 1.254 GiB, 3.83% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 170.55i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:02, 196.92i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 8, loss is 0.087794, accuracy is 0.986242 3.537420 seconds (312.31 k allocations: 1.254 GiB, 5.22% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:04/00:04, 109.85i/s] \n", - "┣ ┫ [0.21%, 1/469, 00:00/00:03, 179.85i/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 9, loss is 0.085906, accuracy is 0.987814 5.220644 seconds (312.67 k allocations: 1.254 GiB, 4.46% gc time)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 157.98i/s] \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch is 10, loss is 0.086418, accuracy is 0.989034 3.682666 seconds (312.32 k allocations: 1.254 GiB, 3.70% gc time)\n", - "Overall Loss: 0.086418\n", - "Overall Accuracy: 0.989034155001202\n" - ] - } - ], - "source": [ - "loss(xtst, ytst) = nll(model(xtst), ytst)\n", - "evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss \n", - "\n", - " for epoch in 1:10\n", - " @time begin\n", - " progress!(adam(model, train_loader; lr = 1e-3))\n", - " @printf(\"epoch is %d, loss is %f, accuracy is %f\", epoch, (evalcb()), accuracy(model, train_loader))\n", - " end \n", - " end \n", - "\n", - "\n", - "println(\"Overall Loss: \", evalcb()) \n", - "println(\"Overall Accuracy: \", accuracy(model, train_loader))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7af4c1eb", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a28d3616", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1cd9c1b7", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d030c34c", - "metadata": {}, - "outputs": [], - "source": [ - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adc056c1", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1673271f", - "metadata": {}, - "outputs": [], - "source": [ - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "686f14b8", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a56aaea6", - "metadata": {}, - "outputs": [], - "source": [ - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17a0901a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "83327f51", - "metadata": {}, - "outputs": [], - "source": [ - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68debc4d", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Julia 1.8.2", - "language": "julia", - "name": "julia-1.8" - }, - "language_info": { - "file_extension": ".jl", - "mimetype": "application/julia", - "name": "julia", - "version": "1.8.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 3cf7d8d43a17266d8cf03851716f327f8abd7fac Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Tue, 1 Nov 2022 09:01:10 -0600 Subject: [PATCH 08/26] Add files via upload Updated notebook --- multi-layer-perceptron/Knet_notebook.ipynb | 621 +++++++++++++++++++++ 1 file changed, 621 insertions(+) create mode 100644 multi-layer-perceptron/Knet_notebook.ipynb diff --git a/multi-layer-perceptron/Knet_notebook.ipynb b/multi-layer-perceptron/Knet_notebook.ipynb new file mode 100644 index 0000000..3ae1765 --- /dev/null +++ b/multi-layer-perceptron/Knet_notebook.ipynb @@ -0,0 +1,621 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0e969ff0", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "061ec219", + "metadata": {}, + "outputs": [], + "source": [ + "using MLDatasets: MNIST\n", + "using Knet, IterTools, MLDatasets\n", + "using Dictionaries\n", + "using TimerOutputs\n", + "using TimerOutputs\n", + "using JSON\n", + "using Printf\n", + "using Knet:minibatch\n", + "using Knet:minimize\n", + "using Knet\n", + "using Knet: Param\n", + "using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", + "using Flatten\n", + "using Flux.Data;\n", + "using Flux, Statistics" + ] + }, + { + "cell_type": "markdown", + "id": "4be16139", + "metadata": {}, + "source": [ + "# Processing Data" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "cb077122", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "28×28×60000 Array{Float32, 3}\n", + "60000-element Vector{Int64}\n", + "28×28×10000 Array{Float32, 3}\n", + "10000-element Vector{Int64}\n" + ] + } + ], + "source": [ + "# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. \n", + "xtrn,ytrn = MNIST.traindata(Float32)\n", + "xtst,ytst = MNIST.testdata(Float32)\n", + "println.(summary.((xtrn,ytrn,xtst,ytst)));" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "1c9abdfc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(\"784×60000 Matrix{Float32}\", \"784×10000 Matrix{Float32}\")\n" + ] + } + ], + "source": [ + "xtrn = reshape(xtrn, 784, 60000 ) \n", + "xtst = reshape(xtst, 784, 10000 )\n", + "println(summary.((xtrn, xtst))) # can see the data that is flattened " + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "5b120fe0", + "metadata": {}, + "outputs": [], + "source": [ + "#Preprocessing targets: one hot vectors\n", + "# ytrn = onehotbatch(ytrn, 0:9)\n", + "# ytst = onehotbatch(ytst, 0:9)" + ] + }, + { + "cell_type": "markdown", + "id": "ca5b6908", + "metadata": {}, + "source": [ + "# Batch Processing" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "3d8bc019", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [7, 2, 1, 0, 4, 1, 4, 9, 5, 9 … 7, 8, 9, 0, 1, 2, 3, 4, 5, 6]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader((xtrn, ytrn), batchsize=128);\n", + "test_loader = DataLoader((xtst, ytst), batchsize = 128)" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "08be497f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "79" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "length(test_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d6b4034", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "cc30960a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "784×128 Matrix{Float32}\n", + "128-element Vector{Int64}\n" + ] + } + ], + "source": [ + "(x,y) = first(train_loader) #gives the first minibatch from training dataset\n", + "println.(summary.((x,y)));" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "224f849d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "4748aa9a", + "metadata": {}, + "source": [ + "# Define Dense Layer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f0043d2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "830e10d0", + "metadata": {}, + "outputs": [], + "source": [ + "struct Dense1; w; b; f; end\n", + "Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f)\n", + "(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b)" + ] + }, + { + "cell_type": "markdown", + "id": "7f50ff27", + "metadata": { + "scrolled": true + }, + "source": [ + "# Define Chain Layer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "6507022b", + "metadata": {}, + "outputs": [], + "source": [ + "# Define a chain of layers and a loss function:\n", + "struct Chain; layers; end\n", + "(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)\n", + "(c::Chain)(x,y) = nll(c(x),y)" + ] + }, + { + "cell_type": "markdown", + "id": "588d08f8", + "metadata": {}, + "source": [ + "# Define the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "78ca61e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain((Dense1(P(Matrix{Float32}(100,784)), P(Vector{Float32}(100)), Knet.Ops20.relu), Dense1(P(Matrix{Float32}(10,100)), P(Vector{Float32}(10)), Knet.Ops20.relu), identity))" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Chain((Dense1(784, 100), Dense1(100, 10), identity))" + ] + }, + { + "cell_type": "markdown", + "id": "12fa672b", + "metadata": {}, + "source": [ + "# Training and Evaluating" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "01ea3152", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "10×128 Matrix{Float32}:\n", + " 0.146128 0.251807 0.0944941 0.0 … 0.0 0.0 0.166544\n", + " 0.908352 0.975714 0.392101 1.07251 1.10933 0.714653 1.1014\n", + " 0.0 0.0 0.0 0.0 0.0 0.37319 0.394873\n", + " 0.0 0.0403631 0.227371 0.218725 0.187911 0.0 0.0\n", + " 0.0 0.0 0.0943293 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.415859 … 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.459048 0.0691352\n", + " 0.189706 0.802615 0.69828 0.0941242 0.455495 0.444184 0.604847\n", + " 0.834217 0.47262 0.0 0.273323 0.575527 0.126404 0.402153\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(x) #checking if training is working" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "0cb35eba", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: redefinition of constant to. This may fail, cause incorrect answers, or produce other errors.\n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 350.61i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 381.14i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 383.71i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 393.42i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 362.42i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 352.92i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 322.49i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 355.11i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 357.56i/s] \n", + "┣████████████████████┫ [100.00%, 469/469, 00:01/00:01, 375.06i/s] \n" + ] + } + ], + "source": [ + "\n", + "loss(xtst, ytst) = nll(model(xtst), ytst)\n", + "evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss \n", + "const to = TimerOutput() # creating a TimerOutput, keeps track of everything\n", + "\n", + "\n", + "@timeit to \"Train Total\" begin\n", + " for epoch in 1:10\n", + " train_epoch = epoch > 1 ? \"train_epoch\" : \"train_ji\"\n", + " @timeit to train_epoch begin\n", + " progress!(adam(model, train_loader; lr = 1e-3))\n", + " end\n", + " \n", + " evaluation = epoch > 1 ? \"evaluation\" : \"eval_jit\"\n", + " @timeit to evaluation begin\n", + " accuracy(model, test_loader)\n", + " end \n", + " \n", + " end \n", + "end \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "7af4c1eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overall Loss: 0.09791139\n", + "Overall Accuracy: 0.9695121951219512\n" + ] + } + ], + "source": [ + "\n", + "final_train_loss = evalcb()\n", + "final_eval_accuracy = accuracy(model, test_loader)\n", + "\n", + "# see the overall loss\n", + "println(\"Overall Loss: \", final_train_loss) \n", + "println(\"Overall Accuracy: \", final_eval_accuracy ) #see the overall accuracy\n" + ] + }, + { + "cell_type": "markdown", + "id": "5eb13576", + "metadata": {}, + "source": [ + "# Getting Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "a28d3616", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0m\u001b[1m ─────────────────────────────────────────────────────────\u001b[22m\n", + "\u001b[0m\u001b[1m \u001b[22m Time Allocations \n", + " ─────────────── ───────────────\n", + " Total measured: 22.9s 10.6GiB \n", + "\n", + " Section ncalls time %tot alloc %tot\n", + " ─────────────────────────────────────────────────────────\n", + " Train Total 1 13.5s 100.0% 10.5GiB 100.0%\n", + " train_epoch 9 11.6s 85.8% 9.15GiB 86.7%\n", + " train_ji 1 1.35s 10.0% 1.02GiB 9.7%\n", + " evaluation 9 505ms 3.7% 352MiB 3.3%\n", + " eval_jit 1 60.6ms 0.4% 39.1MiB 0.4%\n", + "\u001b[0m\u001b[1m ─────────────────────────────────────────────────────────\u001b[22m" + ] + } + ], + "source": [ + "\n", + "\n", + "show(to, allocations = true, compact = true) #see the time it took for training and evaluating the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cd9c1b7", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "d030c34c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.2902278333333332" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#average epoch training time converted to seconds from nanoseconds\n", + "average_train_epoch_time = (TimerOutputs.time(to[\"Train Total\"][\"train_epoch\"]))/(1e+9 *9)\n", + "total_train_time = TimerOutputs.time(to[\"Train Total\"])/(1e+9)\n", + "average_batch_inference_time = TimerOutputs.time(to[\"Train Total\"][\"evaluation\"])/(length(test_loader)*1e+6*9)\n", + "\n", + "average_train_epoch_time\n" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "adc056c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dict{String, Any} with 9 entries:\n", + " \"task\" => \"classifcation\"\n", + " \"framework_name\" => \"Knet\"\n", + " \"final_evaluation_accuracy\" => 0.969512\n", + " \"average_epoch_training_time\" => 1.29023\n", + " \"total_training_time\" => 13.5294\n", + " \"final_training_loss\" => 0.0979114\n", + " \"model_name\" => \"MLP\"\n", + " \"dataset\" => \"MNIST Digits\"\n", + " \"average_batch_inference_time\" => 0.709791" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "#getting dictionary to format the metrics\n", + "metrics = Dict(\"model_name\" => \"MLP\",\n", + " \"framework_name\"=>\"Knet\",\n", + " \"dataset\" => \"MNIST Digits\", \n", + " \"task\" => \"classifcation\",\n", + " \"average_epoch_training_time\" => average_train_epoch_time,\n", + " \"total_training_time\" => total_train_time,\n", + " \"average_batch_inference_time\" => average_batch_inference_time,\n", + " \"final_training_loss\" => final_train_loss,\n", + " \"final_evaluation_accuracy\" => final_eval_accuracy\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "1673271f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dict{String, Any} with 9 entries:\n", + " \"task\" => \"classifcation\"\n", + " \"framework_name\" => \"Knet\"\n", + " \"final_evaluation_accuracy\" => 0.969512\n", + " \"average_epoch_training_time\" => 1.29023\n", + " \"total_training_time\" => 13.5294\n", + " \"final_training_loss\" => 0.0979114\n", + " \"model_name\" => \"MLP\"\n", + " \"dataset\" => \"MNIST Digits\"\n", + " \"average_batch_inference_time\" => 0.709791" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "stringdata = JSON.json(metrics)\n", + "\n", + "#will allow the metrics to be entered into a file \n", + "\n", + "open(\"M1-Knet-mlp.json\", \"w\") do f\n", + " write(f, stringdata)\n", + " end \n", + "\n", + "dict2 = Dict()\n", + "open(\"M1-Knet-mlp.json\", \"r\") do f\n", + " global dict2\n", + " dict2 = JSON.parse(f)\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "686f14b8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"C:\\\\Users\\\\Yash\"" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pwd() #checking directory " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a56aaea6", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17a0901a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83327f51", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68debc4d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83728e05", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a6cb1b69b363781875bba06095dacd811c3c4fb0 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 7 Nov 2022 14:20:18 -0700 Subject: [PATCH 09/26] Update KNet_test.jl --- multi-layer-perceptron/KNet_test.jl | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/multi-layer-perceptron/KNet_test.jl b/multi-layer-perceptron/KNet_test.jl index 5e3f457..2f88ae5 100644 --- a/multi-layer-perceptron/KNet_test.jl +++ b/multi-layer-perceptron/KNet_test.jl @@ -1,43 +1,28 @@ using MLDatasets: MNIST -using Knet, IterTools, MLDatasets +using Knet, IterTools using Dictionaries using TimerOutputs -using TimerOutputs using JSON using Printf -using Knet:minibatch -using Knet:minimize -using Knet -using Knet: Param -using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu using Flatten -using Flux.Data; using Flux, Statistics # This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. xtrn,ytrn = MNIST.traindata(Float32) xtst,ytst = MNIST.testdata(Float32) -println.(summary.((xtrn,ytrn,xtst,ytst))); + xtrn = reshape(xtrn, 784, 60000 ) xtst = reshape(xtst, 784, 10000 ) -println(summary.((xtrn, xtst))) # can see the data that is flattened -#Preprocessing targets: one hot vectors + +#Preprocessing targets: one hot vectors, commented this out, as this does not correctly with KNet # ytrn = onehotbatch(ytrn, 0:9) # ytst = onehotbatch(ytst, 0:9) train_loader = DataLoader((xtrn, ytrn), batchsize=128); test_loader = DataLoader((xtst, ytst), batchsize = 128) -length(test_loader) - - - -(x,y) = first(train_loader) #gives the first minibatch from training dataset -println.(summary.((x,y))); - - @@ -52,7 +37,6 @@ struct Chain; layers; end model = Chain((Dense1(784, 100), Dense1(100, 10), identity)) -model(x) #checking if training is working loss(xtst, ytst) = nll(model(xtst), ytst) @@ -116,7 +100,7 @@ metrics = Dict("model_name" => "MLP", stringdata = JSON.json(metrics) -#will allow the metrics to be entered into a file +#will allow the metrics to be entered into a JSON file, which can be checked open("M1-Knet-mlp.json", "w") do f write(f, stringdata) From 6d20b31e53274023690708ba7637d37365dd3daf Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 7 Nov 2022 14:21:50 -0700 Subject: [PATCH 10/26] Update KNet_test.jl --- multi-layer-perceptron/KNet_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/multi-layer-perceptron/KNet_test.jl b/multi-layer-perceptron/KNet_test.jl index 2f88ae5..446aec0 100644 --- a/multi-layer-perceptron/KNet_test.jl +++ b/multi-layer-perceptron/KNet_test.jl @@ -16,7 +16,7 @@ xtrn = reshape(xtrn, 784, 60000 ) xtst = reshape(xtst, 784, 10000 ) -#Preprocessing targets: one hot vectors, commented this out, as this does not correctly with KNet +#Preprocessing targets: one hot vectors, commented this out, as this does not work correctly with KNet # ytrn = onehotbatch(ytrn, 0:9) # ytst = onehotbatch(ytst, 0:9) From 236db1682864f513a2c8c2c4f31fc68aac999a99 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Sat, 12 Nov 2022 20:38:42 -0700 Subject: [PATCH 11/26] Update KNet_test.jl --- multi-layer-perceptron/KNet_test.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/multi-layer-perceptron/KNet_test.jl b/multi-layer-perceptron/KNet_test.jl index 446aec0..57f5632 100644 --- a/multi-layer-perceptron/KNet_test.jl +++ b/multi-layer-perceptron/KNet_test.jl @@ -1,10 +1,15 @@ using MLDatasets: MNIST -using Knet, IterTools +using Knet, IterTools, MLDatasets using Dictionaries using TimerOutputs using JSON using Printf +using Knet:minibatch +using Knet:minimize +using Knet: Param +using Knet: dir, accuracy, progress, sgd, gc, Data, nll, relu using Flatten +using Flux.Data; using Flux, Statistics # This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. From c3dd60701836a169eb16740fb71c9ddafbf42e99 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:40:49 -0700 Subject: [PATCH 12/26] Create convolutional neural network --- convolutional neural network | 1 + 1 file changed, 1 insertion(+) create mode 100644 convolutional neural network diff --git a/convolutional neural network b/convolutional neural network new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/convolutional neural network @@ -0,0 +1 @@ + From 6ff865a2b6356cbe329b7a7022ab3811c10501e0 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:41:06 -0700 Subject: [PATCH 13/26] Delete convolutional neural network --- convolutional neural network | 1 - 1 file changed, 1 deletion(-) delete mode 100644 convolutional neural network diff --git a/convolutional neural network b/convolutional neural network deleted file mode 100644 index 8b13789..0000000 --- a/convolutional neural network +++ /dev/null @@ -1 +0,0 @@ - From 36948128ad2412dd321badfecf57b1bde86ccc7f Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:43:03 -0700 Subject: [PATCH 14/26] Create ResNet Model V2- Knet --- convolutional neural network/ResNet Model V2- Knet | 1 + 1 file changed, 1 insertion(+) create mode 100644 convolutional neural network/ResNet Model V2- Knet diff --git a/convolutional neural network/ResNet Model V2- Knet b/convolutional neural network/ResNet Model V2- Knet new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/convolutional neural network/ResNet Model V2- Knet @@ -0,0 +1 @@ + From 9dc417dd20df114fa866f3314cf5faacb4c28576 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:43:17 -0700 Subject: [PATCH 15/26] Delete ResNet Model V2- Knet --- convolutional neural network/ResNet Model V2- Knet | 1 - 1 file changed, 1 deletion(-) delete mode 100644 convolutional neural network/ResNet Model V2- Knet diff --git a/convolutional neural network/ResNet Model V2- Knet b/convolutional neural network/ResNet Model V2- Knet deleted file mode 100644 index 8b13789..0000000 --- a/convolutional neural network/ResNet Model V2- Knet +++ /dev/null @@ -1 +0,0 @@ - From bf74dd1fbf13401aae9cff1077903be0398113c5 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:45:25 -0700 Subject: [PATCH 16/26] CREATE README.md --- convolutional neural network/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 convolutional neural network/README.md diff --git a/convolutional neural network/README.md b/convolutional neural network/README.md new file mode 100644 index 0000000..ff8e71d --- /dev/null +++ b/convolutional neural network/README.md @@ -0,0 +1 @@ +Implementing ResNet V2 Model, utilizing Julia Knet From 8ee04f235ea2a9eb012031d51ee2fd3ee9957839 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:46:43 -0700 Subject: [PATCH 17/26] Add files via upload An issue with applying a dataset to the model. --- .../ResNet Model V2- Knet.ipynb | 294 ++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 convolutional neural network/ResNet Model V2- Knet.ipynb diff --git a/convolutional neural network/ResNet Model V2- Knet.ipynb b/convolutional neural network/ResNet Model V2- Knet.ipynb new file mode 100644 index 0000000..3b6a630 --- /dev/null +++ b/convolutional neural network/ResNet Model V2- Knet.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "eeccffd3", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "526f2ea0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: using Data.Data in module Main conflicts with an existing identifier.\n" + ] + } + ], + "source": [ + "using MLDatasets: CIFAR10\n", + "using MLDataUtils\n", + "using Knet, IterTools\n", + "using Dictionaries\n", + "using TimerOutputs\n", + "using JSON\n", + "using Printf\n", + "using Knet:minibatch\n", + "using Knet:minimize\n", + "using Knet: Param\n", + "using Knet: dir, accuracy, progress, sgd, gc, Data, nll, relu, conv4\n", + "using Flatten\n", + "using Flux.Data;\n", + "using Flux, Statistics\n", + "using Statistics: mean, var\n", + "using Functors" + ] + }, + { + "cell_type": "markdown", + "id": "200c3cb6", + "metadata": {}, + "source": [ + "# Processing Data/Batch Processing" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b0aa0ba8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32×32×3×45000 Array{Float32, 4}\n", + "45000-element Vector{Int64}\n", + "32×32×3×5000 Array{Float32, 4}\n", + "5000-element Vector{Int64}\n", + "32×32×3×10000 Array{Float32, 4}\n", + "10000-element Vector{Int64}\n" + ] + } + ], + "source": [ + "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", + "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", + "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", + "xtst,ytst = CIFAR10.testdata(Float32)\n", + "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7db40f98", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataLoader{Tuple{Array{Float32, 4}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}(([0.61960787 0.59607846 … 0.23921569 0.21176471; 0.62352943 0.5921569 … 0.19215687 0.21960784; … ; 0.49411765 0.49019608 … 0.11372549 0.13333334; 0.45490196 0.46666667 … 0.078431375 0.08235294;;; 0.4392157 0.4392157 … 0.45490196 0.41960785; 0.43529412 0.43137255 … 0.4 0.4117647; … ; 0.35686275 0.35686275 … 0.32156864 0.32941177; 0.33333334 0.34509805 … 0.2509804 0.2627451;;; 0.19215687 0.2 … 0.65882355 0.627451; 0.18431373 0.15686275 … 0.5803922 0.58431375; … ; 0.14117648 0.1254902 … 0.49411765 0.5058824; 0.12941177 0.13333334 … 0.41960785 0.43137255;;;; 0.92156863 0.93333334 … 0.32156864 0.33333334; 0.90588236 0.92156863 … 0.18039216 0.24313726; … ; 0.9137255 0.9254902 … 0.7254902 0.7058824; 0.9098039 0.92156863 … 0.73333335 0.7294118;;; 0.92156863 0.93333334 … 0.3764706 0.39607844; 0.90588236 0.92156863 … 0.22352941 0.29411766; … ; 0.9137255 0.9254902 … 0.78431374 0.7647059; 0.9098039 0.92156863 … 0.7921569 0.78431374;;; 0.92156863 0.93333334 … 0.32156864 0.3254902; 0.90588236 0.92156863 … 0.14117648 0.1882353; … ; 0.9137255 0.9254902 … 0.76862746 0.7490196; 0.9098039 0.92156863 … 0.78431374 0.78039217;;;; 0.61960787 0.6666667 … 0.09019608 0.10980392; 0.61960787 0.6745098 … 0.105882354 0.11764706; … ; 0.92941177 0.9647059 … 0.015686275 0.015686275; 0.93333334 0.9647059 … 0.019607844 0.02745098;;; 0.74509805 0.78431374 … 0.13333334 0.16078432; 0.73333335 0.78039217 … 0.14901961 0.16862746; … ; 0.9372549 0.9647059 … 0.023529412 0.019607844; 0.94509804 0.96862745 … 0.02745098 0.03137255;;; 0.87058824 0.8980392 … 0.15294118 0.18431373; 0.85490197 0.8862745 … 0.16862746 0.19607843; … ; 0.9529412 0.98039216 … 0.011764706 0.011764706; 0.9647059 0.9843137 … 0.011764706 0.02745098;;;; … ;;;; 0.078431375 0.08235294 … 0.12941177 0.12156863; 0.07450981 0.078431375 … 0.13333334 0.1254902; … ; 0.047058824 0.039215688 … 0.105882354 0.101960786; 0.050980393 0.047058824 … 0.09803922 0.09803922;;; 0.05882353 0.0627451 … 0.09803922 0.09019608; 0.05490196 0.0627451 … 0.101960786 0.09411765; … ; 0.043137256 0.03529412 … 0.09411765 0.09019608; 0.047058824 0.043137256 … 0.08627451 0.078431375;;; 0.047058824 0.050980393 … 0.05490196 0.047058824; 0.043137256 0.050980393 … 0.05882353 0.050980393; … ; 0.03529412 0.02745098 … 0.21960784 0.20784314; 0.039215688 0.03529412 … 0.18431373 0.18431373;;;; 0.09803922 0.047058824 … 0.40392157 0.37254903; 0.05882353 0.078431375 … 0.40784314 0.37254903; … ; 0.36078432 0.58431375 … 0.3882353 0.37254903; 0.29411766 0.40784314 … 0.36078432 0.36078432;;; 0.15686275 0.09803922 … 0.5176471 0.49411765; 0.14117648 0.14509805 … 0.5137255 0.48235294; … ; 0.44313726 0.65882355 … 0.49803922 0.48235294; 0.34901962 0.45882353 … 0.4745098 0.47058824;;; 0.047058824 0.023529412 … 0.3254902 0.30588236; 0.011764706 0.02745098 … 0.3254902 0.29803923; … ; 0.4392157 0.69411767 … 0.32941177 0.31764707; 0.36078432 0.5137255 … 0.30980393 0.3137255;;;; 0.28627452 0.27058825 … 0.4509804 0.45490196; 0.38431373 0.32941177 … 0.48235294 0.4745098; … ; 0.5294118 0.2784314 … 0.25882354 0.26666668; 0.79607844 0.47058824 … 0.105882354 0.105882354;;; 0.30588236 0.28627452 … 0.4745098 0.47058824; 0.40392157 0.34901962 … 0.4862745 0.47843137; … ; 0.58431375 0.32156864 … 0.25490198 0.25490198; 0.84313726 0.52156866 … 0.105882354 0.101960786;;; 0.29411766 0.27450982 … 0.35686275 0.3529412; 0.44313726 0.38039216 … 0.37254903 0.36862746; … ; 0.6039216 0.3137255 … 0.23137255 0.22745098; 0.8745098 0.5294118 … 0.105882354 0.101960786], [3, 8, 8, 0, 6, 6, 1, 6, 3, 1 … 7, 0, 3, 5, 3, 8, 3, 5, 1, 7]), 256, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader((xtrn, ytrn), batchsize=256)\n", + "val_loader = DataLoader((xval, yval), batchsize = 256)\n", + "test_loader = DataLoader((xtst, ytst), batchsize = 256)" + ] + }, + { + "cell_type": "markdown", + "id": "a50fbde6", + "metadata": {}, + "source": [ + "# Define Struct ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3fbd3c98", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNetLayer\n", + " conv1::Flux.Conv\n", + " conv2::Flux.Conv\n", + " bn1::BatchNorm\n", + " bn2::BatchNorm\n", + " activation_function::Function\n", + " in_channels::Int\n", + " channels::Int\n", + " stride::Int \n", + "end \n", + "\n", + "@functor ResNetLayer (conv1, conv2, bn1, bn2)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "28360856", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNetLayer" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Constructor\n", + "function ResNetLayer(in_channels::Int, channels::Int, activation_function = relu, stride = 1)\n", + " bn1 = BatchNorm(in_channels)\n", + " conv1 = Flux.Conv((3,3), in_channels => channels, activation_function; stride = stride)\n", + " bn2 = BatchNorm(channels)\n", + " conv2 = Flux.Conv((3,3), channels => channels, activation_function; stride = stride)\n", + " return ResNetLayer(conv1, conv2, bn1, bn2, activation_function, in_channels, channels, stride)\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "779746c6", + "metadata": {}, + "source": [ + "# Define Residual Identity" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e83efd28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "residual_identity (generic function with 1 method)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", + " (w, h, c, b) = size(x)\n", + " stride = layer.stride\n", + " if stride > 1\n", + " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", + " \n", + " # Strided downsample\n", + " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", + " else\n", + " x_id = x\n", + " end\n", + "\n", + " channels = layer.channels\n", + " in_channels = layer.in_channels\n", + " if in_channels < channels\n", + " # Zero padding on extra channels\n", + " (w, h, c, b) = size(x_id)\n", + " pad = zeros(w, h, channels - in_channels, b)\n", + " x_id = cat(x_id, pad; dims=3)\n", + " elseif in_channels > channels\n", + " error(\"in_channels > out_channels not supported\")\n", + " end\n", + " return x_id\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "55f1a6b8", + "metadata": {}, + "source": [ + "# Forward Function" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a484933c", + "metadata": {}, + "outputs": [], + "source": [ + "function (self::ResNetLayer)(x::AbstractArray)\n", + " identity = residual_identity(self, x)\n", + " z = self.bn1(x)\n", + " z = self.activation_function(z)\n", + " z = self.conv1(z)\n", + " z = self.bn2(z)\n", + " z = self.activation_function(z)\n", + " z = self.conv2(z)\n", + " y = z + identity \n", + " return y\n", + "end " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c9d89029", + "metadata": {}, + "outputs": [ + { + "ename": "LoadError", + "evalue": "MethodError: no method matching ResNetLayer(::Int64, ::Int64; stride=2)\n\u001b[0mClosest candidates are:\n\u001b[0m ResNetLayer(::Int64, ::Int64) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ...", + "output_type": "error", + "traceback": [ + "MethodError: no method matching ResNetLayer(::Int64, ::Int64; stride=2)\n\u001b[0mClosest candidates are:\n\u001b[0m ResNetLayer(::Int64, ::Int64) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ...", + "", + "Stacktrace:", + " [1] top-level scope", + " @ In[8]:2", + " [2] eval", + " @ .\\boot.jl:368 [inlined]", + " [3] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", + " @ Base .\\loading.jl:1428" + ] + } + ], + "source": [ + "# Example\n", + "l = ResNetLayer(3, 10, stride = 2)\n", + "x = randn(Float32, (64, 64, 3, 2))\n", + "y = l(x)\n", + "size(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d213e04", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 5cafcac8f6d0ba5c57ee823ffff28431e555bafb Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Sun, 11 Dec 2022 13:09:37 -0700 Subject: [PATCH 18/26] Add files via upload --- .../ResNet Model V2- Knet.ipynb | 267 ++++++++++++++++-- 1 file changed, 239 insertions(+), 28 deletions(-) diff --git a/convolutional neural network/ResNet Model V2- Knet.ipynb b/convolutional neural network/ResNet Model V2- Knet.ipynb index 3b6a630..f6d02c2 100644 --- a/convolutional neural network/ResNet Model V2- Knet.ipynb +++ b/convolutional neural network/ResNet Model V2- Knet.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "eeccffd3", + "id": "3a44a186", "metadata": {}, "source": [ "# Imports" @@ -11,7 +11,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "526f2ea0", + "id": "7a88e0c0", "metadata": {}, "outputs": [ { @@ -43,7 +43,7 @@ }, { "cell_type": "markdown", - "id": "200c3cb6", + "id": "bb72d038", "metadata": {}, "source": [ "# Processing Data/Batch Processing" @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "b0aa0ba8", + "id": "4cf6721f", "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "7db40f98", + "id": "f49ad3e2", "metadata": {}, "outputs": [ { @@ -101,7 +101,7 @@ }, { "cell_type": "markdown", - "id": "a50fbde6", + "id": "1f2e30a5", "metadata": {}, "source": [ "# Define Struct ResNetLayer" @@ -110,7 +110,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "3fbd3c98", + "id": "0b16fa40", "metadata": {}, "outputs": [], "source": [ @@ -131,7 +131,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "28360856", + "id": "055ba6b1", "metadata": {}, "outputs": [ { @@ -147,7 +147,7 @@ ], "source": [ "# Constructor\n", - "function ResNetLayer(in_channels::Int, channels::Int, activation_function = relu, stride = 1)\n", + "function ResNetLayer(in_channels::Int, channels::Int; activation_function = relu, stride = 1)\n", " bn1 = BatchNorm(in_channels)\n", " conv1 = Flux.Conv((3,3), in_channels => channels, activation_function; stride = stride)\n", " bn2 = BatchNorm(channels)\n", @@ -158,7 +158,7 @@ }, { "cell_type": "markdown", - "id": "779746c6", + "id": "1fbe7b50", "metadata": {}, "source": [ "# Define Residual Identity" @@ -167,7 +167,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "e83efd28", + "id": "f54242be", "metadata": {}, "outputs": [ { @@ -210,7 +210,7 @@ }, { "cell_type": "markdown", - "id": "55f1a6b8", + "id": "8aac0568", "metadata": {}, "source": [ "# Forward Function" @@ -218,8 +218,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "id": "a484933c", + "execution_count": 10, + "id": "3881663f", "metadata": {}, "outputs": [], "source": [ @@ -238,39 +238,250 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "c9d89029", + "execution_count": 11, + "id": "316c8d7f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNetLayer(Conv((3, 3), 3 => 10, relu, stride=2), Conv((3, 3), 10 => 10, relu, stride=2), BatchNorm(3), BatchNorm(10), Knet.Ops20.relu, 3, 10, 2, 0)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Example\n", + "l = ResNetLayer(3, 10; stride = 2, pad = 0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3db6251e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "64×64×3×2 Array{Float32, 4}:\n", + "[:, :, 1, 1] =\n", + " 1.93723 -0.0038152 1.00776 … -0.268737 0.772344 0.439853\n", + " -1.11342 0.888429 -0.394805 -0.983046 -1.55638 2.38319\n", + " -1.01276 -0.189184 0.00920959 0.082171 -0.282 0.785619\n", + " -0.392059 -0.184873 -1.81964 -0.0519429 0.106289 -0.736538\n", + " 0.19757 0.467828 0.448932 -0.36157 -0.879078 -0.65364\n", + " -0.972052 -0.00179285 -1.44282 … -0.123626 -0.00435947 -0.531083\n", + " -0.612597 -0.171173 0.14255 0.692491 0.548423 0.683022\n", + " 1.25817 -0.170542 0.845562 -0.803024 -0.623968 -2.29548\n", + " -0.266937 1.96865 -1.98532 0.816746 0.717069 -1.49578\n", + " -0.0521157 0.142611 1.59346 -2.30004 -1.42817 -1.13307\n", + " 2.12283 1.47045 -1.58077 … -0.352557 -0.600046 -0.338219\n", + " 1.05403 -0.0555598 -1.05525 -0.127607 0.857893 -1.35023\n", + " -0.856175 -0.982227 -0.988386 1.138 -0.288336 0.735353\n", + " ⋮ ⋱ \n", + " 0.755552 -1.59975 1.23963 -1.50395 0.264384 0.718516\n", + " -0.383118 -0.909808 -0.787967 -2.88865 1.75656 0.213767\n", + " -0.668381 -0.697628 -0.149039 0.930507 0.81691 -1.18663\n", + " -0.783636 -0.659775 1.69925 … 0.577181 1.49819 -0.849574\n", + " 0.57908 0.412455 2.04327 -1.81523 0.32882 0.674949\n", + " 0.705221 -0.339854 0.336234 1.93957 1.55426 0.331621\n", + " -0.402188 0.992751 -0.470351 -1.55654 -0.613687 1.43096\n", + " 0.285479 -0.611887 0.973315 0.960482 1.10192 -0.117772\n", + " 1.55644 -0.590366 0.337902 … 0.141687 0.136244 -0.431893\n", + " -0.864679 0.514239 0.847648 0.746462 -0.393826 -0.417075\n", + " -0.980209 0.300756 -0.0960876 -1.67165 -0.750691 0.759232\n", + " -0.3949 -2.43357 0.0152755 -1.10364 2.26395 0.14055\n", + "\n", + "[:, :, 2, 1] =\n", + " -2.32355 1.00156 … -0.0701956 -0.836908 -1.89917\n", + " 1.04352 0.17987 -0.957125 -1.84391 -0.546967\n", + " -0.915839 -0.489784 1.34493 1.00435 1.43219\n", + " 0.78521 1.81712 -0.629379 -0.777888 -0.1249\n", + " 0.141494 1.25472 -0.124471 1.10209 1.90448\n", + " 0.000519215 -1.0506 … 0.733363 0.808046 0.098419\n", + " -1.1189 -0.552053 0.565745 -0.368468 0.867749\n", + " -0.84709 -0.346625 -0.881431 -0.0301449 0.416381\n", + " -0.230508 -0.525351 -0.335971 -0.629214 -1.34824\n", + " 0.833687 -1.65521 0.145906 -0.559392 -1.0417\n", + " 0.303557 0.160128 … -0.133876 0.70595 -0.0710908\n", + " -0.945794 -0.372286 0.157022 0.520109 -0.553506\n", + " -0.688362 0.908274 0.365752 0.261675 1.44053\n", + " ⋮ ⋱ \n", + " 1.1257 -0.81426 0.502263 0.776487 -0.00868012\n", + " 0.516202 0.0884552 -0.645942 0.185443 0.279642\n", + " 0.262167 -1.16257 0.385125 -0.0882821 -0.883799\n", + " 0.820539 -0.567252 … -0.825821 -1.52597 0.830642\n", + " 0.404737 -0.416801 0.899569 1.0347 -0.840583\n", + " -0.699052 -0.974771 0.750472 -0.761852 2.14697\n", + " -0.412614 -0.0319316 -0.22123 0.991919 -0.973949\n", + " 0.361078 0.142957 1.24954 -1.93624 0.231768\n", + " 0.0268987 0.981564 … -1.70584 1.15414 0.866787\n", + " -1.46858 1.09673 -1.11747 0.856632 0.456886\n", + " -0.117358 -0.0955382 -0.18092 1.41447 0.497232\n", + " 0.597155 -1.56297 -0.193186 -2.90925 -0.463144\n", + "\n", + "[:, :, 3, 1] =\n", + " -0.661052 -1.43996 0.369589 … -0.140701 -0.554115 0.606915\n", + " -2.05038 0.132761 -0.765542 -0.993403 -0.613025 -0.495729\n", + " -0.0564577 1.71149 0.632053 0.0280497 1.01765 0.416239\n", + " 0.185878 2.91298 -0.158277 1.2439 0.798715 1.43122\n", + " 0.176058 0.770648 -0.0930822 1.37595 0.266198 -0.0775581\n", + " -0.801515 -0.0412963 0.996478 … -2.25544 -0.499216 0.298604\n", + " 0.0717726 -0.622027 1.12376 -0.0182525 -0.0241319 0.790553\n", + " 0.565299 -0.976237 1.10213 0.894315 -0.527155 1.63639\n", + " 1.80068 1.92752 0.63103 1.89302 -0.289506 -0.663129\n", + " 0.76663 -0.247099 -0.130537 -0.487301 -1.56555 0.899509\n", + " -0.839651 -0.225549 -0.986695 … 0.433356 -0.309599 -0.0260602\n", + " -0.299321 -0.28825 -0.244942 0.363609 -1.12806 -0.356537\n", + " -1.22345 -0.458875 1.93982 -1.54144 -1.63941 2.01241\n", + " ⋮ ⋱ \n", + " -0.534667 0.314462 -0.169809 1.831 -0.303435 1.97937\n", + " -0.16555 -1.32027 1.23676 -0.536022 -1.07766 -2.183\n", + " 1.72966 -1.38351 1.02683 -0.584877 0.644302 -0.336674\n", + " 1.04335 0.225725 -1.22778 … 1.07201 -0.639405 1.74068\n", + " -0.656627 -1.3051 -1.22059 0.917509 -0.616721 -0.686058\n", + " -0.778703 0.147619 -1.26244 0.87749 1.15391 -0.0203009\n", + " 0.568605 0.757572 0.370278 0.749421 -0.926598 -0.50433\n", + " 1.85186 -1.02303 0.602001 0.475687 -1.03261 -0.0744234\n", + " 0.44897 0.224272 0.491139 … -0.525537 -0.560198 0.727543\n", + " -1.0191 0.576765 1.76492 -0.787169 0.269278 0.724673\n", + " 2.05386 -0.223774 0.411601 1.19905 -1.51291 -0.35303\n", + " 0.999387 0.107672 1.48222 -0.308091 -0.842177 0.456322\n", + "\n", + "[:, :, 1, 2] =\n", + " -0.478833 -0.949141 -0.909521 … 0.566145 -0.806388 -0.103036\n", + " -0.752491 -0.220061 0.374108 -0.563949 -0.33751 0.291631\n", + " 1.25745 -1.56106 0.157859 0.303748 -0.130795 -0.971778\n", + " 1.12126 1.18819 0.463906 0.327253 1.23034 -1.1625\n", + " -0.540558 0.399086 0.915441 -0.192106 1.34276 -0.712125\n", + " -0.0149236 0.367581 -0.864331 … 0.72153 -1.43657 -0.534063\n", + " 0.532564 0.191133 -0.852938 0.903429 1.27837 1.80162\n", + " 1.30331 0.944996 -0.414618 0.0177562 0.621489 -0.412114\n", + " -1.69508 1.01265 0.529925 -0.475814 -1.45582 -0.012964\n", + " 0.907772 0.481818 -1.85653 -0.944343 1.62906 -0.278814\n", + " -1.28081 0.994347 -0.125699 … 0.853954 0.770404 -0.467438\n", + " 1.58481 -0.202481 -0.330101 -0.482117 -0.0351453 0.464559\n", + " -0.620726 -0.755369 0.745984 -0.786137 0.321789 -0.0923321\n", + " ⋮ ⋱ \n", + " 0.537256 0.125566 1.32509 -0.8879 1.40392 -0.798437\n", + " 2.13567 2.07036 -0.63291 -2.07044 -0.308257 0.778777\n", + " -0.109189 0.42111 -0.268277 0.932572 -0.470156 1.46397\n", + " 0.643967 1.29824 -0.328954 … 0.480872 -0.358953 0.116543\n", + " -0.0951204 -0.0243449 0.85459 -0.980945 2.15267 -0.952766\n", + " -2.00296 -0.588196 -0.389323 -0.197178 -0.341192 -1.09176\n", + " 1.2586 0.832861 -1.45609 0.244501 -0.129899 -1.59832\n", + " 0.776852 -1.49437 -0.377505 -0.0556672 -0.632541 0.401117\n", + " -0.108585 0.0491398 0.340784 … -1.03658 0.0200718 0.367884\n", + " -0.92731 -1.90718 -0.558201 0.610026 2.53304 -0.111334\n", + " -0.270401 1.04571 0.902924 -0.512312 -0.0238466 -1.63743\n", + " -0.206619 -0.512932 2.3688 -0.118872 -0.233185 0.566295\n", + "\n", + "[:, :, 2, 2] =\n", + " 0.234511 -0.010806 -0.832856 … 0.43269 0.305379 0.303762\n", + " -1.95149 0.50084 0.121275 -0.720525 -0.646199 1.69683\n", + " 1.23164 -0.673325 -0.478049 0.32442 1.27672 -0.625925\n", + " -0.25276 -1.4575 0.37511 1.10781 -1.13742 -0.277507\n", + " 0.508942 -0.820347 1.11952 2.72176 -1.81007 -0.821581\n", + " 1.90962 -1.82509 0.627871 … -1.15038 -0.738956 -0.695905\n", + " -3.59668 0.117514 -0.349758 0.789587 -1.05735 1.45537\n", + " -0.633973 1.61463 0.576952 1.1558 0.371561 -1.43009\n", + " 0.936568 -1.11766 -0.336715 0.117287 -0.63424 0.482243\n", + " -0.253183 -0.12117 -1.09076 0.222656 0.0313442 -0.62774\n", + " 0.213864 0.130833 -1.76951 … 1.82745 0.557266 2.30872\n", + " 0.338235 -0.0273923 -2.06415 1.42201 0.440293 1.00772\n", + " 0.667596 -2.13915 1.10299 -1.84244 1.0915 0.130394\n", + " ⋮ ⋱ \n", + " 0.586522 1.37729 -0.22929 0.411727 0.0435822 -0.116045\n", + " -0.153603 -1.2998 1.52197 -0.56521 -0.0838249 1.18758\n", + " -0.886851 0.258616 1.97687 -0.654801 0.194263 -0.545757\n", + " 0.799342 0.151642 -1.13255 … 0.98931 0.964001 -1.27512\n", + " 0.127431 -1.55811 0.919991 1.39099 1.62688 0.295337\n", + " 0.782026 1.02843 0.59295 -0.813902 0.522271 0.163845\n", + " 0.0650917 -1.96543 0.337413 -0.777451 -0.107169 -1.35744\n", + " 0.150674 -0.388523 1.77383 -0.883289 -0.389462 0.78941\n", + " 1.09536 0.341656 0.69347 … 0.0138373 0.488107 -1.06297\n", + " -0.270047 -1.49177 -1.69148 -0.403618 -1.77107 1.62124\n", + " 0.819439 0.192678 -0.82619 -0.0491964 -1.03829 -0.854479\n", + " 0.121415 0.910955 0.345215 1.69571 0.857381 0.507913\n", + "\n", + "[:, :, 3, 2] =\n", + " 1.54546 -1.57292 -1.22172 … -0.188748 1.28291 1.16054\n", + " -0.262547 0.598369 0.247646 -0.189469 -0.083296 0.311285\n", + " -0.746231 1.13651 0.625644 -0.979754 -0.728773 -0.094224\n", + " 0.59957 -1.43058 1.04906 -0.0492022 -0.348014 -0.208658\n", + " -0.304111 -0.215611 -0.513582 0.240893 -0.281331 -0.352183\n", + " 0.424201 1.02872 -1.27367 … 1.31776 0.76941 0.606991\n", + " -1.28498 -0.421116 -1.06405 1.1623 -0.450123 1.73071\n", + " -1.50988 -0.108995 0.914812 -0.184181 -0.722843 -0.572129\n", + " 1.42696 -0.215924 0.178053 1.54265 0.148367 0.890746\n", + " 0.139446 0.554282 0.870106 -0.532146 0.193758 -0.968426\n", + " -0.0406143 0.624965 -2.45956 … 0.568808 0.808988 1.85776\n", + " 0.0362163 1.47467 1.08824 0.920676 0.529394 -0.121682\n", + " 0.575251 -1.00362 -0.895482 -0.758474 -1.35153 1.742\n", + " ⋮ ⋱ \n", + " 0.824018 1.15677 -1.28596 0.61364 -0.250242 0.675139\n", + " -1.01877 -0.241783 0.650845 -1.139 0.238043 1.51527\n", + " 1.90379 -0.86203 -0.605472 -2.78706 2.4508 -0.204887\n", + " -0.188355 -0.686378 1.00882 … -0.119225 -0.355753 -0.849154\n", + " -1.68806 1.31191 0.0933924 -1.36017 -0.0328785 -2.0318\n", + " 0.142966 0.0406816 -0.865487 -0.624338 -0.530236 -0.897478\n", + " -0.693914 0.829991 1.02462 -0.550439 0.301905 1.42072\n", + " -1.15714 -1.47882 -0.283758 -0.191736 0.993716 1.3369\n", + " -0.635416 0.0334166 -0.970027 … 0.892278 1.03937 -0.771635\n", + " -0.881207 -0.269159 0.149887 -2.51309 0.637512 -1.40859\n", + " -0.115297 -0.723799 1.74762 -2.52757 0.402234 0.141189\n", + " 0.249918 0.0963379 0.912426 -0.165192 -1.12594 0.00358097" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = randn(Float32, (64, 64, 3, 2))\n", + "x\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6e0743e2", "metadata": {}, "outputs": [ { "ename": "LoadError", - "evalue": "MethodError: no method matching ResNetLayer(::Int64, ::Int64; stride=2)\n\u001b[0mClosest candidates are:\n\u001b[0m ResNetLayer(::Int64, ::Int64) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ...", + "evalue": "MethodError: no method matching Array{Float32, 4}(::Int64)\n\u001b[0mClosest candidates are:\n\u001b[0m Array{T, N}(\u001b[91m::Union{Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}} where T, Union{Base.LogicalIndex{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s14\"}, var\"#s14\"}} where var\"#s14\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}}, SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}} where var\"#s15\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, var\"#s16\"}} where var\"#s16\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, LinearAlgebra.Adjoint{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Diagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.LowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Symmetric{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Transpose{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Tridiagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitLowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitUpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, PermutedDimsArray{T, N, <:Any, <:Any, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}} where {T, N}}\u001b[39m) where {T, N} at C:\\Users\\Yash\\.julia\\packages\\NNlibCUDA\\kCpTE\\src\\batchedadjtrans.jl:16\n\u001b[0m Array{T, N}(\u001b[91m::BitArray{N}\u001b[39m) where {T, N} at bitarray.jl:495\n\u001b[0m Array{T, N}(\u001b[91m::FillArrays.Zeros{V, N}\u001b[39m) where {T, V, N} at C:\\Users\\Yash\\.julia\\packages\\FillArrays\\Slipo\\src\\FillArrays.jl:441\n\u001b[0m ...", "output_type": "error", "traceback": [ - "MethodError: no method matching ResNetLayer(::Int64, ::Int64; stride=2)\n\u001b[0mClosest candidates are:\n\u001b[0m ResNetLayer(::Int64, ::Int64) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ResNetLayer(::Int64, ::Int64, \u001b[91m::Any\u001b[39m, \u001b[91m::Any\u001b[39m) at In[5]:2\u001b[91m got unsupported keyword argument \"stride\"\u001b[39m\n\u001b[0m ...", + "MethodError: no method matching Array{Float32, 4}(::Int64)\n\u001b[0mClosest candidates are:\n\u001b[0m Array{T, N}(\u001b[91m::Union{Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}} where T, Union{Base.LogicalIndex{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s14\"}, var\"#s14\"}} where var\"#s14\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}}, SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}} where var\"#s15\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, var\"#s16\"}} where var\"#s16\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, LinearAlgebra.Adjoint{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Diagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.LowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Symmetric{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Transpose{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Tridiagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitLowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitUpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, PermutedDimsArray{T, N, <:Any, <:Any, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}} where {T, N}}\u001b[39m) where {T, N} at C:\\Users\\Yash\\.julia\\packages\\NNlibCUDA\\kCpTE\\src\\batchedadjtrans.jl:16\n\u001b[0m Array{T, N}(\u001b[91m::BitArray{N}\u001b[39m) where {T, N} at bitarray.jl:495\n\u001b[0m Array{T, N}(\u001b[91m::FillArrays.Zeros{V, N}\u001b[39m) where {T, V, N} at C:\\Users\\Yash\\.julia\\packages\\FillArrays\\Slipo\\src\\FillArrays.jl:441\n\u001b[0m ...", "", "Stacktrace:", - " [1] top-level scope", - " @ In[8]:2", - " [2] eval", + " [1] relu(x::Array{Float32, 4})", + " @ Knet.Ops20 C:\\Users\\Yash\\.julia\\packages\\Knet\\YIFWC\\src\\ops20\\activation.jl:26", + " [2] (::ResNetLayer)(x::Array{Float32, 4})", + " @ Main .\\In[10]:4", + " [3] top-level scope", + " @ In[16]:1", + " [4] eval", " @ .\\boot.jl:368 [inlined]", - " [3] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", + " [5] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", " @ Base .\\loading.jl:1428" ] } ], "source": [ - "# Example\n", - "l = ResNetLayer(3, 10, stride = 2)\n", - "x = randn(Float32, (64, 64, 3, 2))\n", - "y = l(x)\n", - "size(y)" + "y = l(x)" ] }, { "cell_type": "code", "execution_count": null, - "id": "8d213e04", + "id": "b08cda2d", "metadata": {}, "outputs": [], "source": [] From bf79488827423323252ce04bff388607f0a2cb46 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Tue, 17 Jan 2023 18:59:21 -0500 Subject: [PATCH 19/26] Add files via upload --- .../julia_resnetmodel_updated.ipynb | 400 ++++++++++++++++++ 1 file changed, 400 insertions(+) create mode 100644 convolutional neural network/julia_resnetmodel_updated.ipynb diff --git a/convolutional neural network/julia_resnetmodel_updated.ipynb b/convolutional neural network/julia_resnetmodel_updated.ipynb new file mode 100644 index 0000000..e356372 --- /dev/null +++ b/convolutional neural network/julia_resnetmodel_updated.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "69f91157", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9b1583d4", + "metadata": {}, + "outputs": [], + "source": [ + "using Yota;\n", + "using MLDatasets;\n", + "using NNlib;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "using Optimisers;\n", + "using MLUtils: DataLoader;\n", + "using OneHotArrays: onehotbatch\n", + "using Knet:conv4\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "using Flux: BatchNorm, kaiming_uniform, nfan;\n", + "using Functors\n", + "\n", + "# Model creation\n", + "using NNlib;\n", + "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n" + ] + }, + { + "cell_type": "markdown", + "id": "19aff91e", + "metadata": {}, + "source": [ + "# Conv 2D" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "481a3d9a", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct Conv2D{T}\n", + " w::AbstractArray{T, 4}\n", + " b::AbstractVector{T}\n", + " use_bias::Bool\n", + "end\n", + "\n", + "@functor Conv2D (w, b)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "59da1b27", + "metadata": {}, + "outputs": [], + "source": [ + "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", + " bias::Bool=false)\n", + " w_size = (kernel_size..., in_channels, out_channels)\n", + " w = kaiming_uniform(w_size...)\n", + " (fan_in, fan_out) = nfan(w_size)\n", + " \n", + " if bias\n", + " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", + " bound = √3 * √2 / √fan_in\n", + " rng = Uniform(-bound, bound)\n", + " b = rand(rng, out_channels, Float32)\n", + " else\n", + " b = zeros(Float32, out_channels)\n", + " end\n", + "\n", + " return Conv2D(w, b, bias)\n", + "end\n", + "\n", + "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", + " y = conv4(self.w, x; stride=stride, padding=pad, dilation=dilation)\n", + " if self.use_bias\n", + " # Bias is applied channel-wise\n", + " (w, h, c, b) = size(y)\n", + " bias = reshape(self.b, (1, 1, c, 1))\n", + " y = y .+ bias\n", + " end\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "252e934f", + "metadata": {}, + "source": [ + "# ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3e66be4f", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNetLayer\n", + " conv1::Conv2D\n", + " conv2::Conv2D\n", + " bn1::BatchNorm\n", + " bn2::BatchNorm\n", + " f::Function\n", + " in_channels::Int\n", + " channels::Int\n", + " stride::Int\n", + "end\n", + "\n", + "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", + "\n", + "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", + " (w, h, c, b) = size(x)\n", + " stride = layer.stride\n", + " if stride > 1\n", + " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", + " \n", + " # Strided downsample\n", + " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", + " else\n", + " x_id = x\n", + " end\n", + "\n", + " channels = layer.channels\n", + " in_channels = layer.in_channels\n", + " if in_channels < channels\n", + " # Zero padding on extra channels\n", + " (w, h, c, b) = size(x_id)\n", + " pad = zeros(w, h, channels - in_channels, b)\n", + " x_id = cat(x_id, pad; dims=3)\n", + " elseif in_channels > channels\n", + " error(\"in_channels > out_channels not supported\")\n", + " end\n", + " return x_id\n", + "end\n", + "\n", + "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", + " bn1 = BatchNorm(in_channels)\n", + " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", + " bn2 = BatchNorm(channels)\n", + " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", + "\n", + " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", + "end\n", + "\n", + "\n", + "function (self::ResNetLayer)(x::AbstractArray)\n", + " identity = residual_identity(self, x)\n", + " z = self.bn1(x)\n", + " z = self.f(z)\n", + " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", + " z = self.bn2(z)\n", + " z = self.f(z)\n", + " z = self.conv2(z; pad=1)\n", + "\n", + " y = z + identity\n", + " return y\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "9f06e04e", + "metadata": {}, + "source": [ + "# Testing ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "7cdc72a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(16, 16, 10, 4)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "l = ResNetLayer(3, 10; stride=2);\n", + "x = randn(Float32, (32, 32, 3, 4));\n", + "y = l(x);\n", + "size(y)" + ] + }, + { + "cell_type": "markdown", + "id": "7b21b952", + "metadata": {}, + "source": [ + "# Linear Layer" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "8987f02c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: method definition for Linear at In[24]:22 declares type variable T but does not use it.\n" + ] + } + ], + "source": [ + "mutable struct Linear\n", + " W::AbstractMatrix{T} where T\n", + " b::AbstractVector{T} where T\n", + "end\n", + "\n", + "@functor Linear\n", + "\n", + "# Init\n", + "function Linear(in_features::Int, out_features::Int)\n", + " k_sqrt = sqrt(1 / in_features)\n", + " d = Uniform(-k_sqrt, k_sqrt)\n", + " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", + "end\n", + "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", + "\n", + "function Base.show(io::IO, l::Linear)\n", + " o, i = size(l.W)\n", + " print(io, \"Linear(o)\")\n", + "end\n", + "\n", + "# Forward\n", + "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "02eca287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet20Model" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ResNet Architecture\n", + "\n", + "mutable struct ResNet20Model\n", + " input_conv::Conv2D\n", + " resnet_blocks::Chain\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet20Model\n", + "\n", + "function ResNet20Model(in_channels::Int, num_classes::Int)\n", + " resnet_blocks = Chain(\n", + " block_1 = ResNetLayer(16, 16),\n", + " block_2 = ResNetLayer(16, 16),\n", + " block_3 = ResNetLayer(16, 16),\n", + " block_4 = ResNetLayer(16, 32; stride=2),\n", + " block_5 = ResNetLayer(32, 32),\n", + " block_6 = ResNetLayer(32, 32),\n", + " block_7 = ResNetLayer(32, 64; stride=2),\n", + " block_8 = ResNetLayer(64, 64),\n", + " block_9 = ResNetLayer(64, 64)\n", + " )\n", + " return ResNet20Model(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " resnet_blocks,\n", + " GlobalMeanPool(),\n", + " Linear(64, num_classes)\n", + " )\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "cdef0144", + "metadata": {}, + "outputs": [], + "source": [ + "function (self::ResNet20Model)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_blocks(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "25c15eb5", + "metadata": {}, + "outputs": [ + { + "ename": "LoadError", + "evalue": "AssertionError: Spatial dimensions are not divisible by `stride`", + "output_type": "error", + "traceback": [ + "AssertionError: Spatial dimensions are not divisible by `stride`", + "", + "Stacktrace:", + " [1] residual_identity(layer::ResNetLayer, x::Array{Float64, 4})", + " @ Main .\\In[22]:18", + " [2] (::ResNetLayer)(x::Array{Float64, 4})", + " @ Main .\\In[22]:50", + " [3] macro expansion", + " @ C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:53 [inlined]", + " [4] _applychain(layers::NTuple{9, ResNetLayer}, x::Array{Float32, 4})", + " @ Flux C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:53", + " [5] _applychain", + " @ C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:59 [inlined]", + " [6] (::Chain{NamedTuple{(:block_1, :block_2, :block_3, :block_4, :block_5, :block_6, :block_7, :block_8, :block_9), NTuple{9, ResNetLayer}}})(x::Array{Float32, 4})", + " @ Flux C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:51", + " [7] (::ResNet20Model)(x::Array{Float32, 4})", + " @ Main .\\In[39]:3", + " [8] top-level scope", + " @ In[40]:6", + " [9] eval", + " @ .\\boot.jl:368 [inlined]", + " [10] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", + " @ Base .\\loading.jl:1428" + ] + } + ], + "source": [ + "\n", + "# Testing ResNet20 model\n", + "# Expected output: (10, 4)\n", + "m = ResNet20Model(3, 10);\n", + "inputs = randn(Float32, (32, 32, 3, 4))\n", + "outputs = m(inputs);\n", + "size(outputs)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df6a846b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a251bf63831789893572138b4dd2d1258fcd335f Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Tue, 17 Jan 2023 18:59:42 -0500 Subject: [PATCH 20/26] Delete ResNet Model V2- Knet.ipynb --- .../ResNet Model V2- Knet.ipynb | 505 ------------------ 1 file changed, 505 deletions(-) delete mode 100644 convolutional neural network/ResNet Model V2- Knet.ipynb diff --git a/convolutional neural network/ResNet Model V2- Knet.ipynb b/convolutional neural network/ResNet Model V2- Knet.ipynb deleted file mode 100644 index f6d02c2..0000000 --- a/convolutional neural network/ResNet Model V2- Knet.ipynb +++ /dev/null @@ -1,505 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "3a44a186", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7a88e0c0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: using Data.Data in module Main conflicts with an existing identifier.\n" - ] - } - ], - "source": [ - "using MLDatasets: CIFAR10\n", - "using MLDataUtils\n", - "using Knet, IterTools\n", - "using Dictionaries\n", - "using TimerOutputs\n", - "using JSON\n", - "using Printf\n", - "using Knet:minibatch\n", - "using Knet:minimize\n", - "using Knet: Param\n", - "using Knet: dir, accuracy, progress, sgd, gc, Data, nll, relu, conv4\n", - "using Flatten\n", - "using Flux.Data;\n", - "using Flux, Statistics\n", - "using Statistics: mean, var\n", - "using Functors" - ] - }, - { - "cell_type": "markdown", - "id": "bb72d038", - "metadata": {}, - "source": [ - "# Processing Data/Batch Processing" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4cf6721f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "32×32×3×45000 Array{Float32, 4}\n", - "45000-element Vector{Int64}\n", - "32×32×3×5000 Array{Float32, 4}\n", - "5000-element Vector{Int64}\n", - "32×32×3×10000 Array{Float32, 4}\n", - "10000-element Vector{Int64}\n" - ] - } - ], - "source": [ - "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", - "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", - "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", - "xtst,ytst = CIFAR10.testdata(Float32)\n", - "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "f49ad3e2", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DataLoader{Tuple{Array{Float32, 4}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}(([0.61960787 0.59607846 … 0.23921569 0.21176471; 0.62352943 0.5921569 … 0.19215687 0.21960784; … ; 0.49411765 0.49019608 … 0.11372549 0.13333334; 0.45490196 0.46666667 … 0.078431375 0.08235294;;; 0.4392157 0.4392157 … 0.45490196 0.41960785; 0.43529412 0.43137255 … 0.4 0.4117647; … ; 0.35686275 0.35686275 … 0.32156864 0.32941177; 0.33333334 0.34509805 … 0.2509804 0.2627451;;; 0.19215687 0.2 … 0.65882355 0.627451; 0.18431373 0.15686275 … 0.5803922 0.58431375; … ; 0.14117648 0.1254902 … 0.49411765 0.5058824; 0.12941177 0.13333334 … 0.41960785 0.43137255;;;; 0.92156863 0.93333334 … 0.32156864 0.33333334; 0.90588236 0.92156863 … 0.18039216 0.24313726; … ; 0.9137255 0.9254902 … 0.7254902 0.7058824; 0.9098039 0.92156863 … 0.73333335 0.7294118;;; 0.92156863 0.93333334 … 0.3764706 0.39607844; 0.90588236 0.92156863 … 0.22352941 0.29411766; … ; 0.9137255 0.9254902 … 0.78431374 0.7647059; 0.9098039 0.92156863 … 0.7921569 0.78431374;;; 0.92156863 0.93333334 … 0.32156864 0.3254902; 0.90588236 0.92156863 … 0.14117648 0.1882353; … ; 0.9137255 0.9254902 … 0.76862746 0.7490196; 0.9098039 0.92156863 … 0.78431374 0.78039217;;;; 0.61960787 0.6666667 … 0.09019608 0.10980392; 0.61960787 0.6745098 … 0.105882354 0.11764706; … ; 0.92941177 0.9647059 … 0.015686275 0.015686275; 0.93333334 0.9647059 … 0.019607844 0.02745098;;; 0.74509805 0.78431374 … 0.13333334 0.16078432; 0.73333335 0.78039217 … 0.14901961 0.16862746; … ; 0.9372549 0.9647059 … 0.023529412 0.019607844; 0.94509804 0.96862745 … 0.02745098 0.03137255;;; 0.87058824 0.8980392 … 0.15294118 0.18431373; 0.85490197 0.8862745 … 0.16862746 0.19607843; … ; 0.9529412 0.98039216 … 0.011764706 0.011764706; 0.9647059 0.9843137 … 0.011764706 0.02745098;;;; … ;;;; 0.078431375 0.08235294 … 0.12941177 0.12156863; 0.07450981 0.078431375 … 0.13333334 0.1254902; … ; 0.047058824 0.039215688 … 0.105882354 0.101960786; 0.050980393 0.047058824 … 0.09803922 0.09803922;;; 0.05882353 0.0627451 … 0.09803922 0.09019608; 0.05490196 0.0627451 … 0.101960786 0.09411765; … ; 0.043137256 0.03529412 … 0.09411765 0.09019608; 0.047058824 0.043137256 … 0.08627451 0.078431375;;; 0.047058824 0.050980393 … 0.05490196 0.047058824; 0.043137256 0.050980393 … 0.05882353 0.050980393; … ; 0.03529412 0.02745098 … 0.21960784 0.20784314; 0.039215688 0.03529412 … 0.18431373 0.18431373;;;; 0.09803922 0.047058824 … 0.40392157 0.37254903; 0.05882353 0.078431375 … 0.40784314 0.37254903; … ; 0.36078432 0.58431375 … 0.3882353 0.37254903; 0.29411766 0.40784314 … 0.36078432 0.36078432;;; 0.15686275 0.09803922 … 0.5176471 0.49411765; 0.14117648 0.14509805 … 0.5137255 0.48235294; … ; 0.44313726 0.65882355 … 0.49803922 0.48235294; 0.34901962 0.45882353 … 0.4745098 0.47058824;;; 0.047058824 0.023529412 … 0.3254902 0.30588236; 0.011764706 0.02745098 … 0.3254902 0.29803923; … ; 0.4392157 0.69411767 … 0.32941177 0.31764707; 0.36078432 0.5137255 … 0.30980393 0.3137255;;;; 0.28627452 0.27058825 … 0.4509804 0.45490196; 0.38431373 0.32941177 … 0.48235294 0.4745098; … ; 0.5294118 0.2784314 … 0.25882354 0.26666668; 0.79607844 0.47058824 … 0.105882354 0.105882354;;; 0.30588236 0.28627452 … 0.4745098 0.47058824; 0.40392157 0.34901962 … 0.4862745 0.47843137; … ; 0.58431375 0.32156864 … 0.25490198 0.25490198; 0.84313726 0.52156866 … 0.105882354 0.101960786;;; 0.29411766 0.27450982 … 0.35686275 0.3529412; 0.44313726 0.38039216 … 0.37254903 0.36862746; … ; 0.6039216 0.3137255 … 0.23137255 0.22745098; 0.8745098 0.5294118 … 0.105882354 0.101960786], [3, 8, 8, 0, 6, 6, 1, 6, 3, 1 … 7, 0, 3, 5, 3, 8, 3, 5, 1, 7]), 256, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_loader = DataLoader((xtrn, ytrn), batchsize=256)\n", - "val_loader = DataLoader((xval, yval), batchsize = 256)\n", - "test_loader = DataLoader((xtst, ytst), batchsize = 256)" - ] - }, - { - "cell_type": "markdown", - "id": "1f2e30a5", - "metadata": {}, - "source": [ - "# Define Struct ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0b16fa40", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct ResNetLayer\n", - " conv1::Flux.Conv\n", - " conv2::Flux.Conv\n", - " bn1::BatchNorm\n", - " bn2::BatchNorm\n", - " activation_function::Function\n", - " in_channels::Int\n", - " channels::Int\n", - " stride::Int \n", - "end \n", - "\n", - "@functor ResNetLayer (conv1, conv2, bn1, bn2)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "055ba6b1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResNetLayer" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Constructor\n", - "function ResNetLayer(in_channels::Int, channels::Int; activation_function = relu, stride = 1)\n", - " bn1 = BatchNorm(in_channels)\n", - " conv1 = Flux.Conv((3,3), in_channels => channels, activation_function; stride = stride)\n", - " bn2 = BatchNorm(channels)\n", - " conv2 = Flux.Conv((3,3), channels => channels, activation_function; stride = stride)\n", - " return ResNetLayer(conv1, conv2, bn1, bn2, activation_function, in_channels, channels, stride)\n", - "end" - ] - }, - { - "cell_type": "markdown", - "id": "1fbe7b50", - "metadata": {}, - "source": [ - "# Define Residual Identity" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f54242be", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "residual_identity (generic function with 1 method)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", - " (w, h, c, b) = size(x)\n", - " stride = layer.stride\n", - " if stride > 1\n", - " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", - " \n", - " # Strided downsample\n", - " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", - " else\n", - " x_id = x\n", - " end\n", - "\n", - " channels = layer.channels\n", - " in_channels = layer.in_channels\n", - " if in_channels < channels\n", - " # Zero padding on extra channels\n", - " (w, h, c, b) = size(x_id)\n", - " pad = zeros(w, h, channels - in_channels, b)\n", - " x_id = cat(x_id, pad; dims=3)\n", - " elseif in_channels > channels\n", - " error(\"in_channels > out_channels not supported\")\n", - " end\n", - " return x_id\n", - "end" - ] - }, - { - "cell_type": "markdown", - "id": "8aac0568", - "metadata": {}, - "source": [ - "# Forward Function" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3881663f", - "metadata": {}, - "outputs": [], - "source": [ - "function (self::ResNetLayer)(x::AbstractArray)\n", - " identity = residual_identity(self, x)\n", - " z = self.bn1(x)\n", - " z = self.activation_function(z)\n", - " z = self.conv1(z)\n", - " z = self.bn2(z)\n", - " z = self.activation_function(z)\n", - " z = self.conv2(z)\n", - " y = z + identity \n", - " return y\n", - "end " - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "316c8d7f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResNetLayer(Conv((3, 3), 3 => 10, relu, stride=2), Conv((3, 3), 10 => 10, relu, stride=2), BatchNorm(3), BatchNorm(10), Knet.Ops20.relu, 3, 10, 2, 0)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Example\n", - "l = ResNetLayer(3, 10; stride = 2, pad = 0)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "3db6251e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64×64×3×2 Array{Float32, 4}:\n", - "[:, :, 1, 1] =\n", - " 1.93723 -0.0038152 1.00776 … -0.268737 0.772344 0.439853\n", - " -1.11342 0.888429 -0.394805 -0.983046 -1.55638 2.38319\n", - " -1.01276 -0.189184 0.00920959 0.082171 -0.282 0.785619\n", - " -0.392059 -0.184873 -1.81964 -0.0519429 0.106289 -0.736538\n", - " 0.19757 0.467828 0.448932 -0.36157 -0.879078 -0.65364\n", - " -0.972052 -0.00179285 -1.44282 … -0.123626 -0.00435947 -0.531083\n", - " -0.612597 -0.171173 0.14255 0.692491 0.548423 0.683022\n", - " 1.25817 -0.170542 0.845562 -0.803024 -0.623968 -2.29548\n", - " -0.266937 1.96865 -1.98532 0.816746 0.717069 -1.49578\n", - " -0.0521157 0.142611 1.59346 -2.30004 -1.42817 -1.13307\n", - " 2.12283 1.47045 -1.58077 … -0.352557 -0.600046 -0.338219\n", - " 1.05403 -0.0555598 -1.05525 -0.127607 0.857893 -1.35023\n", - " -0.856175 -0.982227 -0.988386 1.138 -0.288336 0.735353\n", - " ⋮ ⋱ \n", - " 0.755552 -1.59975 1.23963 -1.50395 0.264384 0.718516\n", - " -0.383118 -0.909808 -0.787967 -2.88865 1.75656 0.213767\n", - " -0.668381 -0.697628 -0.149039 0.930507 0.81691 -1.18663\n", - " -0.783636 -0.659775 1.69925 … 0.577181 1.49819 -0.849574\n", - " 0.57908 0.412455 2.04327 -1.81523 0.32882 0.674949\n", - " 0.705221 -0.339854 0.336234 1.93957 1.55426 0.331621\n", - " -0.402188 0.992751 -0.470351 -1.55654 -0.613687 1.43096\n", - " 0.285479 -0.611887 0.973315 0.960482 1.10192 -0.117772\n", - " 1.55644 -0.590366 0.337902 … 0.141687 0.136244 -0.431893\n", - " -0.864679 0.514239 0.847648 0.746462 -0.393826 -0.417075\n", - " -0.980209 0.300756 -0.0960876 -1.67165 -0.750691 0.759232\n", - " -0.3949 -2.43357 0.0152755 -1.10364 2.26395 0.14055\n", - "\n", - "[:, :, 2, 1] =\n", - " -2.32355 1.00156 … -0.0701956 -0.836908 -1.89917\n", - " 1.04352 0.17987 -0.957125 -1.84391 -0.546967\n", - " -0.915839 -0.489784 1.34493 1.00435 1.43219\n", - " 0.78521 1.81712 -0.629379 -0.777888 -0.1249\n", - " 0.141494 1.25472 -0.124471 1.10209 1.90448\n", - " 0.000519215 -1.0506 … 0.733363 0.808046 0.098419\n", - " -1.1189 -0.552053 0.565745 -0.368468 0.867749\n", - " -0.84709 -0.346625 -0.881431 -0.0301449 0.416381\n", - " -0.230508 -0.525351 -0.335971 -0.629214 -1.34824\n", - " 0.833687 -1.65521 0.145906 -0.559392 -1.0417\n", - " 0.303557 0.160128 … -0.133876 0.70595 -0.0710908\n", - " -0.945794 -0.372286 0.157022 0.520109 -0.553506\n", - " -0.688362 0.908274 0.365752 0.261675 1.44053\n", - " ⋮ ⋱ \n", - " 1.1257 -0.81426 0.502263 0.776487 -0.00868012\n", - " 0.516202 0.0884552 -0.645942 0.185443 0.279642\n", - " 0.262167 -1.16257 0.385125 -0.0882821 -0.883799\n", - " 0.820539 -0.567252 … -0.825821 -1.52597 0.830642\n", - " 0.404737 -0.416801 0.899569 1.0347 -0.840583\n", - " -0.699052 -0.974771 0.750472 -0.761852 2.14697\n", - " -0.412614 -0.0319316 -0.22123 0.991919 -0.973949\n", - " 0.361078 0.142957 1.24954 -1.93624 0.231768\n", - " 0.0268987 0.981564 … -1.70584 1.15414 0.866787\n", - " -1.46858 1.09673 -1.11747 0.856632 0.456886\n", - " -0.117358 -0.0955382 -0.18092 1.41447 0.497232\n", - " 0.597155 -1.56297 -0.193186 -2.90925 -0.463144\n", - "\n", - "[:, :, 3, 1] =\n", - " -0.661052 -1.43996 0.369589 … -0.140701 -0.554115 0.606915\n", - " -2.05038 0.132761 -0.765542 -0.993403 -0.613025 -0.495729\n", - " -0.0564577 1.71149 0.632053 0.0280497 1.01765 0.416239\n", - " 0.185878 2.91298 -0.158277 1.2439 0.798715 1.43122\n", - " 0.176058 0.770648 -0.0930822 1.37595 0.266198 -0.0775581\n", - " -0.801515 -0.0412963 0.996478 … -2.25544 -0.499216 0.298604\n", - " 0.0717726 -0.622027 1.12376 -0.0182525 -0.0241319 0.790553\n", - " 0.565299 -0.976237 1.10213 0.894315 -0.527155 1.63639\n", - " 1.80068 1.92752 0.63103 1.89302 -0.289506 -0.663129\n", - " 0.76663 -0.247099 -0.130537 -0.487301 -1.56555 0.899509\n", - " -0.839651 -0.225549 -0.986695 … 0.433356 -0.309599 -0.0260602\n", - " -0.299321 -0.28825 -0.244942 0.363609 -1.12806 -0.356537\n", - " -1.22345 -0.458875 1.93982 -1.54144 -1.63941 2.01241\n", - " ⋮ ⋱ \n", - " -0.534667 0.314462 -0.169809 1.831 -0.303435 1.97937\n", - " -0.16555 -1.32027 1.23676 -0.536022 -1.07766 -2.183\n", - " 1.72966 -1.38351 1.02683 -0.584877 0.644302 -0.336674\n", - " 1.04335 0.225725 -1.22778 … 1.07201 -0.639405 1.74068\n", - " -0.656627 -1.3051 -1.22059 0.917509 -0.616721 -0.686058\n", - " -0.778703 0.147619 -1.26244 0.87749 1.15391 -0.0203009\n", - " 0.568605 0.757572 0.370278 0.749421 -0.926598 -0.50433\n", - " 1.85186 -1.02303 0.602001 0.475687 -1.03261 -0.0744234\n", - " 0.44897 0.224272 0.491139 … -0.525537 -0.560198 0.727543\n", - " -1.0191 0.576765 1.76492 -0.787169 0.269278 0.724673\n", - " 2.05386 -0.223774 0.411601 1.19905 -1.51291 -0.35303\n", - " 0.999387 0.107672 1.48222 -0.308091 -0.842177 0.456322\n", - "\n", - "[:, :, 1, 2] =\n", - " -0.478833 -0.949141 -0.909521 … 0.566145 -0.806388 -0.103036\n", - " -0.752491 -0.220061 0.374108 -0.563949 -0.33751 0.291631\n", - " 1.25745 -1.56106 0.157859 0.303748 -0.130795 -0.971778\n", - " 1.12126 1.18819 0.463906 0.327253 1.23034 -1.1625\n", - " -0.540558 0.399086 0.915441 -0.192106 1.34276 -0.712125\n", - " -0.0149236 0.367581 -0.864331 … 0.72153 -1.43657 -0.534063\n", - " 0.532564 0.191133 -0.852938 0.903429 1.27837 1.80162\n", - " 1.30331 0.944996 -0.414618 0.0177562 0.621489 -0.412114\n", - " -1.69508 1.01265 0.529925 -0.475814 -1.45582 -0.012964\n", - " 0.907772 0.481818 -1.85653 -0.944343 1.62906 -0.278814\n", - " -1.28081 0.994347 -0.125699 … 0.853954 0.770404 -0.467438\n", - " 1.58481 -0.202481 -0.330101 -0.482117 -0.0351453 0.464559\n", - " -0.620726 -0.755369 0.745984 -0.786137 0.321789 -0.0923321\n", - " ⋮ ⋱ \n", - " 0.537256 0.125566 1.32509 -0.8879 1.40392 -0.798437\n", - " 2.13567 2.07036 -0.63291 -2.07044 -0.308257 0.778777\n", - " -0.109189 0.42111 -0.268277 0.932572 -0.470156 1.46397\n", - " 0.643967 1.29824 -0.328954 … 0.480872 -0.358953 0.116543\n", - " -0.0951204 -0.0243449 0.85459 -0.980945 2.15267 -0.952766\n", - " -2.00296 -0.588196 -0.389323 -0.197178 -0.341192 -1.09176\n", - " 1.2586 0.832861 -1.45609 0.244501 -0.129899 -1.59832\n", - " 0.776852 -1.49437 -0.377505 -0.0556672 -0.632541 0.401117\n", - " -0.108585 0.0491398 0.340784 … -1.03658 0.0200718 0.367884\n", - " -0.92731 -1.90718 -0.558201 0.610026 2.53304 -0.111334\n", - " -0.270401 1.04571 0.902924 -0.512312 -0.0238466 -1.63743\n", - " -0.206619 -0.512932 2.3688 -0.118872 -0.233185 0.566295\n", - "\n", - "[:, :, 2, 2] =\n", - " 0.234511 -0.010806 -0.832856 … 0.43269 0.305379 0.303762\n", - " -1.95149 0.50084 0.121275 -0.720525 -0.646199 1.69683\n", - " 1.23164 -0.673325 -0.478049 0.32442 1.27672 -0.625925\n", - " -0.25276 -1.4575 0.37511 1.10781 -1.13742 -0.277507\n", - " 0.508942 -0.820347 1.11952 2.72176 -1.81007 -0.821581\n", - " 1.90962 -1.82509 0.627871 … -1.15038 -0.738956 -0.695905\n", - " -3.59668 0.117514 -0.349758 0.789587 -1.05735 1.45537\n", - " -0.633973 1.61463 0.576952 1.1558 0.371561 -1.43009\n", - " 0.936568 -1.11766 -0.336715 0.117287 -0.63424 0.482243\n", - " -0.253183 -0.12117 -1.09076 0.222656 0.0313442 -0.62774\n", - " 0.213864 0.130833 -1.76951 … 1.82745 0.557266 2.30872\n", - " 0.338235 -0.0273923 -2.06415 1.42201 0.440293 1.00772\n", - " 0.667596 -2.13915 1.10299 -1.84244 1.0915 0.130394\n", - " ⋮ ⋱ \n", - " 0.586522 1.37729 -0.22929 0.411727 0.0435822 -0.116045\n", - " -0.153603 -1.2998 1.52197 -0.56521 -0.0838249 1.18758\n", - " -0.886851 0.258616 1.97687 -0.654801 0.194263 -0.545757\n", - " 0.799342 0.151642 -1.13255 … 0.98931 0.964001 -1.27512\n", - " 0.127431 -1.55811 0.919991 1.39099 1.62688 0.295337\n", - " 0.782026 1.02843 0.59295 -0.813902 0.522271 0.163845\n", - " 0.0650917 -1.96543 0.337413 -0.777451 -0.107169 -1.35744\n", - " 0.150674 -0.388523 1.77383 -0.883289 -0.389462 0.78941\n", - " 1.09536 0.341656 0.69347 … 0.0138373 0.488107 -1.06297\n", - " -0.270047 -1.49177 -1.69148 -0.403618 -1.77107 1.62124\n", - " 0.819439 0.192678 -0.82619 -0.0491964 -1.03829 -0.854479\n", - " 0.121415 0.910955 0.345215 1.69571 0.857381 0.507913\n", - "\n", - "[:, :, 3, 2] =\n", - " 1.54546 -1.57292 -1.22172 … -0.188748 1.28291 1.16054\n", - " -0.262547 0.598369 0.247646 -0.189469 -0.083296 0.311285\n", - " -0.746231 1.13651 0.625644 -0.979754 -0.728773 -0.094224\n", - " 0.59957 -1.43058 1.04906 -0.0492022 -0.348014 -0.208658\n", - " -0.304111 -0.215611 -0.513582 0.240893 -0.281331 -0.352183\n", - " 0.424201 1.02872 -1.27367 … 1.31776 0.76941 0.606991\n", - " -1.28498 -0.421116 -1.06405 1.1623 -0.450123 1.73071\n", - " -1.50988 -0.108995 0.914812 -0.184181 -0.722843 -0.572129\n", - " 1.42696 -0.215924 0.178053 1.54265 0.148367 0.890746\n", - " 0.139446 0.554282 0.870106 -0.532146 0.193758 -0.968426\n", - " -0.0406143 0.624965 -2.45956 … 0.568808 0.808988 1.85776\n", - " 0.0362163 1.47467 1.08824 0.920676 0.529394 -0.121682\n", - " 0.575251 -1.00362 -0.895482 -0.758474 -1.35153 1.742\n", - " ⋮ ⋱ \n", - " 0.824018 1.15677 -1.28596 0.61364 -0.250242 0.675139\n", - " -1.01877 -0.241783 0.650845 -1.139 0.238043 1.51527\n", - " 1.90379 -0.86203 -0.605472 -2.78706 2.4508 -0.204887\n", - " -0.188355 -0.686378 1.00882 … -0.119225 -0.355753 -0.849154\n", - " -1.68806 1.31191 0.0933924 -1.36017 -0.0328785 -2.0318\n", - " 0.142966 0.0406816 -0.865487 -0.624338 -0.530236 -0.897478\n", - " -0.693914 0.829991 1.02462 -0.550439 0.301905 1.42072\n", - " -1.15714 -1.47882 -0.283758 -0.191736 0.993716 1.3369\n", - " -0.635416 0.0334166 -0.970027 … 0.892278 1.03937 -0.771635\n", - " -0.881207 -0.269159 0.149887 -2.51309 0.637512 -1.40859\n", - " -0.115297 -0.723799 1.74762 -2.52757 0.402234 0.141189\n", - " 0.249918 0.0963379 0.912426 -0.165192 -1.12594 0.00358097" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = randn(Float32, (64, 64, 3, 2))\n", - "x\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "6e0743e2", - "metadata": {}, - "outputs": [ - { - "ename": "LoadError", - "evalue": "MethodError: no method matching Array{Float32, 4}(::Int64)\n\u001b[0mClosest candidates are:\n\u001b[0m Array{T, N}(\u001b[91m::Union{Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}} where T, Union{Base.LogicalIndex{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s14\"}, var\"#s14\"}} where var\"#s14\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}}, SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}} where var\"#s15\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, var\"#s16\"}} where var\"#s16\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, LinearAlgebra.Adjoint{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Diagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.LowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Symmetric{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Transpose{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Tridiagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitLowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitUpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, PermutedDimsArray{T, N, <:Any, <:Any, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}} where {T, N}}\u001b[39m) where {T, N} at C:\\Users\\Yash\\.julia\\packages\\NNlibCUDA\\kCpTE\\src\\batchedadjtrans.jl:16\n\u001b[0m Array{T, N}(\u001b[91m::BitArray{N}\u001b[39m) where {T, N} at bitarray.jl:495\n\u001b[0m Array{T, N}(\u001b[91m::FillArrays.Zeros{V, N}\u001b[39m) where {T, V, N} at C:\\Users\\Yash\\.julia\\packages\\FillArrays\\Slipo\\src\\FillArrays.jl:441\n\u001b[0m ...", - "output_type": "error", - "traceback": [ - "MethodError: no method matching Array{Float32, 4}(::Int64)\n\u001b[0mClosest candidates are:\n\u001b[0m Array{T, N}(\u001b[91m::Union{Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}} where T, Union{Base.LogicalIndex{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s14\"}, var\"#s14\"}} where var\"#s14\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}}, SubArray{<:Any, <:Any, var\"#s15\"}, var\"#s15\"}} where var\"#s15\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, SubArray{<:Any, <:Any, var\"#s16\"}, var\"#s16\"}}, var\"#s16\"}} where var\"#s16\"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, LinearAlgebra.Adjoint{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Diagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.LowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Symmetric{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Transpose{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Tridiagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitLowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitUpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, PermutedDimsArray{T, N, <:Any, <:Any, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}} where {T, N}}\u001b[39m) where {T, N} at C:\\Users\\Yash\\.julia\\packages\\NNlibCUDA\\kCpTE\\src\\batchedadjtrans.jl:16\n\u001b[0m Array{T, N}(\u001b[91m::BitArray{N}\u001b[39m) where {T, N} at bitarray.jl:495\n\u001b[0m Array{T, N}(\u001b[91m::FillArrays.Zeros{V, N}\u001b[39m) where {T, V, N} at C:\\Users\\Yash\\.julia\\packages\\FillArrays\\Slipo\\src\\FillArrays.jl:441\n\u001b[0m ...", - "", - "Stacktrace:", - " [1] relu(x::Array{Float32, 4})", - " @ Knet.Ops20 C:\\Users\\Yash\\.julia\\packages\\Knet\\YIFWC\\src\\ops20\\activation.jl:26", - " [2] (::ResNetLayer)(x::Array{Float32, 4})", - " @ Main .\\In[10]:4", - " [3] top-level scope", - " @ In[16]:1", - " [4] eval", - " @ .\\boot.jl:368 [inlined]", - " [5] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", - " @ Base .\\loading.jl:1428" - ] - } - ], - "source": [ - "y = l(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b08cda2d", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Julia 1.8.2", - "language": "julia", - "name": "julia-1.8" - }, - "language_info": { - "file_extension": ".jl", - "mimetype": "application/julia", - "name": "julia", - "version": "1.8.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 30b96f4b1274f1c25db68940823e7dc5555432a9 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Sat, 28 Jan 2023 22:26:43 -0500 Subject: [PATCH 21/26] Add files via upload --- ...ia_resnetmodel_updated_FINAL_version.ipynb | 1626 +++++++++++++++++ 1 file changed, 1626 insertions(+) create mode 100644 convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb diff --git a/convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb b/convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb new file mode 100644 index 0000000..0ed2ca5 --- /dev/null +++ b/convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb @@ -0,0 +1,1626 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "69f91157", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9b1583d4", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "using Yota;\n", + "using MLDatasets;\n", + "using NNlib;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "using Optimisers;\n", + "using MLUtils: DataLoader;\n", + "using OneHotArrays: onehotbatch\n", + "using Knet:conv4\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "using Flux: BatchNorm, kaiming_uniform, nfan;\n", + "using Functors\n", + "\n", + "# Model creation\n", + "using NNlib;\n", + "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "\n", + "# Data processing\n", + "using MLDatasets;\n", + "using MLUtils: DataLoader;\n", + "using MLDataPattern;\n", + "using ImageCore;\n", + "using Augmentor;\n", + "using ImageFiltering;\n", + "using MappedArrays;\n", + "using Random;\n", + "using Flux: DataLoader;\n", + "# using OneHotArrays: onehotbatch\n", + "\n", + "# Training\n", + "# using Yota;\n", + "using Zygote;\n", + "using Optimisers;\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "\n", + "\n", + "#using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "19aff91e", + "metadata": {}, + "source": [ + "# Conv 2D" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "481a3d9a", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct Conv2D{T}\n", + " w::AbstractArray{T, 4}\n", + " b::AbstractVector{T}\n", + " use_bias::Bool\n", + " padding::Int \n", + "end\n", + "\n", + "@functor Conv2D (w, b)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "59da1b27", + "metadata": {}, + "outputs": [], + "source": [ + "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", + " bias::Bool=false, padding::Int=1)\n", + " w_size = (kernel_size..., in_channels, out_channels)\n", + " w = kaiming_uniform(w_size...)\n", + " (fan_in, fan_out) = nfan(w_size)\n", + " \n", + " if bias\n", + " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", + " bound = √3 * √2 / √fan_in\n", + " rng = Uniform(-bound, bound)\n", + " b = rand(rng, out_channels, Float32)\n", + " else\n", + " b = zeros(Float32, out_channels)\n", + " end\n", + "\n", + " return Conv2D(w, b, bias, padding)\n", + "end\n", + "\n", + "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", + " y = conv4(self.w, x; stride=stride, padding=self.padding, dilation=dilation)\n", + " if self.use_bias\n", + " # Bias is applied channel-wise\n", + " (w, h, c, b) = size(y)\n", + " bias = reshape(self.b, (1, 1, c, 1))\n", + " y = y .+ bias\n", + " end\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "252e934f", + "metadata": {}, + "source": [ + "# ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3e66be4f", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNetLayer\n", + " conv1::Conv2D\n", + " conv2::Conv2D\n", + " bn1::BatchNorm\n", + " bn2::BatchNorm\n", + " f::Function\n", + " in_channels::Int\n", + " channels::Int\n", + " stride::Int\n", + "end\n", + "\n", + "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", + "\n", + "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", + " (w, h, c, b) = size(x)\n", + " stride = layer.stride\n", + " if stride > 1\n", + " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", + " \n", + " # Strided downsample\n", + " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", + " else\n", + " x_id = x\n", + " end\n", + "\n", + " channels = layer.channels\n", + " in_channels = layer.in_channels\n", + " if in_channels < channels\n", + " # Zero padding on extra channels\n", + " (w, h, c, b) = size(x_id)\n", + " pad = zeros(w, h, channels - in_channels, b)\n", + " x_id = cat(x_id, pad; dims=3)\n", + " elseif in_channels > channels\n", + " error(\"in_channels > out_channels not supported\")\n", + " end\n", + " return x_id\n", + "end\n", + "\n", + "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", + " bn1 = BatchNorm(in_channels)\n", + " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", + " bn2 = BatchNorm(channels)\n", + " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", + "\n", + " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", + "end\n", + "\n", + "\n", + "function (self::ResNetLayer)(x::AbstractArray)\n", + " identity = residual_identity(self, x)\n", + " z = self.bn1(x)\n", + " z = self.f(z)\n", + " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", + " z = self.bn2(z)\n", + " z = self.f(z)\n", + " z = self.conv2(z; pad=1)\n", + "\n", + " y = z + identity\n", + " return y\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "9f06e04e", + "metadata": {}, + "source": [ + "# Testing ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7cdc72a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(16, 16, 10, 4)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "l = ResNetLayer(3, 10; stride=2);\n", + "x = randn(Float32, (32, 32, 3, 4));\n", + "y = l(x);\n", + "size(y)" + ] + }, + { + "cell_type": "markdown", + "id": "7b21b952", + "metadata": {}, + "source": [ + "# Linear Layer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8987f02c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: method definition for Linear at In[6]:22 declares type variable T but does not use it.\n" + ] + } + ], + "source": [ + "mutable struct Linear\n", + " W::AbstractMatrix{T} where T\n", + " b::AbstractVector{T} where T\n", + "end\n", + "\n", + "@functor Linear\n", + "\n", + "# Init\n", + "function Linear(in_features::Int, out_features::Int)\n", + " k_sqrt = sqrt(1 / in_features)\n", + " d = Uniform(-k_sqrt, k_sqrt)\n", + " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", + "end\n", + "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", + "\n", + "function Base.show(io::IO, l::Linear)\n", + " o, i = size(l.W)\n", + " print(io, \"Linear(o)\")\n", + "end\n", + "\n", + "# Forward\n", + "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a386ea7a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "02eca287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet20Model" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ResNet Architecture\n", + "\n", + "mutable struct ResNet20Model\n", + " input_conv::Conv2D\n", + " resnet_blocks::Chain\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet20Model\n", + "\n", + "function ResNet20Model(in_channels::Int, num_classes::Int)\n", + " resnet_blocks = Chain(\n", + " block_1 = ResNetLayer(16, 16),\n", + " block_2 = ResNetLayer(16, 16),\n", + " block_3 = ResNetLayer(16, 16),\n", + " block_4 = ResNetLayer(16, 32; stride=2),\n", + " block_5 = ResNetLayer(32, 32),\n", + " block_6 = ResNetLayer(32, 32),\n", + " block_7 = ResNetLayer(32, 64; stride=2),\n", + " block_8 = ResNetLayer(64, 64),\n", + " block_9 = ResNetLayer(64, 64)\n", + " )\n", + " return ResNet20Model(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " resnet_blocks,\n", + " GlobalMeanPool(),\n", + " Linear(64, num_classes)\n", + " )\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cdef0144", + "metadata": {}, + "outputs": [], + "source": [ + "function (self::ResNet20Model)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_blocks(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "25c15eb5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.\n", + "│ yT = Float64\n", + "│ T1 = Float64\n", + "│ T2 = Float32\n", + "└ @ NNlib C:\\Users\\Yash\\.julia\\packages\\NNlib\\0QnJJ\\src\\conv.jl:285\n" + ] + }, + { + "data": { + "text/plain": [ + "(10, 4)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Testing ResNet20 model\n", + "# Expected output: (10, 4)\n", + "m = ResNet20Model(3, 10);\n", + "inputs = randn(Float32, (32, 32, 3, 4))\n", + "outputs = m(inputs);\n", + "size(outputs)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "8e43380e", + "metadata": {}, + "source": [ + "# Data Preprocessing " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "84857fa0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32×32×3×45000 Array{Float32, 4}\n", + "45000-element Vector{Int64}\n", + "32×32×3×5000 Array{Float32, 4}\n", + "5000-element Vector{Int64}\n", + "32×32×3×10000 Array{Float32, 4}\n", + "10000-element Vector{Int64}\n" + ] + } + ], + "source": [ + "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", + "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", + "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", + "xtst,ytst = CIFAR10.testdata(Float32)\n", + "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "45acc000", + "metadata": {}, + "outputs": [], + "source": [ + "# Normalize all the data\n", + "\n", + "means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))\n", + "stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))\n", + "normalize(x) = (x .- means) ./ stdevs\n", + "\n", + "train_x = normalize(xtrn);\n", + "val_x = normalize(xval);\n", + "test_x = normalize(xtst);" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9e93cda3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "splitobs (generic function with 11 methods)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Train-test split\n", + "# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65\n", + "# obsview doesn't work with this data, so use getobs instead\n", + "\n", + "import MLDataPattern.splitobs;\n", + "\n", + "function splitobs(data; at, shuffle::Bool=false)\n", + " if shuffle\n", + " data = shuffleobs(data)\n", + " end\n", + " n = numobs(data)\n", + " return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9c649cac", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Notebook testing: Use less data\n", + "train_x, train_y = MLDatasets.getobs((train_x, ytrn), 1:500);\n", + "\n", + "val_x, val_y = MLDatasets.getobs((val_x, yval), 1:50);\n", + "\n", + "test_x, test_y = MLDatasets.getobs((test_x, ytst), 1:50);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "75266187", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(40, 40, 3, 500)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Pad the training data for further augmentation\n", + "train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0))); \n", + "size(train_x_padded) # Should be (40, 40, 3, 50000)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fc788d3e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6-step Augmentor.ImmutablePipeline:\n", + " 1.) Permute dimension order to (3, 1, 2)\n", + " 2.) Combine color channels into colorant RGB\n", + " 3.) Either: (50%) Flip the X axis. (50%) No operation.\n", + " 4.) Crop random window with size (32, 32)\n", + " 5.) Split colorant into its color channels\n", + " 6.) Permute dimension order to (2, 3, 1)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "815faf28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "outbatch (generic function with 1 method)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create an output array for augmented images\n", + "outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "2e86e8f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "augmentbatch (generic function with 1 method)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Function that takes a batch (images and targets) and augments the images\n", + "augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "e4d362ce", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: The specified values for size and/or count will result in 4 unused data points\n", + "└ @ MLDataPattern C:\\Users\\Yash\\.julia\\packages\\MLDataPattern\\KlSmO\\src\\dataview.jl:205\n" + ] + } + ], + "source": [ + "\n", + "# Shuffled and batched dataset of augmented images\n", + "train_batch_size = 16\n", + "\n", + "train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e2386c3c", + "metadata": {}, + "outputs": [], + "source": [ + "# Test and Validation data\n", + "test_batch_size = 32\n", + "\n", + "val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);\n", + "test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "3998a220", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# # Create model with 3 input channels and 10 classes\n", + " model = ResNet20Model(3, 10);" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "3731cc35", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# loss(xtst, ytst) = nll(model(xtst), ytst)\n", + "# evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss \n", + "# const to = TimerOutput() # creating a TimerOutput, keeps track of everything\n", + "\n", + "\n", + "# @timeit to \"Train Total\" begin\n", + "# for epoch in 1:10\n", + "# train_epoch = epoch > 1 ? \"train_epoch\" : \"train_ji\"\n", + "# @timeit to train_epoch begin\n", + "# progress!(adam(model, train_batches; lr = 1e-3))\n", + "# end\n", + " \n", + "# evaluation = epoch > 1 ? \"evaluation\" : \"eval_jit\"\n", + "# @timeit to evaluation begin\n", + "# accuracy(model, test_loader)\n", + "# end \n", + " \n", + "# end \n", + "# end \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c33ae82c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1edf6901", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c04eb217", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c945bf07", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19b1cfc9", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0396b9b1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "05599606", + "metadata": {}, + "source": [ + "# Training setup" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "fd7aadd5", + "metadata": {}, + "outputs": [], + "source": [ + "#Sparse Cross Entropy function" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "9f6c4d38", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "sparse_logit_cross_entropy (generic function with 1 method)" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\"\"\"\n", + " sparse_logit_cross_entropy(logits, labels)\n", + "\n", + "Efficient computation of cross entropy loss with model logits and integer indices as labels.\n", + "Integer indices are from [0, N-1], where N is the number of classes\n", + "Similar to TensorFlow SparseCategoricalCrossEntropy\n", + "\n", + "# Arguments\n", + "- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)\n", + "- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)\n", + "\n", + "# Returns\n", + "- `loss::Float32`: Cross entropy loss\n", + "\"\"\"\n", + "# function sparse_logit_cross_entropy(logits, labels)\n", + "# log_probs = logsoftmax(logits);\n", + "# # Select indices of labels for loss\n", + "# log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);\n", + "# loss = -mean(log_probs);\n", + "# return loss\n", + "# end\n", + "\n", + "function sparse_logit_cross_entropy(logits, labels)\n", + " log_probs = logsoftmax(logits);\n", + " inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));\n", + " # Select indices of labels for loss\n", + " log_probs = log_probs[inds];\n", + " loss = -mean(log_probs);\n", + " return loss\n", + "end\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "6fa4497b", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup AdamW optimizer\n", + "β = (0.9, 0.999);\n", + "decay = 1e-4;\n", + "state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "b852506d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "(x, y) = first(train_batches);" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "e71cc12e", + "metadata": {}, + "outputs": [], + "source": [ + "# loss, g = grad(loss_function, model, x, y);" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "1a9a8a89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "loss_function (generic function with 1 method)" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mutable struct ResNet5\n", + " input_conv::Conv2D\n", + " resnet_block::ResNetLayer\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet5\n", + "\n", + "function ResNet5(in_channels::Int, num_classes::Int)\n", + " return ResNet5(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " ResNetLayer(16, 16),\n", + " GlobalMeanPool(),\n", + " Linear(16, num_classes)\n", + " )\n", + "end\n", + "\n", + "function (self::ResNet5)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_block(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + "\n", + "\n", + "function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)\n", + " ŷ = model(x)\n", + " loss = sparse_logit_cross_entropy(ŷ, y)\n", + " return loss\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "028a6d25", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?\n", + "# loss, g = grad(loss_function, model, x, y)\n", + "model = ResNet5(3, 10);\n", + "\n", + "loss, g = Zygote.gradient(loss_function, model, x, y);" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "696231c0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32×32×3×16 Array{Float32, 4}:\n", + "[:, :, 1, 1] =\n", + " 2.87384f-5 5.15218f-6 -7.95893f-6 … 2.03477f-5 -2.35546f-5\n", + " 1.31753f-5 -5.36217f-6 -2.37871f-5 1.75244f-5 -9.57452f-6\n", + " 5.49028f-6 -1.71476f-5 -3.2099f-5 9.48742f-6 -1.40432f-5\n", + " 7.2699f-6 -5.92972f-6 -2.18901f-5 2.19976f-5 -1.03878f-5\n", + " -2.49811f-7 -1.22833f-6 -1.28048f-5 2.23008f-5 -8.82221f-6\n", + " -2.13457f-6 2.44618f-6 -1.34374f-5 … 2.04867f-5 -8.85707f-6\n", + " -1.58989f-6 1.10097f-6 -3.85356f-5 1.71607f-5 -8.91197f-6\n", + " 2.43009f-6 2.74427f-6 -2.31398f-5 1.43616f-5 -8.99907f-6\n", + " 8.6382f-6 4.24066f-6 -1.89015f-5 2.12419f-5 -9.01441f-6\n", + " 1.01316f-5 2.10697f-6 -2.02417f-5 2.03862f-5 -8.99482f-6\n", + " 9.46375f-6 -6.25004f-6 -1.73722f-5 … 2.14269f-5 -8.98492f-6\n", + " 9.65401f-6 -3.91408f-6 -2.25808f-5 2.14387f-5 -8.96013f-6\n", + " 9.63481f-6 -4.90097f-6 -2.71419f-5 2.27392f-5 -8.99733f-6\n", + " ⋮ ⋱ ⋮ \n", + " 6.6796f-6 -7.3569f-6 -2.57043f-5 … 2.38568f-5 -8.91103f-6\n", + " 7.58086f-7 -2.9958f-6 -1.29167f-5 2.07219f-5 -8.85528f-6\n", + " 2.3564f-6 4.17472f-6 1.37998f-7 2.16737f-5 -8.90104f-6\n", + " 8.35725f-6 7.49693f-6 -1.33849f-5 2.03694f-5 -8.93113f-6\n", + " 1.10655f-5 1.39176f-6 -3.57415f-5 2.42324f-5 -8.66875f-6\n", + " 1.06251f-5 -2.52541f-6 -3.04435f-5 … 2.01572f-5 -8.3407f-6\n", + " 8.71558f-6 -4.15225f-6 -2.62513f-5 2.06746f-5 -8.21051f-6\n", + " 4.74822f-6 1.08056f-5 -2.1629f-5 2.31422f-5 -7.38363f-6\n", + " 4.34095f-6 -6.39721f-6 -3.72078f-6 2.3523f-5 -1.0264f-5\n", + " 1.64836f-5 1.37776f-5 -1.70328f-5 2.25005f-5 -1.09304f-5\n", + " -2.97065f-6 3.41663f-7 -1.48043f-5 … 1.87838f-5 -1.2569f-5\n", + " -4.29779f-5 -3.22912f-6 -2.24622f-5 -2.00159f-5 -1.13841f-5\n", + "\n", + "[:, :, 2, 1] =\n", + " 3.51288f-6 -1.22697f-5 -2.48598f-5 … -5.24601f-5 -2.63561f-5\n", + " 5.0802f-6 -8.81436f-6 -9.8856f-6 -7.7401f-5 -4.79854f-5\n", + " 6.16604f-6 -1.53578f-5 -1.70939f-5 -6.56983f-5 -4.47286f-5\n", + " 1.24349f-5 -1.22695f-5 -2.24474f-5 -6.35044f-5 -4.26525f-5\n", + " 1.34362f-5 -1.62494f-5 -1.91734f-5 -6.25156f-5 -4.13404f-5\n", + " 1.10909f-5 -2.10735f-5 -1.57398f-5 … -6.32187f-5 -4.12555f-5\n", + " 9.65864f-6 -1.58431f-5 -2.85245f-5 -6.47562f-5 -4.12256f-5\n", + " 8.55365f-6 -1.49781f-5 1.02663f-5 -6.98997f-5 -4.12231f-5\n", + " 7.57427f-6 -1.58006f-5 5.02288f-5 -6.04549f-5 -4.13147f-5\n", + " 9.93018f-6 -7.01827f-6 1.28258f-5 -5.44541f-5 -4.14396f-5\n", + " 1.12572f-5 -1.46116f-5 -2.1672f-5 … -5.62349f-5 -4.13453f-5\n", + " 1.18546f-5 -1.30768f-5 -1.94482f-5 -5.62452f-5 -4.12333f-5\n", + " 1.18409f-5 -1.25132f-5 -2.20645f-5 -5.53337f-5 -4.11621f-5\n", + " ⋮ ⋱ ⋮ \n", + " 1.18743f-5 -1.3683f-5 -2.22972f-5 … -5.65869f-5 -4.12856f-5\n", + " 1.42515f-5 -1.64138f-5 -1.96379f-5 -5.67975f-5 -4.13787f-5\n", + " 1.24159f-5 -1.52812f-5 -1.22782f-5 -6.03015f-5 -4.13753f-5\n", + " 9.13442f-6 -9.09309f-6 -1.12662f-5 -6.61189f-5 -4.21891f-5\n", + " 1.05468f-5 -7.99576f-6 -1.33269f-5 -5.49894f-5 -4.08029f-5\n", + " 1.1644f-5 -1.0743f-5 -1.19364f-5 … -5.90684f-5 -4.122f-5\n", + " 1.14246f-5 -1.58993f-5 -2.733f-5 -6.15047f-5 -4.00522f-5\n", + " 2.33722f-5 -1.66764f-5 -8.95946f-6 -5.53262f-5 -4.31432f-5\n", + " -1.36753f-5 -3.431f-5 -1.81776f-5 -5.894f-5 -4.33847f-5\n", + " -5.28935f-6 -4.45395f-5 -4.33759f-5 -6.29569f-5 -3.67273f-5\n", + " -5.14466f-5 -5.65307f-5 -5.83267f-5 … -7.28656f-5 -3.01449f-5\n", + " -3.77745f-5 -3.05045f-5 -4.26683f-5 -6.23611f-5 -2.93418f-5\n", + "\n", + "[:, :, 3, 1] =\n", + " 1.30337f-5 1.97476f-5 -1.37888f-6 … 2.5648f-5 -4.64034f-6\n", + " 3.68331f-7 3.68014f-5 1.60665f-5 1.87886f-5 1.83326f-5\n", + " 1.01084f-5 3.88496f-5 3.575f-5 2.20447f-5 1.84873f-5\n", + " 1.87285f-6 3.03808f-5 3.2317f-5 2.7348f-5 1.9302f-5\n", + " 3.31373f-6 3.43455f-5 3.81177f-5 3.29928f-5 1.7696f-5\n", + " 4.18613f-6 4.38729f-5 2.48672f-5 … 2.99008f-5 1.75935f-5\n", + " 2.50239f-6 3.87682f-5 4.05324f-5 2.82491f-5 1.7558f-5\n", + " 3.0775f-6 4.32761f-5 3.71651f-5 3.06394f-5 1.76267f-5\n", + " 3.0736f-6 3.7843f-5 -3.7053f-5 2.99056f-5 1.76673f-5\n", + " 2.25529f-6 3.46948f-5 -1.37121f-5 3.12987f-5 1.76435f-5\n", + " 4.21507f-6 4.15411f-5 2.95277f-5 … 3.20953f-5 1.75889f-5\n", + " 3.72306f-6 3.7588f-5 1.91054f-5 3.17634f-5 1.75182f-5\n", + " 3.94906f-6 4.03008f-5 2.22407f-5 3.2732f-5 1.74733f-5\n", + " ⋮ ⋱ ⋮ \n", + " 3.294f-6 3.00897f-5 3.7794f-5 … 2.85874f-5 1.75208f-5\n", + " 2.5483f-6 3.25366f-5 2.8732f-5 3.14754f-5 1.76076f-5\n", + " 3.52557f-6 3.8136f-5 2.43104f-5 3.40014f-5 1.76464f-5\n", + " 2.90404f-6 3.00002f-5 1.83853f-5 3.81449f-5 1.71642f-5\n", + " 2.10267f-6 2.82629f-5 3.10803f-5 3.24481f-5 1.66842f-5\n", + " 3.57918f-6 3.26719f-5 2.92418f-5 … 3.31228f-5 1.64505f-5\n", + " 7.08678f-6 2.56342f-5 4.57358f-5 3.47681f-5 1.91743f-5\n", + " -3.32079f-6 2.68047f-5 2.83111f-5 3.74222f-5 1.95459f-5\n", + " 7.50175f-6 4.48023f-5 3.10073f-5 3.69629f-5 2.05496f-5\n", + " 3.77103f-5 3.62848f-5 3.53838f-5 3.04651f-5 1.50938f-5\n", + " 2.26291f-6 6.13585f-5 5.42865f-5 … 3.72573f-5 2.30365f-5\n", + " 9.99016f-6 1.11194f-5 1.05421f-5 1.7655f-5 6.95817f-6\n", + "\n", + "[:, :, 1, 2] =\n", + " -1.09797f-5 -2.13722f-5 -1.27747f-5 … -1.69448f-5 -7.88239f-6\n", + " 1.12475f-5 -2.0106f-5 -3.77094f-5 -3.23303f-5 -2.27554f-5\n", + " -5.15608f-6 -2.29666f-5 -1.57832f-5 -1.66505f-5 -2.19483f-5\n", + " -4.40688f-6 -9.89303f-6 -2.67959f-5 -6.70534f-6 -1.72258f-5\n", + " 1.08458f-6 -2.6123f-5 -5.26342f-5 -6.83994f-6 -1.43384f-5\n", + " -3.56635f-6 -1.78059f-5 -2.80402f-5 … -6.95699f-6 -1.45106f-5\n", + " -3.02137f-6 -1.73025f-5 -1.40618f-5 -2.04053f-5 -6.60865f-6\n", + " -4.13502f-7 -1.72447f-5 -1.44892f-5 -2.04278f-5 -6.64876f-6\n", + " -6.49256f-6 -9.48326f-6 -1.42205f-5 -2.04246f-5 -6.61602f-6\n", + " -1.11668f-5 -3.3754f-6 -8.46039f-6 -2.0354f-5 -6.67575f-6\n", + " 7.79242f-6 -1.16571f-5 -1.09261f-5 … -2.0452f-5 -6.61188f-6\n", + " 1.53111f-5 -1.4928f-5 2.88932f-6 -2.04422f-5 -6.60986f-6\n", + " 5.2131f-6 -1.3793f-5 -3.3462f-5 -2.0508f-5 -6.62253f-6\n", + " ⋮ ⋱ ⋮ \n", + " 1.38997f-5 -1.92094f-5 -3.57026f-5 … -1.14854f-5 -8.50423f-6\n", + " 3.52126f-5 5.34156f-6 -4.03604f-5 -2.19193f-5 -9.99198f-6\n", + " 3.7833f-5 -1.7161f-5 -5.25765f-5 -2.1176f-5 -4.21367f-6\n", + " -4.3647f-6 5.2507f-6 -5.71789f-6 -5.89364f-6 -1.47315f-5\n", + " 1.43847f-6 -3.27774f-5 4.64099f-5 -2.04356f-5 -6.57705f-6\n", + " 9.02621f-6 -6.39047f-5 2.31406f-5 … -2.04515f-5 -6.59162f-6\n", + " 1.05266f-5 -7.57128f-5 -4.46544f-5 -2.04393f-5 -6.60923f-6\n", + " 9.78515f-6 -3.16386f-5 -5.53192f-5 -2.03575f-5 -6.66031f-6\n", + " 2.875f-5 -1.83572f-5 -4.66145f-5 -1.73273f-5 -6.96841f-6\n", + " 1.6593f-5 -2.25877f-5 -2.63732f-5 -2.25418f-5 -1.45672f-6\n", + " -1.93264f-5 -4.5622f-6 -2.44407f-6 … -5.44911f-6 -5.07896f-6\n", + " -1.26567f-5 -8.28866f-6 -1.98972f-5 -1.90232f-5 -1.22861f-5\n", + "\n", + "[:, :, 2, 2] =\n", + " -4.52723f-6 -3.49323f-5 -4.12799f-5 … -1.50356f-5 -2.16989f-7\n", + " -6.78492f-5 -7.84942f-5 -8.87692f-5 -2.63529f-5 -4.13737f-6\n", + " -8.41087f-5 -3.1084f-5 -9.76829f-5 -4.07474f-5 -7.83035f-6\n", + " -7.08706f-5 -2.40582f-5 -0.000109596 -5.32373f-5 -3.9194f-6\n", + " -7.77566f-5 -6.85934f-7 -0.000118317 -5.26677f-5 -3.95299f-6\n", + " -8.1118f-5 9.22139f-7 -0.000108936 … -4.52668f-5 9.43691f-6\n", + " -8.13958f-5 -1.85499f-7 -0.000107129 -5.46121f-5 1.62983f-5\n", + " -8.20843f-5 -5.70543f-7 -0.00011982 -5.47232f-5 1.63308f-5\n", + " -8.19961f-5 -9.28187f-6 -0.000109798 -5.46595f-5 1.6309f-5\n", + " -7.93751f-5 -2.48049f-5 -0.000109075 -5.44518f-5 1.6343f-5\n", + " -7.90918f-5 -3.43573f-5 -0.000101994 … -5.43597f-5 1.63746f-5\n", + " -6.23688f-5 -2.64341f-5 -9.38569f-5 -5.44339f-5 1.63483f-5\n", + " -6.52792f-5 -2.48703f-5 -7.8946f-5 -5.44905f-5 1.6355f-5\n", + " ⋮ ⋱ ⋮ \n", + " -8.1285f-5 -4.88402f-5 -0.000100694 … -4.41222f-5 9.497f-6\n", + " -4.88119f-5 -8.28254f-5 -9.25717f-5 -5.31737f-5 1.68613f-5\n", + " -5.17702f-5 -4.68936f-5 -7.68762f-5 -5.82533f-5 1.66828f-6\n", + " -5.55461f-5 -4.18896f-5 -7.47917f-5 -4.51579f-5 9.44686f-6\n", + " -4.83162f-5 -5.65184f-5 -4.35828f-5 -5.4252f-5 1.63604f-5\n", + " -5.14842f-5 -3.70771f-5 -5.60604f-5 … -5.44706f-5 1.62527f-5\n", + " -7.58537f-5 -8.31569f-6 -0.000117182 -5.46463f-5 1.62203f-5\n", + " -8.04193f-5 1.43932f-5 -0.000156959 -5.45568f-5 1.63059f-5\n", + " -5.56057f-5 -8.65628f-6 -0.000125903 -5.44588f-5 1.16301f-5\n", + " -7.1956f-5 -1.97293f-5 -0.000133141 -5.4016f-5 1.89256f-5\n", + " -7.28343f-5 -3.14315f-5 -0.000127052 … -5.81531f-5 7.89226f-6\n", + " -6.13478f-5 -5.94695f-7 -7.32276f-5 -2.62953f-5 2.57695f-5\n", + "\n", + "[:, :, 3, 2] =\n", + " -2.19261f-5 2.7254f-5 -4.52801f-6 … -5.49369f-6 1.45297f-5\n", + " -4.46353f-5 2.62317f-5 8.04015f-6 -2.53547f-5 4.73811f-6\n", + " -3.79392f-5 -1.06354f-5 5.85113f-5 -1.7056f-5 6.20228f-6\n", + " 1.1313f-5 1.57169f-5 5.71432f-5 -3.7475f-5 5.49092f-6\n", + " 2.04558f-6 3.71181f-5 4.01594f-5 -3.35148f-5 -6.25117f-6\n", + " 4.73865f-6 1.68689f-5 4.19958f-5 … -2.51303f-5 -4.91489f-6\n", + " 3.20972f-6 1.13822f-5 3.06967f-5 -2.25985f-5 -1.27144f-5\n", + " 4.89944f-6 9.41842f-6 3.82225f-5 -2.26813f-5 -1.27046f-5\n", + " 1.26125f-5 2.34069f-5 2.76001f-5 -2.27173f-5 -1.27674f-5\n", + " 2.32491f-5 1.94585f-5 2.85936f-5 -2.26634f-5 -1.27467f-5\n", + " 4.74152f-6 2.08929f-5 3.76962f-5 … -2.26154f-5 -1.27296f-5\n", + " -2.01045f-6 1.11967f-5 3.11925f-5 -2.26469f-5 -1.2722f-5\n", + " -1.39923f-6 2.3897f-5 3.35059f-5 -2.27345f-5 -1.27035f-5\n", + " ⋮ ⋱ ⋮ \n", + " 7.09541f-6 -3.04937f-5 4.89153f-5 … -1.33056f-5 -9.03981f-6\n", + " -1.36651f-5 2.70014f-5 3.32767f-5 -3.2459f-5 -2.12493f-7\n", + " 3.17517f-5 6.80056f-6 2.52505f-6 -2.99075f-5 -1.37626f-5\n", + " 8.04813f-6 9.21358f-6 -1.15224f-5 -2.61928f-5 -4.83346f-6\n", + " -5.77461f-5 4.40058f-5 2.18061f-5 -2.26172f-5 -1.26375f-5\n", + " -3.28887f-5 1.54883f-5 2.57241f-5 … -2.26395f-5 -1.25606f-5\n", + " 1.8196f-5 7.08923f-6 -3.4964f-6 -2.27131f-5 -1.26321f-5\n", + " 2.73781f-5 3.86673f-7 2.03445f-5 -2.27442f-5 -1.27123f-5\n", + " 1.39553f-5 -5.43552f-6 4.37894f-5 -2.18561f-5 -1.45538f-5\n", + " 1.25204f-5 4.71521f-6 4.39248f-5 -1.00257f-5 -1.665f-5\n", + " 2.14967f-5 1.08339f-5 -6.27319f-7 … 6.15765f-6 -2.11569f-5\n", + " 5.82686f-5 -3.92828f-6 3.74167f-5 9.41486f-7 -1.17081f-5\n", + "\n", + "[:, :, 1, 3] =\n", + " -1.32921f-5 -2.54748f-5 -2.4276f-5 … -2.82705f-5 -1.69928f-5\n", + " -2.25105f-5 -3.48271f-5 -2.74502f-5 -1.38684f-5 -2.77036f-5\n", + " -2.05614f-5 -1.17774f-5 -4.85968f-6 1.2789f-5 -1.05062f-5\n", + " -3.98836f-5 -2.81254f-5 -6.40034f-5 -4.08217f-5 -2.0309f-5\n", + " -4.38632f-5 -2.61992f-5 -4.8767f-5 -1.90552f-6 -7.2014f-6\n", + " -2.09627f-5 -4.17273f-5 -5.62594f-5 … -1.09989f-5 -1.79831f-5\n", + " -1.28669f-5 -1.27832f-5 -3.18246f-5 -4.09247f-5 -1.90839f-5\n", + " -3.97809f-5 -4.64297f-5 -6.28011f-5 -7.87657f-6 -1.12435f-5\n", + " -4.04715f-5 -4.66917f-5 -4.45681f-5 -4.18667f-5 3.13917f-6\n", + " -2.64794f-5 -7.84953f-5 -7.23706f-5 -2.23594f-5 -7.51302f-6\n", + " -3.37457f-5 -1.42143f-5 -4.43477f-5 … -7.00164f-7 -2.95561f-5\n", + " -2.87837f-5 -2.55903f-5 -2.86363f-5 -2.50456f-5 -3.61103f-5\n", + " -7.72938f-6 -5.0325f-6 -6.46772f-5 -6.57527f-5 -1.66607f-5\n", + " ⋮ ⋱ ⋮ \n", + " -3.11348f-5 -4.19208f-5 -6.37247f-5 … -2.89499f-5 -4.03736f-6\n", + " -3.92671f-5 -3.46014f-5 -3.81206f-5 -2.31593f-5 -1.69803f-5\n", + " -3.53422f-5 -2.82261f-5 -4.3099f-5 -2.38775f-5 -1.47377f-5\n", + " -4.09482f-5 -3.56281f-5 -2.9176f-5 -1.49256f-5 -2.23803f-5\n", + " -8.15614f-6 -1.66638f-5 -4.15551f-5 -9.67728f-6 -1.75268f-5\n", + " -2.87176f-5 -1.85177f-5 -1.01355f-5 … -2.92795f-5 -3.84081f-6\n", + " -3.27056f-5 -4.02847f-5 -5.10919f-5 -2.17138f-5 -1.25814f-5\n", + " -2.75769f-5 -2.67168f-5 -2.62755f-5 -2.40851f-5 -1.13193f-5\n", + " -1.80984f-5 -3.46694f-5 -3.13359f-5 -3.40728f-5 -1.15171f-5\n", + " -4.0983f-5 -1.22354f-5 -2.21499f-5 -3.92156f-5 2.48212f-6\n", + " -2.17541f-5 -1.4721f-5 -3.61195f-5 … -1.8059f-5 -2.34319f-5\n", + " -1.73993f-5 -8.16249f-6 -2.49898f-5 -6.6299f-8 -3.32326f-6\n", + "\n", + "[:, :, 2, 3] =\n", + " 1.37565f-5 5.39358f-5 4.58624f-5 … 4.44795f-5 2.50945f-5\n", + " -1.05192f-5 6.36829f-5 4.8567f-5 6.02761f-5 1.23682f-5\n", + " 8.05044f-6 4.69759f-5 2.89093f-5 1.65648f-5 1.09588f-5\n", + " 2.33734f-6 0.000106804 3.09279f-5 -3.14127f-5 1.23341f-5\n", + " 9.22226f-7 0.000113907 2.53979f-5 4.82065f-6 7.19267f-6\n", + " -5.75003f-6 0.000115487 6.12588f-5 … -9.67223f-6 -3.66514f-6\n", + " -7.21526f-6 0.000118417 4.04848f-5 -1.31087f-5 4.85808f-6\n", + " -1.83487f-5 0.000144694 3.57431f-5 2.71545f-6 1.97464f-6\n", + " -2.31197f-6 0.000116426 4.23696f-5 -1.99311f-6 1.24911f-5\n", + " 1.33507f-5 0.000142583 5.01268f-5 8.75845f-6 4.12369f-5\n", + " 2.56636f-5 0.000180252 5.40545f-5 … 0.000105484 2.48439f-5\n", + " -3.61895f-6 0.000149931 4.13484f-5 4.68084f-5 4.66717f-5\n", + " 2.01801f-5 8.81076f-5 -2.2904f-5 -1.71351f-5 -4.21337f-6\n", + " ⋮ ⋱ ⋮ \n", + " -1.24302f-5 0.00013385 9.00135f-5 … -1.33804f-5 5.75934f-6\n", + " 1.25677f-5 0.000127111 8.09964f-5 -1.75825f-5 9.55642f-6\n", + " -1.91218f-6 9.21105f-5 9.24472f-5 -2.37863f-5 9.99772f-6\n", + " 1.86928f-5 0.000107805 9.78562f-5 -9.6838f-6 -4.90789f-6\n", + " 1.90321f-6 9.47662f-5 6.81009f-5 -1.21807f-5 -6.32489f-7\n", + " 4.47338f-6 8.07572f-5 3.97976f-5 … -1.41697f-5 1.19191f-5\n", + " 1.87808f-5 5.76884f-5 5.41935f-5 -8.75647f-6 8.57532f-6\n", + " 3.17319f-5 4.42051f-5 2.78475f-5 -1.52921f-5 9.61939f-6\n", + " 9.52952f-6 5.40576f-5 5.0719f-5 -1.71491f-5 9.88385f-6\n", + " 5.17164f-7 5.02872f-5 3.96059f-5 -1.22204f-5 1.37712f-5\n", + " 1.50076f-5 8.3402f-5 5.2399f-5 … 6.42688f-6 -1.64415f-5\n", + " 3.14011f-6 4.67719f-5 4.23395f-5 2.95754f-5 3.52819f-6\n", + "\n", + "[:, :, 3, 3] =\n", + " -2.13358f-5 -3.7937f-5 -2.88963f-5 … -5.82107f-5 -3.07175f-5\n", + " -8.75597f-6 -8.63936f-6 -1.4399f-6 -4.39748f-5 -3.46493f-5\n", + " -3.99843f-5 -7.18937f-5 -5.48815f-5 2.70835f-5 -2.11729f-5\n", + " -3.38074f-5 -8.02526f-5 -3.76416f-5 -1.26173f-5 -3.77554f-5\n", + " -5.74915f-6 -5.5656f-5 -3.4937f-5 -2.98646f-5 -3.76138f-5\n", + " -8.11329f-6 -2.84418f-5 -3.05806f-5 … -2.07718f-5 -4.41476f-5\n", + " -2.09927f-5 -5.92586f-5 -3.75812f-5 -1.33266f-5 -3.69712f-5\n", + " -2.36342f-5 -6.4137f-5 -2.07103f-6 -1.34164f-5 -4.23239f-5\n", + " -1.78956f-5 -6.06292f-5 -2.3291f-5 -4.24331f-6 -3.4804f-5\n", + " -2.15475f-5 -2.42674f-5 -6.6006f-5 -2.98182f-5 -6.30099f-5\n", + " -8.82353f-6 -4.50255f-5 -3.5869f-5 … -6.11666f-5 -3.6939f-5\n", + " -2.73957f-5 -5.63472f-5 -1.24705f-5 -2.32421f-5 -3.37618f-5\n", + " -1.91092f-5 -0.000121495 -4.93165f-5 -2.39933f-5 -3.47519f-5\n", + " ⋮ ⋱ ⋮ \n", + " -1.63307f-5 -5.76098f-5 -1.52878f-5 … -2.98358f-5 -3.83354f-5\n", + " -3.66333f-6 -5.44591f-5 -6.64642f-5 -2.32273f-5 -4.34126f-5\n", + " -1.08237f-5 -3.67816f-5 -7.95501f-5 -3.41968f-5 -3.89323f-5\n", + " 4.25022f-6 -4.04467f-5 -2.91673f-5 -6.07295f-6 -3.00576f-5\n", + " -2.61168f-5 -2.23294f-5 -5.08112f-5 -1.30184f-5 -4.02096f-5\n", + " -2.06026f-5 -3.47827f-5 -5.70607f-5 … -1.61879f-5 -4.02096f-5\n", + " -1.69183f-5 -3.22452f-5 -2.34737f-5 -1.40456f-5 -4.36295f-5\n", + " -1.28868f-5 -2.36515f-5 -4.06212f-5 -1.34111f-5 -4.07483f-5\n", + " -3.36046f-5 -1.2702f-5 -6.03335f-6 -1.00821f-5 -3.58138f-5\n", + " -1.11375f-5 -5.46132f-5 -6.48593f-6 -1.6174f-5 -3.2f-5\n", + " 1.48887f-5 -2.30397f-5 -1.40682f-5 … -8.4154f-6 -3.42971f-5\n", + " 2.61805f-5 -9.19203f-7 -5.72442f-7 -3.74935f-6 -2.72021f-6\n", + "\n", + ";;;; … \n", + "\n", + "[:, :, 1, 14] =\n", + " 1.33663f-5 4.85302f-6 3.58623f-6 … 1.37095f-5 1.52744f-6\n", + " 2.40446f-5 -2.13888f-5 -7.09837f-7 -2.65502f-5 -3.84853f-5\n", + " 2.15673f-5 -3.83408f-5 -6.04328f-6 -2.47308f-5 -6.90114f-5\n", + " 1.92326f-5 -3.71648f-5 -2.57633f-5 8.74128f-6 -5.50297f-5\n", + " 1.71903f-5 -3.32595f-5 -3.98588f-5 1.39541f-5 -5.28194f-5\n", + " 1.4936f-5 -3.24147f-5 -4.14944f-5 … 1.85891f-5 -4.86048f-5\n", + " 1.73485f-5 -2.66423f-5 -2.27061f-5 1.55611f-5 -4.97301f-5\n", + " 1.35476f-5 -3.83446f-5 -1.91897f-5 1.55515f-5 -4.66568f-5\n", + " 1.39347f-5 -3.14124f-5 -2.70663f-5 1.66001f-5 -4.76427f-5\n", + " 2.03841f-5 -2.95635f-5 -4.12947f-5 1.51247f-5 -4.52537f-5\n", + " 1.9246f-5 -4.71404f-5 -3.9704f-5 … 1.44092f-5 -4.75653f-5\n", + " 1.02947f-5 -2.77938f-5 -1.65234f-5 9.53047f-6 -4.31208f-5\n", + " 1.7194f-5 -2.49142f-5 -2.85847f-5 2.09133f-5 -5.40761f-5\n", + " ⋮ ⋱ ⋮ \n", + " 1.87105f-5 -5.13546f-5 -1.63782f-5 … 1.17234f-5 -4.45439f-5\n", + " 1.84893f-5 -4.52386f-5 -3.94779f-5 -2.03202f-5 -5.12953f-5\n", + " 1.86344f-5 -4.50125f-5 -4.02963f-5 6.36818f-6 -4.96217f-5\n", + " 1.55367f-5 -3.01873f-5 -4.11503f-5 1.33834f-5 -5.56446f-5\n", + " 1.6725f-5 -3.16873f-5 -3.58039f-5 2.98642f-5 -5.41134f-5\n", + " 1.91699f-5 -4.34588f-5 -5.96472f-6 … 3.36201f-5 -6.05377f-5\n", + " 1.02027f-5 -3.68617f-5 -1.26946f-6 1.37002f-5 -5.93067f-5\n", + " 2.02941f-5 -3.33365f-5 -2.79833f-5 -1.27102f-7 -4.2245f-5\n", + " 1.90906f-5 -4.93408f-5 -1.02181f-5 4.24493f-6 -6.76327f-5\n", + " 1.35392f-5 -3.17202f-5 -1.05651f-5 1.37745f-5 -6.95663f-5\n", + " 3.22179f-5 -3.63395f-5 -3.44376f-5 … -1.8031f-5 -3.67907f-5\n", + " 3.1825f-5 -2.07306f-5 -1.293f-5 -1.55749f-5 -5.47155f-5\n", + "\n", + "[:, :, 2, 14] =\n", + " -1.16588f-5 -1.2407f-5 3.49407f-6 … -2.00775f-5 2.20019f-5\n", + " -1.22088f-5 -8.70201f-6 7.22299f-6 -2.42942f-5 3.22809f-5\n", + " -4.46188f-6 1.12563f-6 1.55854f-6 1.82184f-6 2.75427f-5\n", + " -6.50865f-6 6.71314f-6 -4.40675f-6 -1.15001f-5 1.37051f-5\n", + " 1.46834f-6 1.8095f-6 -1.09453f-5 -1.93772f-5 1.18843f-5\n", + " 9.96785f-6 -6.14118f-6 -1.91112f-5 … -2.1082f-5 2.99534f-6\n", + " 7.32546f-6 -1.10439f-5 -1.05676f-5 -1.90141f-5 6.74762f-6\n", + " 4.39465f-7 -8.66755f-7 -3.26735f-5 -1.68382f-5 5.62134f-6\n", + " 1.84562f-6 8.40483f-6 -3.41861f-5 -1.86092f-5 4.87877f-6\n", + " 3.29915f-6 -6.79852f-7 -3.56272f-5 -1.23352f-5 7.23447f-6\n", + " -7.65528f-7 -3.32969f-6 -2.83803f-5 … -1.67057f-5 1.49106f-6\n", + " 3.31928f-6 -5.36953f-6 -2.76984f-5 -1.32463f-5 1.17819f-5\n", + " 8.95364f-6 -1.72866f-6 -3.75768f-5 1.09688f-5 -4.38009f-6\n", + " ⋮ ⋱ ⋮ \n", + " -3.17671f-6 5.2434f-6 -2.69256f-5 … -4.84416f-5 8.25841f-6\n", + " -3.05886f-6 1.01041f-6 -2.00724f-5 -2.82864f-5 -3.48747f-6\n", + " -3.11412f-6 -2.6824f-6 -1.84471f-5 -2.03048f-5 -8.71729f-6\n", + " 2.70974f-6 -4.42211f-6 -1.84807f-5 -1.75151f-5 -2.35158f-5\n", + " 5.19128f-6 -2.57658f-6 -2.41242f-5 2.26607f-5 2.58383f-6\n", + " -4.43079f-7 9.24148f-6 -2.55504f-5 … -4.00863f-5 -2.20372f-7\n", + " 3.4482f-6 9.72595f-6 -3.16414f-5 -3.03052f-5 -2.44575f-5\n", + " 3.71146f-6 3.85173f-6 -2.39893f-5 -2.33657f-5 1.92735f-5\n", + " -4.2958f-7 6.57322f-6 -9.47067f-6 -3.18343f-5 1.15899f-5\n", + " 3.97834f-7 -3.9832f-6 -2.29879f-5 -2.98401f-5 1.83165f-5\n", + " -2.48082f-6 2.33363f-6 -1.80932f-5 … -3.67383f-5 3.11761f-5\n", + " 1.40496f-5 -1.21956f-5 -1.12533f-5 -3.16762f-5 -1.81151f-5\n", + "\n", + "[:, :, 3, 14] =\n", + " -7.93633f-6 -1.74683f-5 -3.30643f-5 … -2.41837f-6 -2.35818f-5\n", + " -2.14967f-5 -3.45701f-5 -4.38043f-5 -3.09205f-6 -2.01295f-5\n", + " -2.20267f-5 -3.53623f-5 -3.24018f-5 6.1502f-6 -1.65125f-5\n", + " -2.08004f-5 -4.26441f-5 -2.3583f-5 -2.71736f-5 -1.68349f-5\n", + " -1.79565f-5 -4.85526f-5 -4.17504f-5 -2.07986f-5 -1.3357f-5\n", + " -2.09896f-5 -2.60134f-5 -5.04995f-5 … -1.76324f-5 -7.67686f-6\n", + " -3.54343f-5 -3.14013f-5 -2.514f-5 -2.07742f-5 -1.19488f-5\n", + " -2.83928f-5 -4.10369f-5 -2.33926f-5 -2.41132f-5 -1.1116f-5\n", + " -1.46165f-5 -4.19793f-5 -2.49181f-5 -2.58115f-5 -1.18515f-5\n", + " -2.30708f-5 -4.50573f-5 -3.26967f-5 -2.11465f-5 -1.01586f-5\n", + " -3.29882f-5 -3.43335f-5 -1.62386f-5 … -2.59596f-5 -7.33063f-6\n", + " -1.49393f-5 -4.06295f-5 -2.20788f-5 -2.70018f-5 -1.27446f-5\n", + " -2.14708f-5 -2.84157f-5 -4.1274f-5 -2.99866f-5 -5.0091f-6\n", + " ⋮ ⋱ ⋮ \n", + " -2.12108f-5 -5.19599f-5 -3.05368f-5 … -2.42368f-5 -1.09666f-5\n", + " -2.10179f-5 -4.74168f-5 -4.27474f-5 -1.5578f-5 -1.5984f-5\n", + " -2.09912f-5 -3.35144f-5 -3.88128f-5 -2.7594f-5 3.39654f-6\n", + " -1.95054f-5 -3.31035f-5 -3.50551f-5 -4.12634f-5 -2.56047f-5\n", + " -2.33043f-5 -3.33993f-5 -3.51916f-5 -1.40093f-5 -1.22362f-5\n", + " -3.31995f-5 -3.87969f-5 -2.05606f-5 … -2.27915f-5 -1.53536f-5\n", + " -1.52208f-5 -5.35259f-5 -2.78927f-5 -2.51014f-5 -2.51453f-5\n", + " -2.30069f-5 -4.30843f-5 -5.04652f-5 -2.78615f-5 1.22018f-6\n", + " -3.28321f-5 -3.75078f-5 -2.48635f-5 -2.51772f-5 -1.9633f-5\n", + " -7.88154f-6 -4.40925f-5 -3.61323f-5 -2.70516f-5 -1.97826f-5\n", + " -1.74243f-5 -3.12033f-5 -4.37996f-5 … -2.22341f-5 -3.2528f-5\n", + " -3.9618f-5 -5.10782f-5 -3.54032f-5 -7.39695f-6 -1.27495f-7\n", + "\n", + "[:, :, 1, 15] =\n", + " -1.42322f-5 -8.28306f-7 -1.66416f-5 … -2.11629f-5 -8.51041f-6\n", + " 1.78201f-5 -2.03631f-5 -2.93015f-5 -3.91225f-5 -2.38064f-5\n", + " -4.01629f-6 7.20289f-6 -5.25714f-6 -1.72928f-5 -3.32815f-5\n", + " -2.82395f-5 -4.71707f-6 -3.74868f-5 -3.02409f-5 -1.64734f-5\n", + " -1.81109f-5 5.98791f-6 -5.90359f-5 -1.7286f-5 -2.49253f-5\n", + " 8.58902f-6 -1.84229f-5 -5.12326f-5 … -3.34777f-5 1.37166f-5\n", + " 4.36534f-5 -1.47507f-5 -3.43016f-5 -2.38659f-5 1.95873f-5\n", + " 3.05164f-5 -1.30216f-5 -2.95457f-5 -2.2512f-5 -1.54837f-5\n", + " 3.12878f-5 -2.15927f-5 -3.26094f-5 -1.72527f-5 -6.8008f-6\n", + " 3.13024f-5 -2.1157f-5 -3.26002f-5 -1.25f-5 -7.81954f-6\n", + " 3.44746f-5 -2.64019f-5 -2.90338f-5 … -8.88944f-6 -1.10185f-5\n", + " 3.56305f-5 -2.51382f-5 -3.36172f-5 -1.70083f-5 -5.94835f-6\n", + " 3.22221f-5 -1.83961f-5 -3.63357f-5 -1.24371f-5 -7.30951f-6\n", + " ⋮ ⋱ ⋮ \n", + " 3.40669f-5 -2.65137f-5 -2.84164f-5 … -1.8529f-5 -8.00771f-6\n", + " 3.23159f-5 -1.99016f-5 -3.71337f-5 -1.29374f-5 -2.77934f-6\n", + " 3.06859f-5 -1.97791f-5 -3.23031f-5 -1.18352f-5 -5.00627f-6\n", + " 2.92809f-5 -2.24281f-5 -2.87513f-5 -1.18677f-5 -1.03392f-5\n", + " 3.13974f-5 -2.08298f-5 -3.25597f-5 -1.30022f-5 -1.88982f-5\n", + " 3.4333f-5 -2.63916f-5 -2.87311f-5 … -1.94221f-5 -9.26061f-6\n", + " 3.27389f-5 -2.00057f-5 -3.68491f-5 -1.92571f-5 -5.71103f-6\n", + " 3.07853f-5 -1.95269f-5 -3.19953f-5 -1.89164f-5 -6.1808f-6\n", + " 3.26149f-5 -2.96727f-5 -2.83632f-6 -1.86268f-5 -3.14514f-6\n", + " 4.09495f-5 -8.29647f-6 -2.99314f-5 -1.44101f-5 -7.22131f-6\n", + " 3.36188f-5 -5.28598f-6 2.68256f-6 … -1.77605f-5 -1.72127f-5\n", + " 1.5718f-5 -1.06357f-5 -1.58782f-5 -5.37788f-6 -4.31066f-5\n", + "\n", + "[:, :, 2, 15] =\n", + " -2.71481f-5 -6.23731f-5 -6.94312f-5 … -3.79918f-5 -7.59295f-6\n", + " -5.76539f-5 -6.44929f-5 -6.63022f-5 -4.24491f-5 -3.27987f-6\n", + " -3.08224f-5 -4.92059f-5 -5.37293f-5 -5.99305f-5 -9.06002f-6\n", + " -6.15207f-5 -5.25847f-5 -5.51519f-5 -5.93541f-5 -2.0077f-5\n", + " -8.6014f-5 -6.91141f-5 -0.000100434 -9.90005f-5 -5.61515f-5\n", + " -5.77686f-5 -0.000104908 -4.87899f-5 … -8.39651f-5 -5.26676f-5\n", + " -6.25512f-5 -0.00010837 -2.85305f-5 -8.89585f-5 -4.53344f-6\n", + " -5.81532f-5 -0.00010201 -2.09666f-5 -8.76519f-5 -1.36613f-5\n", + " -6.27479f-5 -0.000103551 -1.84921f-5 -9.89555f-5 -2.4265f-5\n", + " -6.32125f-5 -0.000103278 -1.85082f-5 -9.37496f-5 -2.54385f-5\n", + " -6.8004f-5 -0.000106113 -1.69805f-5 … -9.61866f-5 -2.37284f-5\n", + " -6.86537f-5 -9.7697f-5 -2.13105f-5 -0.00010507 -1.76843f-5\n", + " -6.46231f-5 -8.94775f-5 -3.06477f-5 -9.72222f-5 -1.85753f-5\n", + " ⋮ ⋱ ⋮ \n", + " -6.71139f-5 -0.000105918 -1.67125f-5 … -9.68319f-5 -2.71956f-5\n", + " -6.31445f-5 -9.47925f-5 -2.28639f-5 -9.50865f-5 -2.68261f-5\n", + " -6.359f-5 -9.87085f-5 -2.61428f-5 -0.000101427 -1.90983f-5\n", + " -6.16764f-5 -0.000107163 -1.4755f-5 -9.67539f-5 -2.50946f-5\n", + " -5.63963f-5 -0.000103415 -2.06108f-5 -8.91569f-5 -3.287f-5\n", + " -6.76495f-5 -0.000106304 -1.62743f-5 … -9.48742f-5 -3.3864f-5\n", + " -6.35575f-5 -9.48656f-5 -2.22142f-5 -0.000101478 -2.79286f-5\n", + " -6.35552f-5 -9.81299f-5 -2.57596f-5 -0.000101187 -2.81493f-5\n", + " -5.26933f-5 -0.000102399 -1.71688f-5 -9.89127f-5 -2.84281f-5\n", + " -5.17102f-5 -7.33952f-5 -3.16825f-5 -0.000103105 -1.59417f-5\n", + " -3.87775f-5 -5.88389f-5 -4.60914f-5 … -8.24853f-5 -6.29939f-6\n", + " -1.01194f-5 -5.58116f-5 1.2544f-8 -2.73959f-5 7.38215f-7\n", + "\n", + "[:, :, 3, 15] =\n", + " -1.05581f-5 4.04572f-6 4.48011f-6 … 2.63798f-6 1.36589f-6\n", + " -1.10731f-5 -1.23928f-6 1.43986f-6 -3.17876f-5 -4.85235f-6\n", + " -3.0389f-5 -1.88746f-5 -1.09316f-6 -2.05259f-5 -6.97609f-6\n", + " -2.31538f-5 -1.87011f-5 4.27785f-5 5.31586f-6 -1.19822f-7\n", + " 1.77212f-5 -3.64338f-5 4.40577f-6 2.44596f-5 2.9193f-5\n", + " -5.06462f-6 -3.26252f-5 -4.55707f-6 … -3.38217f-6 1.32613f-5\n", + " -2.9407f-5 4.73535f-5 1.84007f-5 1.7393f-5 -7.39461f-6\n", + " -3.36971f-6 3.14314f-6 2.49574f-5 2.26891f-5 1.02206f-5\n", + " -2.95847f-6 1.58656f-5 1.15972f-5 2.19764f-5 8.8516f-6\n", + " -2.94252f-6 1.59466f-5 1.12964f-5 2.21957f-5 4.75332f-6\n", + " -4.56594f-6 1.67592f-5 1.0183f-5 … 2.72823f-5 4.26673f-6\n", + " -3.58368f-7 9.34662f-6 1.08674f-5 2.61219f-5 8.15499f-6\n", + " 3.7385f-6 1.83577f-5 1.69644f-5 2.75185f-5 6.39195f-6\n", + " ⋮ ⋱ ⋮ \n", + " -5.14102f-6 1.59719f-5 9.30539f-6 … 2.49045f-5 5.96053f-6\n", + " 1.24572f-6 7.70415f-6 1.15018f-5 2.81032f-5 8.23876f-6\n", + " -7.85542f-8 2.54156f-5 1.60599f-5 3.06293f-5 7.06368f-6\n", + " -6.55128f-6 1.78671f-5 5.77923f-6 1.16846f-5 -2.76677f-7\n", + " -3.26391f-6 9.23154f-6 1.38038f-5 1.70896f-5 6.30555f-6\n", + " -4.88428f-6 1.65113f-5 1.02148f-5 … 2.52434f-5 5.1654f-6\n", + " 7.36047f-7 8.26562f-6 1.18475f-5 2.72033f-5 4.53531f-6\n", + " -2.71816f-8 2.55809f-5 1.5987f-5 2.37181f-5 5.42823f-6\n", + " -5.75722f-6 3.17648f-5 -1.50827f-5 2.2849f-5 6.46165f-6\n", + " 5.26474f-6 2.6f-5 1.6355f-5 2.14678f-5 -6.35829f-6\n", + " 7.64364f-6 1.2984f-5 1.15768f-5 … 9.00361f-6 -2.41165f-5\n", + " 5.46411f-6 2.37213f-5 1.45142f-5 1.40133f-5 8.73353f-6\n", + "\n", + "[:, :, 1, 16] =\n", + " -5.03802f-6 -1.53795f-6 4.60553f-6 … 8.59933f-6 1.34548f-5\n", + " -1.33139f-5 8.2283f-6 4.72714f-6 1.69092f-5 4.31501f-6\n", + " -1.87007f-6 1.15159f-5 5.04875f-6 5.16017f-6 8.04131f-6\n", + " 6.35418f-8 3.08693f-6 -1.29598f-5 8.22963f-6 2.8338f-6\n", + " -2.16983f-5 -2.4079f-5 7.10067f-6 1.87707f-5 4.5435f-6\n", + " -1.25426f-5 -1.95699f-6 -3.54366f-7 … 3.89216f-6 4.52989f-6\n", + " -1.91067f-5 1.63425f-5 -2.13026f-5 9.29225f-6 1.57625f-6\n", + " -1.27482f-5 -6.22103f-6 -1.65056f-5 1.1202f-5 1.11683f-5\n", + " -2.21989f-5 -3.97914f-6 -2.48552f-5 3.4693f-6 6.54992f-6\n", + " -1.66821f-5 3.08557f-6 -2.25274f-5 2.99258f-6 1.19086f-5\n", + " -1.35066f-5 -3.56337f-5 -3.92915f-5 … -6.58634f-7 4.13911f-6\n", + " 2.01907f-7 -4.15685f-6 -2.32959f-5 -3.8229f-6 6.58547f-6\n", + " -2.55842f-5 -4.38035f-6 -7.13148f-6 -9.832f-7 3.90915f-6\n", + " ⋮ ⋱ ⋮ \n", + " -2.38357f-5 -1.98589f-5 -2.21745f-5 … 9.79263f-6 -5.56729f-7\n", + " -2.90955f-5 -2.92635f-6 -2.00649f-5 6.53204f-7 4.46353f-6\n", + " -2.28595f-5 -1.98966f-5 -8.38325f-6 1.59205f-5 4.65142f-6\n", + " -2.96641f-5 -2.09721f-5 -1.28431f-5 4.71605f-6 2.15453f-6\n", + " -2.20343f-5 -3.17625f-5 -5.10912f-5 6.94432f-6 1.09819f-5\n", + " -2.3516f-6 1.31662f-6 -1.44183f-5 … -4.05642f-6 -2.73961f-6\n", + " -2.16968f-5 -6.0583f-6 -9.5029f-6 1.47617f-5 1.67887f-6\n", + " -3.17614f-6 -2.08875f-5 2.43521f-5 -9.33056f-6 8.64669f-6\n", + " -2.03552f-5 9.42903f-6 2.71075f-6 9.54452f-6 2.22076f-5\n", + " 5.75616f-6 1.0806f-5 -2.36865f-7 4.67347f-6 1.61059f-5\n", + " -1.40951f-5 -1.02704f-5 -3.18635f-6 … -8.39052f-7 2.45332f-5\n", + " -3.1905f-6 -3.46778f-6 -3.62573f-6 1.71537f-5 2.06471f-5\n", + "\n", + "[:, :, 2, 16] =\n", + " -2.15884f-5 -1.14227f-5 -1.62082f-5 … -2.55764f-5 -3.13291f-5\n", + " -2.23083f-5 1.90108f-5 1.10183f-5 -5.42675f-7 -1.1911f-5\n", + " -9.12809f-6 3.98662f-5 2.59647f-6 -3.74951f-5 -8.69153f-6\n", + " -7.62137f-6 3.32652f-5 1.61093f-5 -2.59258f-5 7.50499f-6\n", + " -3.47015f-6 2.78749f-7 1.91968f-5 -2.03829f-5 1.30214f-5\n", + " -4.30571f-6 1.6286f-5 3.87096f-5 … -1.3211f-5 1.14148f-5\n", + " 1.43065f-7 3.21748f-5 4.9644f-5 -1.63238f-5 8.8627f-6\n", + " 4.03356f-6 4.056f-5 3.15086f-5 -1.25394f-5 2.2954f-6\n", + " -1.13654f-5 7.34695f-5 3.3689f-5 -1.45838f-5 1.1383f-7\n", + " -2.92628f-5 5.7042f-5 -2.2081f-6 3.45306f-6 1.79821f-6\n", + " -1.09181f-5 4.72039f-5 1.92767f-5 … 2.18812f-6 -6.04413f-6\n", + " -1.69237f-5 1.27184f-5 1.6065f-5 5.88449f-6 -1.0836f-5\n", + " -8.81984f-6 2.0075f-5 4.56131f-5 -8.64202f-6 -1.25129f-5\n", + " ⋮ ⋱ ⋮ \n", + " 1.20612f-5 5.55669f-5 6.90552f-5 … -2.94486f-5 -1.32609f-5\n", + " 2.02f-5 4.96709f-5 6.90785f-5 -1.91592f-5 4.8409f-6\n", + " 1.83125f-5 6.40737f-5 6.18068f-5 -2.0052f-5 1.90681f-6\n", + " 7.91827f-6 6.21826f-5 6.86201f-5 -2.40519f-5 7.56757f-6\n", + " 4.21825f-7 4.81752f-5 4.54214f-5 -1.54302f-5 5.53415f-6\n", + " 2.0481f-6 3.59647f-5 -6.68709f-6 … -8.75948f-6 5.73184f-6\n", + " 2.2825f-6 4.04816f-5 1.54492f-5 1.44027f-5 5.88194f-6\n", + " -2.17447f-5 1.23614f-8 -5.07233f-7 -9.76107f-6 -5.2898f-6\n", + " -4.04921f-6 3.28196f-5 1.09735f-5 -2.13296f-5 -1.0723f-5\n", + " -6.53629f-7 9.64771f-6 2.88722f-5 -5.34081f-6 1.22531f-5\n", + " -3.75205f-6 5.36633f-6 2.31161f-5 … -9.85085f-7 1.42955f-5\n", + " 5.52764f-6 2.27731f-5 3.81831f-5 1.51786f-5 1.40957f-5\n", + "\n", + "[:, :, 3, 16] =\n", + " 2.09887f-5 2.48268f-5 3.26262f-5 … 2.52067f-5 1.83206f-5\n", + " 1.60499f-5 3.12218f-5 3.18235f-5 1.53867f-5 2.27519f-5\n", + " 3.48141f-5 9.79703f-6 2.88586f-6 1.14747f-5 1.88519f-5\n", + " 8.71227f-7 2.75015f-5 4.37508f-5 1.78847f-5 1.87048f-5\n", + " 2.18388f-5 4.08884f-5 3.65291f-5 1.0776f-5 2.12948f-5\n", + " 3.40946f-5 4.08884f-5 1.94329f-5 … 1.0377f-5 2.05071f-5\n", + " 1.92069f-5 1.97018f-5 3.90101f-5 1.92215f-5 1.76457f-5\n", + " 2.36644f-5 9.86611f-6 2.17295f-5 2.36401f-5 1.73641f-5\n", + " 2.31846f-5 2.88264f-5 3.39537f-5 2.51549f-5 1.47432f-5\n", + " 2.57393f-5 -9.20106f-6 3.30965f-5 4.29673f-5 1.3416f-5\n", + " 1.70486f-5 1.40648f-5 1.81989f-5 … 3.695f-5 1.24074f-5\n", + " 2.94645f-5 4.57801f-5 2.73179f-5 3.59657f-5 8.28307f-6\n", + " 2.56688f-5 1.4644f-5 2.24626f-5 2.24926f-5 8.88082f-6\n", + " ⋮ ⋱ ⋮ \n", + " 2.03935f-5 3.42854f-5 5.86772f-6 … 4.35592f-7 1.05793f-5\n", + " 1.92946f-5 2.94853f-5 3.70969f-5 1.82321f-5 1.76721f-5\n", + " 1.20686f-5 3.67779f-5 2.70968f-5 6.38749f-6 1.87694f-5\n", + " 1.40735f-5 3.77764f-6 2.13491f-5 1.19133f-5 1.27518f-5\n", + " 2.57298f-5 2.21913f-5 1.81671f-5 1.72985f-5 1.27492f-5\n", + " 1.5189f-5 2.35697f-5 1.81276f-5 … 1.98066f-5 7.02275f-6\n", + " 2.92245f-6 4.0379f-6 3.79888f-5 2.83694f-5 1.09155f-5\n", + " 2.72736f-5 8.72931f-6 -5.25945f-6 2.87468f-5 2.20531f-5\n", + " 2.43792f-5 1.75247f-5 6.85327f-6 -2.04517f-5 -5.39688f-6\n", + " 7.97139f-6 9.4859f-6 -1.00873f-6 2.4968f-5 1.06572f-5\n", + " 9.65145f-6 1.21162f-5 1.81417f-8 … 9.58578f-6 6.21811f-6\n", + " -1.35314f-5 -1.89171f-5 -2.95601f-5 -2.69976f-5 -1.33066f-5" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g" + ] + }, + { + "cell_type": "markdown", + "id": "a57b3c8d", + "metadata": {}, + "source": [ + "# Evaluation Function" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "02f69609", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "evaluate (generic function with 1 method)" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "function evaluate(model, test_loader)\n", + " preds = []\n", + " targets = []\n", + " for (x, y) in test_loader\n", + " # Get model predictions\n", + " # Note argmax of nd-array gives CartesianIndex\n", + " # Need to grab the first element of each CartesianIndex to get the true index\n", + " logits = model(x)\n", + " ŷ = map(i -> i[1], argmax(logits, dims=1))\n", + " append!(preds, ŷ)\n", + "\n", + " # Get true labels\n", + " append!(targets, y)\n", + " end\n", + " accuracy = sum(preds .== targets) / length(targets)\n", + " return accuracy\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "f2072bd8", + "metadata": {}, + "source": [ + "# Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "cc39bcab", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: redefinition of constant to. This may fail, cause incorrect answers, or produce other errors.\n" + ] + }, + { + "data": { + "text/plain": [ + "\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────\u001b[22m\n", + "\u001b[0m\u001b[1m \u001b[22m Time Allocations \n", + " ─────────────────────── ────────────────────────\n", + " Tot / % measured: 1.35ms / 0.0% 13.7KiB / 0.0% \n", + "\n", + " Section ncalls time %tot avg alloc %tot avg\n", + " ────────────────────────────────────────────────────────────────────\n", + "\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────\u001b[22m" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Setup timing output\n", + "const to = TimerOutput()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "9b5c088c", + "metadata": {}, + "outputs": [ + { + "ename": "LoadError", + "evalue": "No derivative rule found for op %1174 = lastindex(%1172)::Int64 , try defining it using \n\n\tChainRulesCore.rrule(::typeof(lastindex), ::NTuple{4, Int64}) = ...\n", + "output_type": "error", + "traceback": [ + "No derivative rule found for op %1174 = lastindex(%1172)::Int64 , try defining it using \n\n\tChainRulesCore.rrule(::typeof(lastindex), ::NTuple{4, Int64}) = ...\n", + "", + "Stacktrace:", + " [1] error(s::String)", + " @ Base .\\error.jl:35", + " [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)", + " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:219", + " [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)", + " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:260", + " [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)", + " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:273", + " [5] gradtape(::Function, ::ResNet5, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)", + " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:300", + " [6] grad(::Function, ::ResNet5, ::Vararg{Any}; seed::Int64)", + " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:370", + " [7] grad(::Function, ::ResNet5, ::Vararg{Any})", + " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:362", + " [8] macro expansion", + " @ .\\In[45]:14 [inlined]", + " [9] macro expansion", + " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", + " [10] macro expansion", + " @ .\\In[45]:9 [inlined]", + " [11] top-level scope", + " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", + " [12] top-level scope", + " @ .\\In[45]:0", + " [13] eval", + " @ .\\boot.jl:368 [inlined]", + " [14] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", + " @ Base .\\loading.jl:1428" + ] + } + ], + "source": [ + "last_loss = 0;\n", + "@timeit to \"total_training_time\" begin\n", + " for epoch in 1:10\n", + " timing_name = epoch > 1 ? \"average_epoch_training_time\" : \"train_jit\"\n", + "\n", + " # Create lazily evaluated augmented training data\n", + " train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", + "\n", + " @timeit to timing_name begin\n", + " losses = []\n", + " for (x, y) in train_batches\n", + " # loss_function does forward pass\n", + " # Yota.jl grad function computes model parameter gradients in g[2]\n", + " loss, g = grad(loss_function, model, x, y)\n", + " \n", + " # Optimiser updates parameters\n", + " Optimisers.update!(state, model, g[2])\n", + " push!(losses, loss)\n", + " end\n", + " last_loss = mean(losses)\n", + " @info(\"epoch (mean(losses))\")\n", + " end\n", + " # timing_name = epoch > 1 ? \"average_inference_time\" : \"eval_jit\"\n", + " # @timeit to timing_name begin\n", + " # acc = evaluate(model, test_loader)\n", + " # @info(\"epoch (acc)\")\n", + " # end\n", + " end\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1955c486", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ace272d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From eb03d862c37772a9a6fdf52661f606396aa92ddf Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Sat, 28 Jan 2023 22:27:00 -0500 Subject: [PATCH 22/26] Delete julia_resnetmodel_updated.ipynb --- .../julia_resnetmodel_updated.ipynb | 400 ------------------ 1 file changed, 400 deletions(-) delete mode 100644 convolutional neural network/julia_resnetmodel_updated.ipynb diff --git a/convolutional neural network/julia_resnetmodel_updated.ipynb b/convolutional neural network/julia_resnetmodel_updated.ipynb deleted file mode 100644 index e356372..0000000 --- a/convolutional neural network/julia_resnetmodel_updated.ipynb +++ /dev/null @@ -1,400 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "69f91157", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "9b1583d4", - "metadata": {}, - "outputs": [], - "source": [ - "using Yota;\n", - "using MLDatasets;\n", - "using NNlib;\n", - "using Statistics;\n", - "using Distributions;\n", - "using Functors;\n", - "using Optimisers;\n", - "using MLUtils: DataLoader;\n", - "using OneHotArrays: onehotbatch\n", - "using Knet:conv4\n", - "using Metrics;\n", - "using TimerOutputs;\n", - "using Flux: BatchNorm, kaiming_uniform, nfan;\n", - "using Functors\n", - "\n", - "# Model creation\n", - "using NNlib;\n", - "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", - "using Statistics;\n", - "using Distributions;\n", - "using Functors;\n" - ] - }, - { - "cell_type": "markdown", - "id": "19aff91e", - "metadata": {}, - "source": [ - "# Conv 2D" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "481a3d9a", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct Conv2D{T}\n", - " w::AbstractArray{T, 4}\n", - " b::AbstractVector{T}\n", - " use_bias::Bool\n", - "end\n", - "\n", - "@functor Conv2D (w, b)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "59da1b27", - "metadata": {}, - "outputs": [], - "source": [ - "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", - " bias::Bool=false)\n", - " w_size = (kernel_size..., in_channels, out_channels)\n", - " w = kaiming_uniform(w_size...)\n", - " (fan_in, fan_out) = nfan(w_size)\n", - " \n", - " if bias\n", - " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", - " bound = √3 * √2 / √fan_in\n", - " rng = Uniform(-bound, bound)\n", - " b = rand(rng, out_channels, Float32)\n", - " else\n", - " b = zeros(Float32, out_channels)\n", - " end\n", - "\n", - " return Conv2D(w, b, bias)\n", - "end\n", - "\n", - "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", - " y = conv4(self.w, x; stride=stride, padding=pad, dilation=dilation)\n", - " if self.use_bias\n", - " # Bias is applied channel-wise\n", - " (w, h, c, b) = size(y)\n", - " bias = reshape(self.b, (1, 1, c, 1))\n", - " y = y .+ bias\n", - " end\n", - " return y\n", - "end\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "252e934f", - "metadata": {}, - "source": [ - "# ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "3e66be4f", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct ResNetLayer\n", - " conv1::Conv2D\n", - " conv2::Conv2D\n", - " bn1::BatchNorm\n", - " bn2::BatchNorm\n", - " f::Function\n", - " in_channels::Int\n", - " channels::Int\n", - " stride::Int\n", - "end\n", - "\n", - "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", - "\n", - "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", - " (w, h, c, b) = size(x)\n", - " stride = layer.stride\n", - " if stride > 1\n", - " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", - " \n", - " # Strided downsample\n", - " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", - " else\n", - " x_id = x\n", - " end\n", - "\n", - " channels = layer.channels\n", - " in_channels = layer.in_channels\n", - " if in_channels < channels\n", - " # Zero padding on extra channels\n", - " (w, h, c, b) = size(x_id)\n", - " pad = zeros(w, h, channels - in_channels, b)\n", - " x_id = cat(x_id, pad; dims=3)\n", - " elseif in_channels > channels\n", - " error(\"in_channels > out_channels not supported\")\n", - " end\n", - " return x_id\n", - "end\n", - "\n", - "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", - " bn1 = BatchNorm(in_channels)\n", - " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", - " bn2 = BatchNorm(channels)\n", - " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", - "\n", - " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", - "end\n", - "\n", - "\n", - "function (self::ResNetLayer)(x::AbstractArray)\n", - " identity = residual_identity(self, x)\n", - " z = self.bn1(x)\n", - " z = self.f(z)\n", - " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", - " z = self.bn2(z)\n", - " z = self.f(z)\n", - " z = self.conv2(z; pad=1)\n", - "\n", - " y = z + identity\n", - " return y\n", - "end" - ] - }, - { - "cell_type": "markdown", - "id": "9f06e04e", - "metadata": {}, - "source": [ - "# Testing ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "7cdc72a9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(16, 16, 10, 4)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "l = ResNetLayer(3, 10; stride=2);\n", - "x = randn(Float32, (32, 32, 3, 4));\n", - "y = l(x);\n", - "size(y)" - ] - }, - { - "cell_type": "markdown", - "id": "7b21b952", - "metadata": {}, - "source": [ - "# Linear Layer" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "8987f02c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: method definition for Linear at In[24]:22 declares type variable T but does not use it.\n" - ] - } - ], - "source": [ - "mutable struct Linear\n", - " W::AbstractMatrix{T} where T\n", - " b::AbstractVector{T} where T\n", - "end\n", - "\n", - "@functor Linear\n", - "\n", - "# Init\n", - "function Linear(in_features::Int, out_features::Int)\n", - " k_sqrt = sqrt(1 / in_features)\n", - " d = Uniform(-k_sqrt, k_sqrt)\n", - " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", - "end\n", - "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", - "\n", - "function Base.show(io::IO, l::Linear)\n", - " o, i = size(l.W)\n", - " print(io, \"Linear(o)\")\n", - "end\n", - "\n", - "# Forward\n", - "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "02eca287", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResNet20Model" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# ResNet Architecture\n", - "\n", - "mutable struct ResNet20Model\n", - " input_conv::Conv2D\n", - " resnet_blocks::Chain\n", - " pool::GlobalMeanPool\n", - " linear::Linear\n", - "end\n", - "\n", - "@functor ResNet20Model\n", - "\n", - "function ResNet20Model(in_channels::Int, num_classes::Int)\n", - " resnet_blocks = Chain(\n", - " block_1 = ResNetLayer(16, 16),\n", - " block_2 = ResNetLayer(16, 16),\n", - " block_3 = ResNetLayer(16, 16),\n", - " block_4 = ResNetLayer(16, 32; stride=2),\n", - " block_5 = ResNetLayer(32, 32),\n", - " block_6 = ResNetLayer(32, 32),\n", - " block_7 = ResNetLayer(32, 64; stride=2),\n", - " block_8 = ResNetLayer(64, 64),\n", - " block_9 = ResNetLayer(64, 64)\n", - " )\n", - " return ResNet20Model(\n", - " Conv2D((3, 3), in_channels, 16, bias=false),\n", - " resnet_blocks,\n", - " GlobalMeanPool(),\n", - " Linear(64, num_classes)\n", - " )\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "cdef0144", - "metadata": {}, - "outputs": [], - "source": [ - "function (self::ResNet20Model)(x::AbstractArray)\n", - " z = self.input_conv(x)\n", - " z = self.resnet_blocks(z)\n", - " z = self.pool(z)\n", - " z = dropdims(z, dims=(1, 2))\n", - " y = self.linear(z)\n", - " return y\n", - "end\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "25c15eb5", - "metadata": {}, - "outputs": [ - { - "ename": "LoadError", - "evalue": "AssertionError: Spatial dimensions are not divisible by `stride`", - "output_type": "error", - "traceback": [ - "AssertionError: Spatial dimensions are not divisible by `stride`", - "", - "Stacktrace:", - " [1] residual_identity(layer::ResNetLayer, x::Array{Float64, 4})", - " @ Main .\\In[22]:18", - " [2] (::ResNetLayer)(x::Array{Float64, 4})", - " @ Main .\\In[22]:50", - " [3] macro expansion", - " @ C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:53 [inlined]", - " [4] _applychain(layers::NTuple{9, ResNetLayer}, x::Array{Float32, 4})", - " @ Flux C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:53", - " [5] _applychain", - " @ C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:59 [inlined]", - " [6] (::Chain{NamedTuple{(:block_1, :block_2, :block_3, :block_4, :block_5, :block_6, :block_7, :block_8, :block_9), NTuple{9, ResNetLayer}}})(x::Array{Float32, 4})", - " @ Flux C:\\Users\\Yash\\.julia\\packages\\Flux\\4k0Ls\\src\\layers\\basic.jl:51", - " [7] (::ResNet20Model)(x::Array{Float32, 4})", - " @ Main .\\In[39]:3", - " [8] top-level scope", - " @ In[40]:6", - " [9] eval", - " @ .\\boot.jl:368 [inlined]", - " [10] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", - " @ Base .\\loading.jl:1428" - ] - } - ], - "source": [ - "\n", - "# Testing ResNet20 model\n", - "# Expected output: (10, 4)\n", - "m = ResNet20Model(3, 10);\n", - "inputs = randn(Float32, (32, 32, 3, 4))\n", - "outputs = m(inputs);\n", - "size(outputs)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "df6a846b", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Julia 1.8.2", - "language": "julia", - "name": "julia-1.8" - }, - "language_info": { - "file_extension": ".jl", - "mimetype": "application/julia", - "name": "julia", - "version": "1.8.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From c276c411d9e6676087adcf93acab5cf441148eea Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Thu, 9 Feb 2023 01:43:29 -0500 Subject: [PATCH 23/26] Add files via upload --- julia_resnetmodel_updated_model v2.ipynb | 1036 ++++++++++++++++++++++ 1 file changed, 1036 insertions(+) create mode 100644 julia_resnetmodel_updated_model v2.ipynb diff --git a/julia_resnetmodel_updated_model v2.ipynb b/julia_resnetmodel_updated_model v2.ipynb new file mode 100644 index 0000000..68d134d --- /dev/null +++ b/julia_resnetmodel_updated_model v2.ipynb @@ -0,0 +1,1036 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "69f91157", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9b1583d4", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "using Yota;\n", + "using MLDatasets;\n", + "using NNlib;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "using Optimisers;\n", + "using MLUtils: DataLoader;\n", + "using OneHotArrays: onehotbatch\n", + "using Knet:Knet,conv4, adam\n", + "using Knet: dir, accuracy, progress, sgd, gc\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "using Flux: BatchNorm, kaiming_uniform, nfan;\n", + "using Functors\n", + "\n", + "# Model creation\n", + "using NNlib;\n", + "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "\n", + "# Data processing\n", + "using MLDatasets;\n", + "using MLUtils: DataLoader;\n", + "using MLDataPattern;\n", + "using ImageCore;\n", + "using Augmentor;\n", + "using ImageFiltering;\n", + "using MappedArrays;\n", + "using Random;\n", + "using Flux: DataLoader;\n", + "# using OneHotArrays: onehotbatch\n", + "\n", + "\n", + "# Training\n", + "# using Yota;\n", + "using Zygote;\n", + "using Optimisers;\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "\n", + "\n", + "\n", + "# Issue when running this\n", + "#using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "19aff91e", + "metadata": {}, + "source": [ + "# Conv 2D" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "481a3d9a", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct Conv2D{T}\n", + " w::AbstractArray{T, 4}\n", + " b::AbstractVector{T}\n", + " use_bias::Bool\n", + " padding::Int \n", + "end\n", + "\n", + "@functor Conv2D (w, b)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "59da1b27", + "metadata": {}, + "outputs": [], + "source": [ + "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", + " bias::Bool=false, padding::Int=1)\n", + " w_size = (kernel_size..., in_channels, out_channels)\n", + " w = kaiming_uniform(w_size...)\n", + " (fan_in, fan_out) = nfan(w_size)\n", + " \n", + " if bias\n", + " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", + " bound = √3 * √2 / √fan_in\n", + " rng = Uniform(-bound, bound)\n", + " b = rand(rng, out_channels, Float32)\n", + " else\n", + " b = zeros(Float32, out_channels)\n", + " end\n", + "\n", + " return Conv2D(w, b, bias, padding)\n", + "end\n", + "\n", + "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", + " y = conv4(self.w, x; stride=stride, padding=self.padding, dilation=dilation)\n", + " if self.use_bias\n", + " # Bias is applied channel-wise\n", + " (w, h, c, b) = size(y)\n", + " bias = reshape(self.b, (1, 1, c, 1))\n", + " y = y .+ bias\n", + " end\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "252e934f", + "metadata": {}, + "source": [ + "# ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3e66be4f", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNetLayer\n", + " conv1::Conv2D\n", + " conv2::Conv2D\n", + " bn1::BatchNorm\n", + " bn2::BatchNorm\n", + " f::Function\n", + " in_channels::Int\n", + " channels::Int\n", + " stride::Int\n", + "end\n", + "\n", + "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", + "\n", + "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", + " (w, h, c, b) = size(x)\n", + " stride = layer.stride\n", + " if stride > 1\n", + " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", + " \n", + " # Strided downsample\n", + " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", + " else\n", + " x_id = x\n", + " end\n", + "\n", + " channels = layer.channels\n", + " in_channels = layer.in_channels\n", + " if in_channels < channels\n", + " # Zero padding on extra channels\n", + " (w, h, c, b) = size(x_id)\n", + " pad = zeros(w, h, channels - in_channels, b)\n", + " x_id = cat(x_id, pad; dims=3)\n", + " elseif in_channels > channels\n", + " error(\"in_channels > out_channels not supported\")\n", + " end\n", + " return x_id\n", + "end\n", + "\n", + "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", + " bn1 = BatchNorm(in_channels)\n", + " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", + " bn2 = BatchNorm(channels)\n", + " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", + "\n", + " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", + "end\n", + "\n", + "\n", + "function (self::ResNetLayer)(x::AbstractArray)\n", + " identity = residual_identity(self, x)\n", + " z = self.bn1(x)\n", + " z = self.f(z)\n", + " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", + " z = self.bn2(z)\n", + " z = self.f(z)\n", + " z = self.conv2(z; pad=1)\n", + "\n", + " y = z + identity\n", + " return y\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "9f06e04e", + "metadata": {}, + "source": [ + "# Testing ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7cdc72a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(16, 16, 10, 4)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "l = ResNetLayer(3, 10; stride=2);\n", + "x = randn(Float32, (32, 32, 3, 4));\n", + "y = l(x);\n", + "size(y)" + ] + }, + { + "cell_type": "markdown", + "id": "7b21b952", + "metadata": {}, + "source": [ + "# Linear Layer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8987f02c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: method definition for Linear at In[6]:22 declares type variable T but does not use it.\n" + ] + } + ], + "source": [ + "mutable struct Linear\n", + " W::AbstractMatrix{T} where T\n", + " b::AbstractVector{T} where T\n", + "end\n", + "\n", + "@functor Linear\n", + "\n", + "# Init\n", + "function Linear(in_features::Int, out_features::Int)\n", + " k_sqrt = sqrt(1 / in_features)\n", + " d = Uniform(-k_sqrt, k_sqrt)\n", + " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", + "end\n", + "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", + "\n", + "function Base.show(io::IO, l::Linear)\n", + " o, i = size(l.W)\n", + " print(io, \"Linear(o)\")\n", + "end\n", + "\n", + "# Forward\n", + "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" + ] + }, + { + "cell_type": "markdown", + "id": "79e8c6ca", + "metadata": {}, + "source": [ + "# Defining a Chain Layer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a47a2eaa", + "metadata": {}, + "outputs": [], + "source": [ + "# Define a chain of layers and a loss function:\n", + "struct Chain1; layers; end\n", + "(c::Chain1)(x) = (for l in c.layers; x = l(x); end; x)\n", + "(c::Chain1)(x,y) = nll(c(x),y)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "02eca287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet20Model" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ResNet Architecture\n", + "\n", + "mutable struct ResNet20Model\n", + " input_conv::Conv2D\n", + " resnet_blocks::Chain1\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet20Model\n", + "\n", + "function ResNet20Model(in_channels::Int, num_classes::Int)\n", + " resnet_blocks = Chain1((\n", + " block_1 = ResNetLayer(16, 16),\n", + " block_2 = ResNetLayer(16, 16),\n", + " block_3 = ResNetLayer(16, 16),\n", + " block_4 = ResNetLayer(16, 32; stride=2),\n", + " block_5 = ResNetLayer(32, 32),\n", + " block_6 = ResNetLayer(32, 32),\n", + " block_7 = ResNetLayer(32, 64; stride=2),\n", + " block_8 = ResNetLayer(64, 64),\n", + " block_9 = ResNetLayer(64, 64)\n", + " ))\n", + " return ResNet20Model(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " resnet_blocks,\n", + " GlobalMeanPool(),\n", + " Linear(64, num_classes)\n", + " )\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cdef0144", + "metadata": {}, + "outputs": [], + "source": [ + "function (self::ResNet20Model)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_blocks(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "25c15eb5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.\n", + "│ yT = Float64\n", + "│ T1 = Float64\n", + "│ T2 = Float32\n", + "└ @ NNlib C:\\Users\\Yash\\.julia\\packages\\NNlib\\0QnJJ\\src\\conv.jl:285\n" + ] + }, + { + "data": { + "text/plain": [ + "(10, 4)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Testing ResNet20 model\n", + "# Expected output: (10, 4)\n", + "m = ResNet20Model(3, 10);\n", + "inputs = randn(Float32, (32, 32, 3, 4))\n", + "outputs = m(inputs);\n", + "size(outputs)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "8e43380e", + "metadata": {}, + "source": [ + "# Data Preprocessing " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "84857fa0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32×32×3×45000 Array{Float32, 4}\n", + "45000-element Vector{Int64}\n", + "32×32×3×5000 Array{Float32, 4}\n", + "5000-element Vector{Int64}\n", + "32×32×3×10000 Array{Float32, 4}\n", + "10000-element Vector{Int64}\n" + ] + } + ], + "source": [ + "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", + "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", + "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", + "xtst,ytst = CIFAR10.testdata(Float32)\n", + "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "45acc000", + "metadata": {}, + "outputs": [], + "source": [ + "# Normalize all the data\n", + "\n", + "means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))\n", + "stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))\n", + "normalize(x) = (x .- means) ./ stdevs\n", + "\n", + "train_x = normalize(xtrn);\n", + "val_x = normalize(xval);\n", + "test_x = normalize(xtst);" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9e93cda3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "splitobs (generic function with 11 methods)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Train-test split\n", + "# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65\n", + "# obsview doesn't work with this data, so use getobs instead\n", + "\n", + "import MLDataPattern.splitobs;\n", + "\n", + "function splitobs(data; at, shuffle::Bool=false)\n", + " if shuffle\n", + " data = shuffleobs(data)\n", + " end\n", + " n = numobs(data)\n", + " return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9c649cac", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Notebook testing: Use less data\n", + "train_x, train_y = MLDatasets.getobs((train_x, ytrn), 1:500);\n", + "\n", + "val_x, val_y = MLDatasets.getobs((val_x, yval), 1:50);\n", + "\n", + "test_x, test_y = MLDatasets.getobs((test_x, ytst), 1:50);" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "75266187", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(40, 40, 3, 500)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Pad the training data for further augmentation\n", + "train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0))); \n", + "size(train_x_padded) # Should be (40, 40, 3, 50000)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "fc788d3e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6-step Augmentor.ImmutablePipeline:\n", + " 1.) Permute dimension order to (3, 1, 2)\n", + " 2.) Combine color channels into colorant RGB\n", + " 3.) Either: (50%) Flip the X axis. (50%) No operation.\n", + " 4.) Crop random window with size (32, 32)\n", + " 5.) Split colorant into its color channels\n", + " 6.) Permute dimension order to (2, 3, 1)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "815faf28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "outbatch (generic function with 1 method)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create an output array for augmented images\n", + "outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2e86e8f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "augmentbatch (generic function with 1 method)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Function that takes a batch (images and targets) and augments the images\n", + "augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e4d362ce", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: The specified values for size and/or count will result in 4 unused data points\n", + "└ @ MLDataPattern C:\\Users\\Yash\\.julia\\packages\\MLDataPattern\\KlSmO\\src\\dataview.jl:205\n" + ] + } + ], + "source": [ + "\n", + "# Shuffled and batched dataset of augmented images\n", + "train_batch_size = 16\n", + "\n", + "train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e2386c3c", + "metadata": {}, + "outputs": [], + "source": [ + "# Test and Validation data\n", + "test_batch_size = 32\n", + "\n", + "val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);\n", + "test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);" + ] + }, + { + "cell_type": "markdown", + "id": "05599606", + "metadata": {}, + "source": [ + "# Training setup" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "fd7aadd5", + "metadata": {}, + "outputs": [], + "source": [ + "#Sparse Cross Entropy function" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9f6c4d38", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "sparse_logit_cross_entropy (generic function with 1 method)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\"\"\"\n", + " sparse_logit_cross_entropy(logits, labels)\n", + "\n", + "Efficient computation of cross entropy loss with model logits and integer indices as labels.\n", + "Integer indices are from [0, N-1], where N is the number of classes\n", + "Similar to TensorFlow SparseCategoricalCrossEntropy\n", + "\n", + "# Arguments\n", + "- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)\n", + "- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)\n", + "\n", + "# Returns\n", + "- `loss::Float32`: Cross entropy loss\n", + "\"\"\"\n", + "# function sparse_logit_cross_entropy(logits, labels)\n", + "# log_probs = logsoftmax(logits);\n", + "# # Select indices of labels for loss\n", + "# log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);\n", + "# loss = -mean(log_probs);\n", + "# return loss\n", + "# end\n", + "\n", + "function sparse_logit_cross_entropy(logits, labels)\n", + " log_probs = logsoftmax(logits);\n", + " inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));\n", + " # Select indices of labels for loss\n", + " log_probs = log_probs[inds];\n", + " loss = -mean(log_probs);\n", + " return loss\n", + "end\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "3998a220", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Create model with 3 input channels and 10 classes\n", + "model = ResNet20Model(3, 10);" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6fa4497b", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup AdamW optimizer\n", + "β = (0.9, 0.999);\n", + "decay = 1e-4;\n", + "state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "b852506d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "(x, y) = first(train_batches);" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "e71cc12e", + "metadata": {}, + "outputs": [], + "source": [ + "# loss, g = grad(loss_function, model, x, y);" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "1a9a8a89", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNet5\n", + " input_conv::Conv2D\n", + " resnet_block::ResNetLayer\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet5\n", + "\n", + "function ResNet5(in_channels::Int, num_classes::Int)\n", + " return ResNet5(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " ResNetLayer(16, 16),\n", + " GlobalMeanPool(),\n", + " Linear(16, num_classes)\n", + " )\n", + "end\n", + "\n", + "function (self::ResNet5)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_block(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + "\n", + "\n", + "# function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)\n", + "# ŷ = model(x)\n", + "# loss = sparse_logit_cross_entropy(ŷ, y)\n", + "# return loss\n", + "# end" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "028a6d25", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?\n", + "# loss, g = grad(loss_function, model, x, y)\n", + "model = ResNet5(3, 10);\n", + "\n", + "# loss, g = Zygote.gradient(loss_function, model, x, y);" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "696231c0", + "metadata": {}, + "outputs": [], + "source": [ + "# g" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "7d23487b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet5(Conv2D{Float32}([-0.3269322 -0.09079589 -0.30220258; 0.29980195 -0.35697645 -0.43193826; 0.41894063 -0.27608046 -0.35023037;;; -0.11048572 -0.18733561 -0.048941202; -0.25535512 0.41386655 0.23646444; 0.1226552 0.19139434 -0.3201441;;; -0.2881651 0.4041223 -0.11729951; 0.28896266 0.124178275 0.10890088; -0.07136462 0.37597623 0.2907424;;;; 0.0116342725 -0.06491064 0.03947901; 0.36589766 -0.31363672 0.32354057; -0.101177834 0.22076249 0.26570976;;; -0.22781743 -0.16796216 0.079579934; 0.43243396 -0.18935399 0.3949348; -0.3725451 -0.06775151 0.21907443;;; -0.05270441 -0.43405735 -0.44125763; -0.47045088 -0.30292767 0.014733751; -0.04850591 -0.2133474 0.2412362;;;; 0.2572607 0.18735757 -0.33566207; -0.03157889 -0.04323261 0.1315869; -0.16356815 -0.23604983 0.051579874;;; -0.3262316 0.40397793 0.07843399; 0.17368728 0.31032175 -0.2273731; -0.20191403 0.11151084 0.33216488;;; -0.34083363 -0.46381113 0.055753145; -0.44104743 0.31393462 0.2622986; 0.13619547 -0.12979876 -0.043511562;;;; … ;;;; -0.020386824 -0.104114234 0.37638608; 0.41557428 -0.19767518 0.15894295; 0.150955 0.4521936 -0.26687187;;; 0.38604698 0.30180404 -0.1059084; -0.15032865 -0.031554278 -0.21704273; -0.03794346 0.3485954 0.38278958;;; -0.3071966 -0.3373205 0.26615357; 0.4422373 0.13577293 -0.10324652; 0.2894867 -0.23344572 0.39201385;;;; -0.13550712 0.30746755 0.38600484; 0.35319903 0.27227426 -0.42721114; -0.4167391 0.460941 0.23783916;;; 0.45264333 0.30202055 -0.32739767; -0.34008625 -0.23484135 0.19659689; -0.14264174 -0.09916833 0.27199847;;; -0.38800704 -0.060515907 0.4428402; 0.24729028 0.38564798 0.008954014; 0.10717848 0.3565583 -0.40317935;;;; 0.29413882 0.032473866 -0.24675108; 0.18658455 0.4415373 -0.07814981; 0.33296683 -0.1115019 -0.33509403;;; -0.24343053 0.042397596 0.35703608; 0.36186588 -0.05911843 -0.08993424; 0.13785711 -0.26265797 -0.067820854;;; -0.10355408 -0.26968983 0.097447224; -0.25024468 -0.1599089 0.4510931; 0.4365045 0.18134817 0.32099614], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), ResNetLayer(Conv2D{Float32}([0.19761112 0.13021815 -0.028087448; -0.00057760417 -0.11352962 0.05831184; 0.124577686 0.089965284 0.13280511;;; -0.11783607 0.089995995 -0.1501678; -0.015031724 -0.012284083 -0.0898654; -0.09346841 -0.1621385 -0.0109067345;;; 0.17110728 -0.10372673 -0.015643882; 0.008885473 0.09228887 0.20323306; 0.015771464 -0.12023175 0.039679002;;; … ;;; -0.046293672 0.0779443 0.011689981; 0.122653276 -0.106488414 0.16222343; 0.006254584 -0.017572897 0.07593362;;; 0.0962515 0.18770924 0.19851086; -0.13754818 -0.029091569 0.06325006; 0.19043466 -0.08298591 0.17023039;;; -0.061017822 0.17036065 0.06359598; -0.005008466 -0.15680619 0.1233305; 0.17700247 0.20328574 -0.16379757;;;; -0.18611959 -0.043676604 0.108012736; 0.02672942 -0.13492495 0.15822719; -0.002692112 0.16126484 -0.025754329;;; -0.15724711 -0.04828544 -0.1707655; -0.0970394 -0.055499617 -0.024079112; 0.025048656 -0.082812436 -0.022185529;;; -0.1882148 0.16506943 0.000598823; -0.0912264 -0.16550767 -0.16717595; 0.06827358 0.17189996 -0.16707191;;; … ;;; 0.18228532 -0.17576592 -0.18167828; 0.19565201 -0.20267504 -0.18348633; -0.15575626 0.15550001 -0.14896752;;; -0.08209669 0.15093498 -0.11133591; -0.05105649 0.022263639 0.027264051; -0.02951181 0.18026373 -0.07041432;;; 0.025740579 0.075161055 -0.0525457; 0.092696406 -0.09947786 -0.12829517; -0.069097444 0.12314727 -0.1672388;;;; 0.10623432 -0.057691578 0.119959146; 0.03539229 -0.022670422 -0.111270726; 0.008098309 0.0037883115 -0.026243139;;; -0.070837826 -0.11017056 -0.178822; 0.06665229 -0.005612837 0.07103156; 0.1561577 0.031744555 0.0140344165;;; 0.15617746 -0.14973398 0.07564629; 0.0016903019 0.18394831 0.09675205; 0.19826071 -0.09340203 0.1700775;;; … ;;; 0.069550656 0.14834794 -0.06968259; 0.20270519 0.11043808 0.027695874; 0.13334787 0.16532846 0.048797905;;; 0.18591698 0.018436953 -0.0032594716; -0.08772257 -0.052733872 -0.14566335; 0.011975072 -0.15187715 0.10042701;;; 0.112629846 0.18635167 0.16804218; -0.19342335 0.010884567 0.14426668; -0.059680305 -0.038495857 0.19673485;;;; … ;;;; 0.12924753 0.11957796 -0.107034236; -0.15904024 -0.1602915 -0.094139844; 0.08867885 0.17599945 -0.04848101;;; -0.10934039 -0.19765444 -0.14756997; 0.17149675 0.14435652 -0.10002485; 0.18451841 0.066142604 0.17169759;;; -0.11077233 -0.0039441674 -0.069028825; -0.110270225 0.00804806 -0.080900356; -0.16658163 0.054695323 -0.015006443;;; … ;;; 0.00212283 0.18980761 -0.122687295; -0.115651555 -0.14763168 -0.06032928; -0.18875845 0.16435969 0.015897948;;; -0.16687778 -0.13686112 -0.10878749; -0.06288342 0.18291253 0.06357345; 0.09998129 0.1873411 0.08618598;;; -0.08974118 -0.0026351716 0.16592684; 0.12683599 -0.0756004 -0.058696333; -0.025129566 -0.17873748 0.15622494;;;; -0.020486929 -0.1648232 0.0033708948; 0.07356035 -0.04425343 -0.006855549; -0.18207838 0.07768629 -0.11125955;;; 0.05911343 -0.13387455 -0.02985895; 0.19365218 -0.0906411 -0.047491293; 0.1927835 -0.12751262 -0.13904938;;; 0.09187676 0.122309804 0.055273946; 0.1549648 0.06880033 -0.11684039; -0.10373542 0.018027786 -0.1502131;;; … ;;; -0.013978474 -0.051094815 -0.0014352868; -0.046254106 0.19333138 0.0050440175; 0.02186895 -0.083350666 -0.20366172;;; 0.087388806 -0.12761207 -0.027745271; -0.13694362 -0.14262795 0.15379757; -0.088099465 0.05600551 0.05024689;;; -0.18494785 3.431023f-5 -0.059908554; -0.0059224563 0.1425408 -0.02109383; 0.08903239 0.13319586 -0.1353602;;;; 0.0057842666 0.15523006 -0.031229122; -0.19541235 0.051155504 0.06527311; -0.1730501 -0.15493153 -0.1712542;;; 0.15538698 0.19521399 -0.061315224; 0.05348293 -0.15681091 0.06969483; -0.15938468 0.00255144 -0.008160578;;; -0.042394936 -0.023364875 0.030308511; -0.106786154 -0.062052142 0.110279866; -0.024075657 -0.20078343 -0.11427486;;; … ;;; 0.18105797 -0.13237794 -0.1409754; 0.19100334 -0.061168298 0.046601903; 0.14125203 0.17658699 -0.14446235;;; 0.05207166 0.06348066 -0.09771767; 0.0501899 0.09775206 0.1500323; -0.048741743 -0.13970241 -0.013351497;;; 0.11777932 -0.08002654 0.18714364; 0.12242211 0.17391886 0.12460178; -0.0737201 -0.15127876 -0.10289584], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), Conv2D{Float32}([0.09717922 -0.03407784 0.014590169; -0.18445018 0.12831593 -0.07739436; -0.09349194 0.0037209808 -0.13268006;;; 0.17895766 -0.02247517 -0.037817243; -0.111089125 0.12584838 0.0016550183; 0.077631414 -0.044876367 -0.14228758;;; 0.13512933 0.14910595 -0.07057957; -0.121997386 0.13264237 0.19979271; -0.20085463 0.11576198 0.14067839;;; … ;;; 0.10820168 0.12200413 -0.08980403; 0.0011746995 0.1944284 0.18440181; 0.1767769 0.026093902 0.12368353;;; -0.007022744 -0.11667823 -0.20263171; 0.17421563 -0.11981526 -0.049599864; -0.15285212 -0.11957154 0.11077546;;; -0.158289 0.14341965 0.027710011; 0.1390186 -0.12730537 -0.16582859; -0.12259298 -0.11377486 -0.056429546;;;; 0.13843888 0.03513197 -0.15407993; -0.022550944 0.13677955 0.10618293; -0.14883694 0.16104752 -0.033526104;;; 0.11633827 0.073986895 0.039874643; 0.017279312 0.07693811 0.093623415; 0.10816275 -0.20379694 0.018923648;;; 0.15248454 0.06787851 -0.01800759; 0.09909373 -0.060843762 -0.027705412; -0.099589966 -0.1385283 -0.14809528;;; … ;;; -0.16959943 -0.06297494 0.05770058; 0.20113298 -0.047386974 0.11995617; -0.0032085418 -0.022590097 0.010460361;;; -0.18173099 0.13940151 -0.045853406; 0.07445716 0.10802089 -0.18385637; 0.12254616 -0.1648605 -0.1824452;;; -0.011050132 0.12620825 -0.021450827; 0.1770418 0.116070434 -0.091566995; -0.19971325 -0.16312777 -0.013281781;;;; -0.035456408 -0.036289588 -0.05979168; 0.09422635 -0.056371607 -0.0059971116; -0.19720648 0.015351248 -0.09199891;;; -0.03359222 0.0419892 0.18662222; -0.05208587 0.090805836 -0.073787235; -0.0937151 0.044613346 0.16400264;;; -0.10861183 0.17393792 -0.100663245; -0.09707624 -0.09358217 -0.17184801; 0.081750296 -0.076938204 0.06963327;;; … ;;; 0.108367175 -0.14059828 0.012450817; -0.045001905 -0.19173874 0.15082036; -0.008574759 0.15566646 0.14164986;;; -0.1068494 -0.030200982 0.17375617; 0.08390958 0.08448434 0.17300947; -0.16909282 0.056779485 0.03290582;;; -0.07413019 -0.031639017 0.0054012574; 0.033854 0.0916983 -0.18244815; 0.10101544 0.15076157 -0.16013426;;;; … ;;;; -0.063150845 0.17775898 0.1134759; -0.058250155 -0.0747142 0.15499249; -0.05641818 0.073472485 -0.17872544;;; 0.2006494 0.047009416 -0.006391436; -0.07353716 -0.09418866 -0.17701323; 0.024081083 -0.19869477 -0.00069250696;;; -0.049676735 0.0745745 -0.13708827; 0.08180036 -0.02534971 0.03376078; 0.071060985 0.17471206 -0.16068852;;; … ;;; 0.049811736 -0.11532743 0.002412666; 0.16178377 0.1469855 0.10922075; 0.04306834 0.057045255 -0.16866772;;; -0.013301638 -0.15541886 -0.035264272; 0.072164506 0.16410089 -0.09112215; 0.15948404 0.11376505 -0.10438497;;; 0.09450845 -0.19718246 0.073083684; 0.051564816 -0.033912666 -0.1722798; -0.047565997 -0.014022055 0.12652722;;;; 7.1759474f-5 -0.031632643 -0.09353316; 0.027056292 0.035871707 -0.011876182; 0.17951545 -0.13125665 -0.070746765;;; 0.059760485 -0.118754365 0.093996614; 0.13821077 0.19534363 0.12683688; 0.15409826 -0.17226508 -0.051132996;;; -0.092778996 -0.079336025 -0.10923138; 0.08583115 -0.0496824 0.012007193; 0.16740225 0.19941543 -0.06619081;;; … ;;; -0.059872467 0.015607577 0.13966718; 0.10251013 0.030340316 0.20154124; 0.10072651 0.06240758 -0.19964255;;; 0.19304037 0.102634646 0.15421943; 0.15746589 0.096430525 -0.04338655; 0.044689316 -0.047149528 -0.011374255;;; -0.026993804 -0.069641635 -0.063028984; 0.052857243 -0.08376117 -0.021661483; 0.09019801 -0.0031467103 -0.18673009;;;; 0.14249101 -0.18121263 0.020420937; -0.19246987 -0.011450296 -0.1786581; -0.0054819956 -0.016487014 -0.2034004;;; 0.17387968 -0.16070417 0.13330147; 0.07486626 0.16504952 -0.073846295; -0.14962983 -0.09689917 0.000648001;;; -0.15290444 0.17162122 0.0069305687; -0.06845603 -0.1037042 -0.13593975; 0.15990376 -0.08033438 -0.14466582;;; … ;;; -0.029019274 -0.16538228 0.11829052; 0.18074305 -0.04143785 0.16447136; 0.041418746 -0.17317066 0.11413952;;; 0.14402567 -0.030171514 0.10105601; -0.13740699 -0.02441596 0.11153571; 0.086634375 -0.15125358 0.14048335;;; 0.05758249 0.118906185 -0.10497864; -0.17066532 -0.044436127 0.023161229; -0.14597139 -0.03848912 0.17103036], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), BatchNorm(16), BatchNorm(16), NNlib.relu, 16, 16, 1), GlobalMeanPool(), Linear(o))" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "01294ff6", + "metadata": {}, + "source": [ + "# Training Knet" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "74f3f1da", + "metadata": {}, + "outputs": [ + { + "ename": "LoadError", + "evalue": "UndefVarError: progress! not defined", + "output_type": "error", + "traceback": [ + "UndefVarError: progress! not defined", + "", + "Stacktrace:", + " [1] macro expansion", + " @ .\\In[31]:12 [inlined]", + " [2] macro expansion", + " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", + " [3] macro expansion", + " @ .\\In[31]:11 [inlined]", + " [4] top-level scope", + " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", + " [5] top-level scope", + " @ .\\In[31]:0", + " [6] eval", + " @ .\\boot.jl:368 [inlined]", + " [7] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", + " @ Base .\\loading.jl:1428" + ] + } + ], + "source": [ + "\n", + "\n", + "loss(test_x, test_y) = nll(model(test_x), test_y)\n", + "evalcb = () -> (loss(test_x, test_y)) #function that will be called to get the loss \n", + "const to = TimerOutput() # creating a TimerOutput, keeps track of everything\n", + "\n", + "\n", + "@timeit to \"Train Total\" begin\n", + " for epoch in 1:10\n", + " train_epoch = epoch > 1 ? \"train_epoch\" : \"train_ji\"\n", + " @timeit to train_epoch begin\n", + " progress!(adam(model, train_batches; lr = 1e-3))\n", + " end\n", + " \n", + " evaluation = epoch > 1 ? \"evaluation\" : \"eval_jit\"\n", + " @timeit to evaluation begin\n", + " accuracy(model, train_batches)\n", + " end \n", + " \n", + " end \n", + "end " + ] + }, + { + "cell_type": "markdown", + "id": "a57b3c8d", + "metadata": {}, + "source": [ + "# Evaluation Function" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "02f69609", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# function evaluate(model, test_loader)\n", + "# preds = []\n", + "# targets = []\n", + "# for (x, y) in test_loader\n", + "# # Get model predictions\n", + "# # Note argmax of nd-array gives CartesianIndex\n", + "# # Need to grab the first element of each CartesianIndex to get the true index\n", + "# logits = model(x)\n", + "# ŷ = map(i -> i[1], argmax(logits, dims=1))\n", + "# append!(preds, ŷ)\n", + "\n", + "# # Get true labels\n", + "# append!(targets, y)\n", + "# end\n", + "# accuracy = sum(preds .== targets) / length(targets)\n", + "# return accuracy\n", + "# end" + ] + }, + { + "cell_type": "markdown", + "id": "f2072bd8", + "metadata": {}, + "source": [ + "# Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "cc39bcab", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# # Setup timing output\n", + "# const to = TimerOutput()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "9b5c088c", + "metadata": {}, + "outputs": [], + "source": [ + "# # last_loss = 0;\n", + "# # @timeit to \"total_training_time\" begin\n", + "# for epoch in 1:10\n", + "# timing_name = epoch > 1 ? \"average_epoch_training_time\" : \"train_jit\"\n", + "\n", + "# # Create lazily evaluated augmented training data\n", + "# train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", + "\n", + "# @timeit to timing_name begin\n", + "# losses = []\n", + "# for (x, y) in train_batches\n", + "# # loss_function does forward pass\n", + "# # Yota.jl grad function computes model parameter gradients in g[2]\n", + "# loss, g = grad(loss_function, model, x, y)\n", + " \n", + "# # Optimiser updates parameters\n", + "# Optimisers.update!(state, model, g[2])\n", + "# push!(losses, loss)\n", + "# end\n", + "# last_loss = mean(losses)\n", + "# @info(\"epoch (mean(losses))\")\n", + "# end\n", + "# # timing_name = epoch > 1 ? \"average_inference_time\" : \"eval_jit\"\n", + "# # @timeit to timing_name begin\n", + "# # acc = evaluate(model, test_loader)\n", + "# # @info(\"epoch (acc)\")\n", + "# # end\n", + "# end\n", + "# end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1955c486", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 5ad010b6e1c97c4bde2200431521a96b370f66dc Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Thu, 9 Feb 2023 01:44:01 -0500 Subject: [PATCH 24/26] Delete julia_resnetmodel_updated_model v2.ipynb --- julia_resnetmodel_updated_model v2.ipynb | 1036 ---------------------- 1 file changed, 1036 deletions(-) delete mode 100644 julia_resnetmodel_updated_model v2.ipynb diff --git a/julia_resnetmodel_updated_model v2.ipynb b/julia_resnetmodel_updated_model v2.ipynb deleted file mode 100644 index 68d134d..0000000 --- a/julia_resnetmodel_updated_model v2.ipynb +++ /dev/null @@ -1,1036 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "69f91157", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "9b1583d4", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "using Yota;\n", - "using MLDatasets;\n", - "using NNlib;\n", - "using Statistics;\n", - "using Distributions;\n", - "using Functors;\n", - "using Optimisers;\n", - "using MLUtils: DataLoader;\n", - "using OneHotArrays: onehotbatch\n", - "using Knet:Knet,conv4, adam\n", - "using Knet: dir, accuracy, progress, sgd, gc\n", - "using Metrics;\n", - "using TimerOutputs;\n", - "using Flux: BatchNorm, kaiming_uniform, nfan;\n", - "using Functors\n", - "\n", - "# Model creation\n", - "using NNlib;\n", - "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", - "using Statistics;\n", - "using Distributions;\n", - "using Functors;\n", - "\n", - "# Data processing\n", - "using MLDatasets;\n", - "using MLUtils: DataLoader;\n", - "using MLDataPattern;\n", - "using ImageCore;\n", - "using Augmentor;\n", - "using ImageFiltering;\n", - "using MappedArrays;\n", - "using Random;\n", - "using Flux: DataLoader;\n", - "# using OneHotArrays: onehotbatch\n", - "\n", - "\n", - "# Training\n", - "# using Yota;\n", - "using Zygote;\n", - "using Optimisers;\n", - "using Metrics;\n", - "using TimerOutputs;\n", - "\n", - "\n", - "\n", - "# Issue when running this\n", - "#using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "19aff91e", - "metadata": {}, - "source": [ - "# Conv 2D" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "481a3d9a", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct Conv2D{T}\n", - " w::AbstractArray{T, 4}\n", - " b::AbstractVector{T}\n", - " use_bias::Bool\n", - " padding::Int \n", - "end\n", - "\n", - "@functor Conv2D (w, b)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "59da1b27", - "metadata": {}, - "outputs": [], - "source": [ - "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", - " bias::Bool=false, padding::Int=1)\n", - " w_size = (kernel_size..., in_channels, out_channels)\n", - " w = kaiming_uniform(w_size...)\n", - " (fan_in, fan_out) = nfan(w_size)\n", - " \n", - " if bias\n", - " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", - " bound = √3 * √2 / √fan_in\n", - " rng = Uniform(-bound, bound)\n", - " b = rand(rng, out_channels, Float32)\n", - " else\n", - " b = zeros(Float32, out_channels)\n", - " end\n", - "\n", - " return Conv2D(w, b, bias, padding)\n", - "end\n", - "\n", - "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", - " y = conv4(self.w, x; stride=stride, padding=self.padding, dilation=dilation)\n", - " if self.use_bias\n", - " # Bias is applied channel-wise\n", - " (w, h, c, b) = size(y)\n", - " bias = reshape(self.b, (1, 1, c, 1))\n", - " y = y .+ bias\n", - " end\n", - " return y\n", - "end\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "252e934f", - "metadata": {}, - "source": [ - "# ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "3e66be4f", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct ResNetLayer\n", - " conv1::Conv2D\n", - " conv2::Conv2D\n", - " bn1::BatchNorm\n", - " bn2::BatchNorm\n", - " f::Function\n", - " in_channels::Int\n", - " channels::Int\n", - " stride::Int\n", - "end\n", - "\n", - "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", - "\n", - "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", - " (w, h, c, b) = size(x)\n", - " stride = layer.stride\n", - " if stride > 1\n", - " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", - " \n", - " # Strided downsample\n", - " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", - " else\n", - " x_id = x\n", - " end\n", - "\n", - " channels = layer.channels\n", - " in_channels = layer.in_channels\n", - " if in_channels < channels\n", - " # Zero padding on extra channels\n", - " (w, h, c, b) = size(x_id)\n", - " pad = zeros(w, h, channels - in_channels, b)\n", - " x_id = cat(x_id, pad; dims=3)\n", - " elseif in_channels > channels\n", - " error(\"in_channels > out_channels not supported\")\n", - " end\n", - " return x_id\n", - "end\n", - "\n", - "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", - " bn1 = BatchNorm(in_channels)\n", - " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", - " bn2 = BatchNorm(channels)\n", - " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", - "\n", - " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", - "end\n", - "\n", - "\n", - "function (self::ResNetLayer)(x::AbstractArray)\n", - " identity = residual_identity(self, x)\n", - " z = self.bn1(x)\n", - " z = self.f(z)\n", - " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", - " z = self.bn2(z)\n", - " z = self.f(z)\n", - " z = self.conv2(z; pad=1)\n", - "\n", - " y = z + identity\n", - " return y\n", - "end" - ] - }, - { - "cell_type": "markdown", - "id": "9f06e04e", - "metadata": {}, - "source": [ - "# Testing ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7cdc72a9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(16, 16, 10, 4)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "l = ResNetLayer(3, 10; stride=2);\n", - "x = randn(Float32, (32, 32, 3, 4));\n", - "y = l(x);\n", - "size(y)" - ] - }, - { - "cell_type": "markdown", - "id": "7b21b952", - "metadata": {}, - "source": [ - "# Linear Layer" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "8987f02c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: method definition for Linear at In[6]:22 declares type variable T but does not use it.\n" - ] - } - ], - "source": [ - "mutable struct Linear\n", - " W::AbstractMatrix{T} where T\n", - " b::AbstractVector{T} where T\n", - "end\n", - "\n", - "@functor Linear\n", - "\n", - "# Init\n", - "function Linear(in_features::Int, out_features::Int)\n", - " k_sqrt = sqrt(1 / in_features)\n", - " d = Uniform(-k_sqrt, k_sqrt)\n", - " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", - "end\n", - "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", - "\n", - "function Base.show(io::IO, l::Linear)\n", - " o, i = size(l.W)\n", - " print(io, \"Linear(o)\")\n", - "end\n", - "\n", - "# Forward\n", - "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" - ] - }, - { - "cell_type": "markdown", - "id": "79e8c6ca", - "metadata": {}, - "source": [ - "# Defining a Chain Layer" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "a47a2eaa", - "metadata": {}, - "outputs": [], - "source": [ - "# Define a chain of layers and a loss function:\n", - "struct Chain1; layers; end\n", - "(c::Chain1)(x) = (for l in c.layers; x = l(x); end; x)\n", - "(c::Chain1)(x,y) = nll(c(x),y)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "02eca287", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResNet20Model" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# ResNet Architecture\n", - "\n", - "mutable struct ResNet20Model\n", - " input_conv::Conv2D\n", - " resnet_blocks::Chain1\n", - " pool::GlobalMeanPool\n", - " linear::Linear\n", - "end\n", - "\n", - "@functor ResNet20Model\n", - "\n", - "function ResNet20Model(in_channels::Int, num_classes::Int)\n", - " resnet_blocks = Chain1((\n", - " block_1 = ResNetLayer(16, 16),\n", - " block_2 = ResNetLayer(16, 16),\n", - " block_3 = ResNetLayer(16, 16),\n", - " block_4 = ResNetLayer(16, 32; stride=2),\n", - " block_5 = ResNetLayer(32, 32),\n", - " block_6 = ResNetLayer(32, 32),\n", - " block_7 = ResNetLayer(32, 64; stride=2),\n", - " block_8 = ResNetLayer(64, 64),\n", - " block_9 = ResNetLayer(64, 64)\n", - " ))\n", - " return ResNet20Model(\n", - " Conv2D((3, 3), in_channels, 16, bias=false),\n", - " resnet_blocks,\n", - " GlobalMeanPool(),\n", - " Linear(64, num_classes)\n", - " )\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "cdef0144", - "metadata": {}, - "outputs": [], - "source": [ - "function (self::ResNet20Model)(x::AbstractArray)\n", - " z = self.input_conv(x)\n", - " z = self.resnet_blocks(z)\n", - " z = self.pool(z)\n", - " z = dropdims(z, dims=(1, 2))\n", - " y = self.linear(z)\n", - " return y\n", - "end\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "25c15eb5", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.\n", - "│ yT = Float64\n", - "│ T1 = Float64\n", - "│ T2 = Float32\n", - "└ @ NNlib C:\\Users\\Yash\\.julia\\packages\\NNlib\\0QnJJ\\src\\conv.jl:285\n" - ] - }, - { - "data": { - "text/plain": [ - "(10, 4)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Testing ResNet20 model\n", - "# Expected output: (10, 4)\n", - "m = ResNet20Model(3, 10);\n", - "inputs = randn(Float32, (32, 32, 3, 4))\n", - "outputs = m(inputs);\n", - "size(outputs)\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "8e43380e", - "metadata": {}, - "source": [ - "# Data Preprocessing " - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "84857fa0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "32×32×3×45000 Array{Float32, 4}\n", - "45000-element Vector{Int64}\n", - "32×32×3×5000 Array{Float32, 4}\n", - "5000-element Vector{Int64}\n", - "32×32×3×10000 Array{Float32, 4}\n", - "10000-element Vector{Int64}\n" - ] - } - ], - "source": [ - "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", - "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", - "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", - "xtst,ytst = CIFAR10.testdata(Float32)\n", - "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "45acc000", - "metadata": {}, - "outputs": [], - "source": [ - "# Normalize all the data\n", - "\n", - "means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))\n", - "stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))\n", - "normalize(x) = (x .- means) ./ stdevs\n", - "\n", - "train_x = normalize(xtrn);\n", - "val_x = normalize(xval);\n", - "test_x = normalize(xtst);" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "9e93cda3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "splitobs (generic function with 11 methods)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Train-test split\n", - "# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65\n", - "# obsview doesn't work with this data, so use getobs instead\n", - "\n", - "import MLDataPattern.splitobs;\n", - "\n", - "function splitobs(data; at, shuffle::Bool=false)\n", - " if shuffle\n", - " data = shuffleobs(data)\n", - " end\n", - " n = numobs(data)\n", - " return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "9c649cac", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Notebook testing: Use less data\n", - "train_x, train_y = MLDatasets.getobs((train_x, ytrn), 1:500);\n", - "\n", - "val_x, val_y = MLDatasets.getobs((val_x, yval), 1:50);\n", - "\n", - "test_x, test_y = MLDatasets.getobs((test_x, ytst), 1:50);" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "75266187", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(40, 40, 3, 500)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Pad the training data for further augmentation\n", - "train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0))); \n", - "size(train_x_padded) # Should be (40, 40, 3, 50000)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "fc788d3e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6-step Augmentor.ImmutablePipeline:\n", - " 1.) Permute dimension order to (3, 1, 2)\n", - " 2.) Combine color channels into colorant RGB\n", - " 3.) Either: (50%) Flip the X axis. (50%) No operation.\n", - " 4.) Crop random window with size (32, 32)\n", - " 5.) Split colorant into its color channels\n", - " 6.) Permute dimension order to (2, 3, 1)" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "815faf28", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "outbatch (generic function with 1 method)" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Create an output array for augmented images\n", - "outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "2e86e8f7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "augmentbatch (generic function with 1 method)" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Function that takes a batch (images and targets) and augments the images\n", - "augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "e4d362ce", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┌ Warning: The specified values for size and/or count will result in 4 unused data points\n", - "└ @ MLDataPattern C:\\Users\\Yash\\.julia\\packages\\MLDataPattern\\KlSmO\\src\\dataview.jl:205\n" - ] - } - ], - "source": [ - "\n", - "# Shuffled and batched dataset of augmented images\n", - "train_batch_size = 16\n", - "\n", - "train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "e2386c3c", - "metadata": {}, - "outputs": [], - "source": [ - "# Test and Validation data\n", - "test_batch_size = 32\n", - "\n", - "val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);\n", - "test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);" - ] - }, - { - "cell_type": "markdown", - "id": "05599606", - "metadata": {}, - "source": [ - "# Training setup" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "fd7aadd5", - "metadata": {}, - "outputs": [], - "source": [ - "#Sparse Cross Entropy function" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "9f6c4d38", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "sparse_logit_cross_entropy (generic function with 1 method)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "\"\"\"\n", - " sparse_logit_cross_entropy(logits, labels)\n", - "\n", - "Efficient computation of cross entropy loss with model logits and integer indices as labels.\n", - "Integer indices are from [0, N-1], where N is the number of classes\n", - "Similar to TensorFlow SparseCategoricalCrossEntropy\n", - "\n", - "# Arguments\n", - "- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)\n", - "- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)\n", - "\n", - "# Returns\n", - "- `loss::Float32`: Cross entropy loss\n", - "\"\"\"\n", - "# function sparse_logit_cross_entropy(logits, labels)\n", - "# log_probs = logsoftmax(logits);\n", - "# # Select indices of labels for loss\n", - "# log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);\n", - "# loss = -mean(log_probs);\n", - "# return loss\n", - "# end\n", - "\n", - "function sparse_logit_cross_entropy(logits, labels)\n", - " log_probs = logsoftmax(logits);\n", - " inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));\n", - " # Select indices of labels for loss\n", - " log_probs = log_probs[inds];\n", - " loss = -mean(log_probs);\n", - " return loss\n", - "end\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "3998a220", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Create model with 3 input channels and 10 classes\n", - "model = ResNet20Model(3, 10);" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "6fa4497b", - "metadata": {}, - "outputs": [], - "source": [ - "# Setup AdamW optimizer\n", - "β = (0.9, 0.999);\n", - "decay = 1e-4;\n", - "state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "b852506d", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "(x, y) = first(train_batches);" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "e71cc12e", - "metadata": {}, - "outputs": [], - "source": [ - "# loss, g = grad(loss_function, model, x, y);" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "1a9a8a89", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct ResNet5\n", - " input_conv::Conv2D\n", - " resnet_block::ResNetLayer\n", - " pool::GlobalMeanPool\n", - " linear::Linear\n", - "end\n", - "\n", - "@functor ResNet5\n", - "\n", - "function ResNet5(in_channels::Int, num_classes::Int)\n", - " return ResNet5(\n", - " Conv2D((3, 3), in_channels, 16, bias=false),\n", - " ResNetLayer(16, 16),\n", - " GlobalMeanPool(),\n", - " Linear(16, num_classes)\n", - " )\n", - "end\n", - "\n", - "function (self::ResNet5)(x::AbstractArray)\n", - " z = self.input_conv(x)\n", - " z = self.resnet_block(z)\n", - " z = self.pool(z)\n", - " z = dropdims(z, dims=(1, 2))\n", - " y = self.linear(z)\n", - " return y\n", - "end\n", - "\n", - "\n", - "# function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)\n", - "# ŷ = model(x)\n", - "# loss = sparse_logit_cross_entropy(ŷ, y)\n", - "# return loss\n", - "# end" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "028a6d25", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?\n", - "# loss, g = grad(loss_function, model, x, y)\n", - "model = ResNet5(3, 10);\n", - "\n", - "# loss, g = Zygote.gradient(loss_function, model, x, y);" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "696231c0", - "metadata": {}, - "outputs": [], - "source": [ - "# g" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "7d23487b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResNet5(Conv2D{Float32}([-0.3269322 -0.09079589 -0.30220258; 0.29980195 -0.35697645 -0.43193826; 0.41894063 -0.27608046 -0.35023037;;; -0.11048572 -0.18733561 -0.048941202; -0.25535512 0.41386655 0.23646444; 0.1226552 0.19139434 -0.3201441;;; -0.2881651 0.4041223 -0.11729951; 0.28896266 0.124178275 0.10890088; -0.07136462 0.37597623 0.2907424;;;; 0.0116342725 -0.06491064 0.03947901; 0.36589766 -0.31363672 0.32354057; -0.101177834 0.22076249 0.26570976;;; -0.22781743 -0.16796216 0.079579934; 0.43243396 -0.18935399 0.3949348; -0.3725451 -0.06775151 0.21907443;;; -0.05270441 -0.43405735 -0.44125763; -0.47045088 -0.30292767 0.014733751; -0.04850591 -0.2133474 0.2412362;;;; 0.2572607 0.18735757 -0.33566207; -0.03157889 -0.04323261 0.1315869; -0.16356815 -0.23604983 0.051579874;;; -0.3262316 0.40397793 0.07843399; 0.17368728 0.31032175 -0.2273731; -0.20191403 0.11151084 0.33216488;;; -0.34083363 -0.46381113 0.055753145; -0.44104743 0.31393462 0.2622986; 0.13619547 -0.12979876 -0.043511562;;;; … ;;;; -0.020386824 -0.104114234 0.37638608; 0.41557428 -0.19767518 0.15894295; 0.150955 0.4521936 -0.26687187;;; 0.38604698 0.30180404 -0.1059084; -0.15032865 -0.031554278 -0.21704273; -0.03794346 0.3485954 0.38278958;;; -0.3071966 -0.3373205 0.26615357; 0.4422373 0.13577293 -0.10324652; 0.2894867 -0.23344572 0.39201385;;;; -0.13550712 0.30746755 0.38600484; 0.35319903 0.27227426 -0.42721114; -0.4167391 0.460941 0.23783916;;; 0.45264333 0.30202055 -0.32739767; -0.34008625 -0.23484135 0.19659689; -0.14264174 -0.09916833 0.27199847;;; -0.38800704 -0.060515907 0.4428402; 0.24729028 0.38564798 0.008954014; 0.10717848 0.3565583 -0.40317935;;;; 0.29413882 0.032473866 -0.24675108; 0.18658455 0.4415373 -0.07814981; 0.33296683 -0.1115019 -0.33509403;;; -0.24343053 0.042397596 0.35703608; 0.36186588 -0.05911843 -0.08993424; 0.13785711 -0.26265797 -0.067820854;;; -0.10355408 -0.26968983 0.097447224; -0.25024468 -0.1599089 0.4510931; 0.4365045 0.18134817 0.32099614], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), ResNetLayer(Conv2D{Float32}([0.19761112 0.13021815 -0.028087448; -0.00057760417 -0.11352962 0.05831184; 0.124577686 0.089965284 0.13280511;;; -0.11783607 0.089995995 -0.1501678; -0.015031724 -0.012284083 -0.0898654; -0.09346841 -0.1621385 -0.0109067345;;; 0.17110728 -0.10372673 -0.015643882; 0.008885473 0.09228887 0.20323306; 0.015771464 -0.12023175 0.039679002;;; … ;;; -0.046293672 0.0779443 0.011689981; 0.122653276 -0.106488414 0.16222343; 0.006254584 -0.017572897 0.07593362;;; 0.0962515 0.18770924 0.19851086; -0.13754818 -0.029091569 0.06325006; 0.19043466 -0.08298591 0.17023039;;; -0.061017822 0.17036065 0.06359598; -0.005008466 -0.15680619 0.1233305; 0.17700247 0.20328574 -0.16379757;;;; -0.18611959 -0.043676604 0.108012736; 0.02672942 -0.13492495 0.15822719; -0.002692112 0.16126484 -0.025754329;;; -0.15724711 -0.04828544 -0.1707655; -0.0970394 -0.055499617 -0.024079112; 0.025048656 -0.082812436 -0.022185529;;; -0.1882148 0.16506943 0.000598823; -0.0912264 -0.16550767 -0.16717595; 0.06827358 0.17189996 -0.16707191;;; … ;;; 0.18228532 -0.17576592 -0.18167828; 0.19565201 -0.20267504 -0.18348633; -0.15575626 0.15550001 -0.14896752;;; -0.08209669 0.15093498 -0.11133591; -0.05105649 0.022263639 0.027264051; -0.02951181 0.18026373 -0.07041432;;; 0.025740579 0.075161055 -0.0525457; 0.092696406 -0.09947786 -0.12829517; -0.069097444 0.12314727 -0.1672388;;;; 0.10623432 -0.057691578 0.119959146; 0.03539229 -0.022670422 -0.111270726; 0.008098309 0.0037883115 -0.026243139;;; -0.070837826 -0.11017056 -0.178822; 0.06665229 -0.005612837 0.07103156; 0.1561577 0.031744555 0.0140344165;;; 0.15617746 -0.14973398 0.07564629; 0.0016903019 0.18394831 0.09675205; 0.19826071 -0.09340203 0.1700775;;; … ;;; 0.069550656 0.14834794 -0.06968259; 0.20270519 0.11043808 0.027695874; 0.13334787 0.16532846 0.048797905;;; 0.18591698 0.018436953 -0.0032594716; -0.08772257 -0.052733872 -0.14566335; 0.011975072 -0.15187715 0.10042701;;; 0.112629846 0.18635167 0.16804218; -0.19342335 0.010884567 0.14426668; -0.059680305 -0.038495857 0.19673485;;;; … ;;;; 0.12924753 0.11957796 -0.107034236; -0.15904024 -0.1602915 -0.094139844; 0.08867885 0.17599945 -0.04848101;;; -0.10934039 -0.19765444 -0.14756997; 0.17149675 0.14435652 -0.10002485; 0.18451841 0.066142604 0.17169759;;; -0.11077233 -0.0039441674 -0.069028825; -0.110270225 0.00804806 -0.080900356; -0.16658163 0.054695323 -0.015006443;;; … ;;; 0.00212283 0.18980761 -0.122687295; -0.115651555 -0.14763168 -0.06032928; -0.18875845 0.16435969 0.015897948;;; -0.16687778 -0.13686112 -0.10878749; -0.06288342 0.18291253 0.06357345; 0.09998129 0.1873411 0.08618598;;; -0.08974118 -0.0026351716 0.16592684; 0.12683599 -0.0756004 -0.058696333; -0.025129566 -0.17873748 0.15622494;;;; -0.020486929 -0.1648232 0.0033708948; 0.07356035 -0.04425343 -0.006855549; -0.18207838 0.07768629 -0.11125955;;; 0.05911343 -0.13387455 -0.02985895; 0.19365218 -0.0906411 -0.047491293; 0.1927835 -0.12751262 -0.13904938;;; 0.09187676 0.122309804 0.055273946; 0.1549648 0.06880033 -0.11684039; -0.10373542 0.018027786 -0.1502131;;; … ;;; -0.013978474 -0.051094815 -0.0014352868; -0.046254106 0.19333138 0.0050440175; 0.02186895 -0.083350666 -0.20366172;;; 0.087388806 -0.12761207 -0.027745271; -0.13694362 -0.14262795 0.15379757; -0.088099465 0.05600551 0.05024689;;; -0.18494785 3.431023f-5 -0.059908554; -0.0059224563 0.1425408 -0.02109383; 0.08903239 0.13319586 -0.1353602;;;; 0.0057842666 0.15523006 -0.031229122; -0.19541235 0.051155504 0.06527311; -0.1730501 -0.15493153 -0.1712542;;; 0.15538698 0.19521399 -0.061315224; 0.05348293 -0.15681091 0.06969483; -0.15938468 0.00255144 -0.008160578;;; -0.042394936 -0.023364875 0.030308511; -0.106786154 -0.062052142 0.110279866; -0.024075657 -0.20078343 -0.11427486;;; … ;;; 0.18105797 -0.13237794 -0.1409754; 0.19100334 -0.061168298 0.046601903; 0.14125203 0.17658699 -0.14446235;;; 0.05207166 0.06348066 -0.09771767; 0.0501899 0.09775206 0.1500323; -0.048741743 -0.13970241 -0.013351497;;; 0.11777932 -0.08002654 0.18714364; 0.12242211 0.17391886 0.12460178; -0.0737201 -0.15127876 -0.10289584], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), Conv2D{Float32}([0.09717922 -0.03407784 0.014590169; -0.18445018 0.12831593 -0.07739436; -0.09349194 0.0037209808 -0.13268006;;; 0.17895766 -0.02247517 -0.037817243; -0.111089125 0.12584838 0.0016550183; 0.077631414 -0.044876367 -0.14228758;;; 0.13512933 0.14910595 -0.07057957; -0.121997386 0.13264237 0.19979271; -0.20085463 0.11576198 0.14067839;;; … ;;; 0.10820168 0.12200413 -0.08980403; 0.0011746995 0.1944284 0.18440181; 0.1767769 0.026093902 0.12368353;;; -0.007022744 -0.11667823 -0.20263171; 0.17421563 -0.11981526 -0.049599864; -0.15285212 -0.11957154 0.11077546;;; -0.158289 0.14341965 0.027710011; 0.1390186 -0.12730537 -0.16582859; -0.12259298 -0.11377486 -0.056429546;;;; 0.13843888 0.03513197 -0.15407993; -0.022550944 0.13677955 0.10618293; -0.14883694 0.16104752 -0.033526104;;; 0.11633827 0.073986895 0.039874643; 0.017279312 0.07693811 0.093623415; 0.10816275 -0.20379694 0.018923648;;; 0.15248454 0.06787851 -0.01800759; 0.09909373 -0.060843762 -0.027705412; -0.099589966 -0.1385283 -0.14809528;;; … ;;; -0.16959943 -0.06297494 0.05770058; 0.20113298 -0.047386974 0.11995617; -0.0032085418 -0.022590097 0.010460361;;; -0.18173099 0.13940151 -0.045853406; 0.07445716 0.10802089 -0.18385637; 0.12254616 -0.1648605 -0.1824452;;; -0.011050132 0.12620825 -0.021450827; 0.1770418 0.116070434 -0.091566995; -0.19971325 -0.16312777 -0.013281781;;;; -0.035456408 -0.036289588 -0.05979168; 0.09422635 -0.056371607 -0.0059971116; -0.19720648 0.015351248 -0.09199891;;; -0.03359222 0.0419892 0.18662222; -0.05208587 0.090805836 -0.073787235; -0.0937151 0.044613346 0.16400264;;; -0.10861183 0.17393792 -0.100663245; -0.09707624 -0.09358217 -0.17184801; 0.081750296 -0.076938204 0.06963327;;; … ;;; 0.108367175 -0.14059828 0.012450817; -0.045001905 -0.19173874 0.15082036; -0.008574759 0.15566646 0.14164986;;; -0.1068494 -0.030200982 0.17375617; 0.08390958 0.08448434 0.17300947; -0.16909282 0.056779485 0.03290582;;; -0.07413019 -0.031639017 0.0054012574; 0.033854 0.0916983 -0.18244815; 0.10101544 0.15076157 -0.16013426;;;; … ;;;; -0.063150845 0.17775898 0.1134759; -0.058250155 -0.0747142 0.15499249; -0.05641818 0.073472485 -0.17872544;;; 0.2006494 0.047009416 -0.006391436; -0.07353716 -0.09418866 -0.17701323; 0.024081083 -0.19869477 -0.00069250696;;; -0.049676735 0.0745745 -0.13708827; 0.08180036 -0.02534971 0.03376078; 0.071060985 0.17471206 -0.16068852;;; … ;;; 0.049811736 -0.11532743 0.002412666; 0.16178377 0.1469855 0.10922075; 0.04306834 0.057045255 -0.16866772;;; -0.013301638 -0.15541886 -0.035264272; 0.072164506 0.16410089 -0.09112215; 0.15948404 0.11376505 -0.10438497;;; 0.09450845 -0.19718246 0.073083684; 0.051564816 -0.033912666 -0.1722798; -0.047565997 -0.014022055 0.12652722;;;; 7.1759474f-5 -0.031632643 -0.09353316; 0.027056292 0.035871707 -0.011876182; 0.17951545 -0.13125665 -0.070746765;;; 0.059760485 -0.118754365 0.093996614; 0.13821077 0.19534363 0.12683688; 0.15409826 -0.17226508 -0.051132996;;; -0.092778996 -0.079336025 -0.10923138; 0.08583115 -0.0496824 0.012007193; 0.16740225 0.19941543 -0.06619081;;; … ;;; -0.059872467 0.015607577 0.13966718; 0.10251013 0.030340316 0.20154124; 0.10072651 0.06240758 -0.19964255;;; 0.19304037 0.102634646 0.15421943; 0.15746589 0.096430525 -0.04338655; 0.044689316 -0.047149528 -0.011374255;;; -0.026993804 -0.069641635 -0.063028984; 0.052857243 -0.08376117 -0.021661483; 0.09019801 -0.0031467103 -0.18673009;;;; 0.14249101 -0.18121263 0.020420937; -0.19246987 -0.011450296 -0.1786581; -0.0054819956 -0.016487014 -0.2034004;;; 0.17387968 -0.16070417 0.13330147; 0.07486626 0.16504952 -0.073846295; -0.14962983 -0.09689917 0.000648001;;; -0.15290444 0.17162122 0.0069305687; -0.06845603 -0.1037042 -0.13593975; 0.15990376 -0.08033438 -0.14466582;;; … ;;; -0.029019274 -0.16538228 0.11829052; 0.18074305 -0.04143785 0.16447136; 0.041418746 -0.17317066 0.11413952;;; 0.14402567 -0.030171514 0.10105601; -0.13740699 -0.02441596 0.11153571; 0.086634375 -0.15125358 0.14048335;;; 0.05758249 0.118906185 -0.10497864; -0.17066532 -0.044436127 0.023161229; -0.14597139 -0.03848912 0.17103036], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), BatchNorm(16), BatchNorm(16), NNlib.relu, 16, 16, 1), GlobalMeanPool(), Linear(o))" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model" - ] - }, - { - "cell_type": "markdown", - "id": "01294ff6", - "metadata": {}, - "source": [ - "# Training Knet" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "74f3f1da", - "metadata": {}, - "outputs": [ - { - "ename": "LoadError", - "evalue": "UndefVarError: progress! not defined", - "output_type": "error", - "traceback": [ - "UndefVarError: progress! not defined", - "", - "Stacktrace:", - " [1] macro expansion", - " @ .\\In[31]:12 [inlined]", - " [2] macro expansion", - " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", - " [3] macro expansion", - " @ .\\In[31]:11 [inlined]", - " [4] top-level scope", - " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", - " [5] top-level scope", - " @ .\\In[31]:0", - " [6] eval", - " @ .\\boot.jl:368 [inlined]", - " [7] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", - " @ Base .\\loading.jl:1428" - ] - } - ], - "source": [ - "\n", - "\n", - "loss(test_x, test_y) = nll(model(test_x), test_y)\n", - "evalcb = () -> (loss(test_x, test_y)) #function that will be called to get the loss \n", - "const to = TimerOutput() # creating a TimerOutput, keeps track of everything\n", - "\n", - "\n", - "@timeit to \"Train Total\" begin\n", - " for epoch in 1:10\n", - " train_epoch = epoch > 1 ? \"train_epoch\" : \"train_ji\"\n", - " @timeit to train_epoch begin\n", - " progress!(adam(model, train_batches; lr = 1e-3))\n", - " end\n", - " \n", - " evaluation = epoch > 1 ? \"evaluation\" : \"eval_jit\"\n", - " @timeit to evaluation begin\n", - " accuracy(model, train_batches)\n", - " end \n", - " \n", - " end \n", - "end " - ] - }, - { - "cell_type": "markdown", - "id": "a57b3c8d", - "metadata": {}, - "source": [ - "# Evaluation Function" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "02f69609", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# function evaluate(model, test_loader)\n", - "# preds = []\n", - "# targets = []\n", - "# for (x, y) in test_loader\n", - "# # Get model predictions\n", - "# # Note argmax of nd-array gives CartesianIndex\n", - "# # Need to grab the first element of each CartesianIndex to get the true index\n", - "# logits = model(x)\n", - "# ŷ = map(i -> i[1], argmax(logits, dims=1))\n", - "# append!(preds, ŷ)\n", - "\n", - "# # Get true labels\n", - "# append!(targets, y)\n", - "# end\n", - "# accuracy = sum(preds .== targets) / length(targets)\n", - "# return accuracy\n", - "# end" - ] - }, - { - "cell_type": "markdown", - "id": "f2072bd8", - "metadata": {}, - "source": [ - "# Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "cc39bcab", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# # Setup timing output\n", - "# const to = TimerOutput()" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "9b5c088c", - "metadata": {}, - "outputs": [], - "source": [ - "# # last_loss = 0;\n", - "# # @timeit to \"total_training_time\" begin\n", - "# for epoch in 1:10\n", - "# timing_name = epoch > 1 ? \"average_epoch_training_time\" : \"train_jit\"\n", - "\n", - "# # Create lazily evaluated augmented training data\n", - "# train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", - "\n", - "# @timeit to timing_name begin\n", - "# losses = []\n", - "# for (x, y) in train_batches\n", - "# # loss_function does forward pass\n", - "# # Yota.jl grad function computes model parameter gradients in g[2]\n", - "# loss, g = grad(loss_function, model, x, y)\n", - " \n", - "# # Optimiser updates parameters\n", - "# Optimisers.update!(state, model, g[2])\n", - "# push!(losses, loss)\n", - "# end\n", - "# last_loss = mean(losses)\n", - "# @info(\"epoch (mean(losses))\")\n", - "# end\n", - "# # timing_name = epoch > 1 ? \"average_inference_time\" : \"eval_jit\"\n", - "# # @timeit to timing_name begin\n", - "# # acc = evaluate(model, test_loader)\n", - "# # @info(\"epoch (acc)\")\n", - "# # end\n", - "# end\n", - "# end" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1955c486", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Julia 1.8.2", - "language": "julia", - "name": "julia-1.8" - }, - "language_info": { - "file_extension": ".jl", - "mimetype": "application/julia", - "name": "julia", - "version": "1.8.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 8f79a0e811aae6c56f12d6c16897a0dc6989f2f7 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Thu, 9 Feb 2023 01:44:13 -0500 Subject: [PATCH 25/26] Add files via upload --- .../julia_resnetmodel_updated_model v2.ipynb | 1036 +++++++++++++++++ 1 file changed, 1036 insertions(+) create mode 100644 convolutional neural network/julia_resnetmodel_updated_model v2.ipynb diff --git a/convolutional neural network/julia_resnetmodel_updated_model v2.ipynb b/convolutional neural network/julia_resnetmodel_updated_model v2.ipynb new file mode 100644 index 0000000..68d134d --- /dev/null +++ b/convolutional neural network/julia_resnetmodel_updated_model v2.ipynb @@ -0,0 +1,1036 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "69f91157", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9b1583d4", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "using Yota;\n", + "using MLDatasets;\n", + "using NNlib;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "using Optimisers;\n", + "using MLUtils: DataLoader;\n", + "using OneHotArrays: onehotbatch\n", + "using Knet:Knet,conv4, adam\n", + "using Knet: dir, accuracy, progress, sgd, gc\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "using Flux: BatchNorm, kaiming_uniform, nfan;\n", + "using Functors\n", + "\n", + "# Model creation\n", + "using NNlib;\n", + "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", + "using Statistics;\n", + "using Distributions;\n", + "using Functors;\n", + "\n", + "# Data processing\n", + "using MLDatasets;\n", + "using MLUtils: DataLoader;\n", + "using MLDataPattern;\n", + "using ImageCore;\n", + "using Augmentor;\n", + "using ImageFiltering;\n", + "using MappedArrays;\n", + "using Random;\n", + "using Flux: DataLoader;\n", + "# using OneHotArrays: onehotbatch\n", + "\n", + "\n", + "# Training\n", + "# using Yota;\n", + "using Zygote;\n", + "using Optimisers;\n", + "using Metrics;\n", + "using TimerOutputs;\n", + "\n", + "\n", + "\n", + "# Issue when running this\n", + "#using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "19aff91e", + "metadata": {}, + "source": [ + "# Conv 2D" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "481a3d9a", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct Conv2D{T}\n", + " w::AbstractArray{T, 4}\n", + " b::AbstractVector{T}\n", + " use_bias::Bool\n", + " padding::Int \n", + "end\n", + "\n", + "@functor Conv2D (w, b)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "59da1b27", + "metadata": {}, + "outputs": [], + "source": [ + "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", + " bias::Bool=false, padding::Int=1)\n", + " w_size = (kernel_size..., in_channels, out_channels)\n", + " w = kaiming_uniform(w_size...)\n", + " (fan_in, fan_out) = nfan(w_size)\n", + " \n", + " if bias\n", + " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", + " bound = √3 * √2 / √fan_in\n", + " rng = Uniform(-bound, bound)\n", + " b = rand(rng, out_channels, Float32)\n", + " else\n", + " b = zeros(Float32, out_channels)\n", + " end\n", + "\n", + " return Conv2D(w, b, bias, padding)\n", + "end\n", + "\n", + "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", + " y = conv4(self.w, x; stride=stride, padding=self.padding, dilation=dilation)\n", + " if self.use_bias\n", + " # Bias is applied channel-wise\n", + " (w, h, c, b) = size(y)\n", + " bias = reshape(self.b, (1, 1, c, 1))\n", + " y = y .+ bias\n", + " end\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "252e934f", + "metadata": {}, + "source": [ + "# ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3e66be4f", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNetLayer\n", + " conv1::Conv2D\n", + " conv2::Conv2D\n", + " bn1::BatchNorm\n", + " bn2::BatchNorm\n", + " f::Function\n", + " in_channels::Int\n", + " channels::Int\n", + " stride::Int\n", + "end\n", + "\n", + "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", + "\n", + "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", + " (w, h, c, b) = size(x)\n", + " stride = layer.stride\n", + " if stride > 1\n", + " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", + " \n", + " # Strided downsample\n", + " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", + " else\n", + " x_id = x\n", + " end\n", + "\n", + " channels = layer.channels\n", + " in_channels = layer.in_channels\n", + " if in_channels < channels\n", + " # Zero padding on extra channels\n", + " (w, h, c, b) = size(x_id)\n", + " pad = zeros(w, h, channels - in_channels, b)\n", + " x_id = cat(x_id, pad; dims=3)\n", + " elseif in_channels > channels\n", + " error(\"in_channels > out_channels not supported\")\n", + " end\n", + " return x_id\n", + "end\n", + "\n", + "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", + " bn1 = BatchNorm(in_channels)\n", + " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", + " bn2 = BatchNorm(channels)\n", + " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", + "\n", + " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", + "end\n", + "\n", + "\n", + "function (self::ResNetLayer)(x::AbstractArray)\n", + " identity = residual_identity(self, x)\n", + " z = self.bn1(x)\n", + " z = self.f(z)\n", + " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", + " z = self.bn2(z)\n", + " z = self.f(z)\n", + " z = self.conv2(z; pad=1)\n", + "\n", + " y = z + identity\n", + " return y\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "9f06e04e", + "metadata": {}, + "source": [ + "# Testing ResNetLayer" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7cdc72a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(16, 16, 10, 4)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "l = ResNetLayer(3, 10; stride=2);\n", + "x = randn(Float32, (32, 32, 3, 4));\n", + "y = l(x);\n", + "size(y)" + ] + }, + { + "cell_type": "markdown", + "id": "7b21b952", + "metadata": {}, + "source": [ + "# Linear Layer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8987f02c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: method definition for Linear at In[6]:22 declares type variable T but does not use it.\n" + ] + } + ], + "source": [ + "mutable struct Linear\n", + " W::AbstractMatrix{T} where T\n", + " b::AbstractVector{T} where T\n", + "end\n", + "\n", + "@functor Linear\n", + "\n", + "# Init\n", + "function Linear(in_features::Int, out_features::Int)\n", + " k_sqrt = sqrt(1 / in_features)\n", + " d = Uniform(-k_sqrt, k_sqrt)\n", + " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", + "end\n", + "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", + "\n", + "function Base.show(io::IO, l::Linear)\n", + " o, i = size(l.W)\n", + " print(io, \"Linear(o)\")\n", + "end\n", + "\n", + "# Forward\n", + "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" + ] + }, + { + "cell_type": "markdown", + "id": "79e8c6ca", + "metadata": {}, + "source": [ + "# Defining a Chain Layer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a47a2eaa", + "metadata": {}, + "outputs": [], + "source": [ + "# Define a chain of layers and a loss function:\n", + "struct Chain1; layers; end\n", + "(c::Chain1)(x) = (for l in c.layers; x = l(x); end; x)\n", + "(c::Chain1)(x,y) = nll(c(x),y)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "02eca287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet20Model" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ResNet Architecture\n", + "\n", + "mutable struct ResNet20Model\n", + " input_conv::Conv2D\n", + " resnet_blocks::Chain1\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet20Model\n", + "\n", + "function ResNet20Model(in_channels::Int, num_classes::Int)\n", + " resnet_blocks = Chain1((\n", + " block_1 = ResNetLayer(16, 16),\n", + " block_2 = ResNetLayer(16, 16),\n", + " block_3 = ResNetLayer(16, 16),\n", + " block_4 = ResNetLayer(16, 32; stride=2),\n", + " block_5 = ResNetLayer(32, 32),\n", + " block_6 = ResNetLayer(32, 32),\n", + " block_7 = ResNetLayer(32, 64; stride=2),\n", + " block_8 = ResNetLayer(64, 64),\n", + " block_9 = ResNetLayer(64, 64)\n", + " ))\n", + " return ResNet20Model(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " resnet_blocks,\n", + " GlobalMeanPool(),\n", + " Linear(64, num_classes)\n", + " )\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cdef0144", + "metadata": {}, + "outputs": [], + "source": [ + "function (self::ResNet20Model)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_blocks(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "25c15eb5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.\n", + "│ yT = Float64\n", + "│ T1 = Float64\n", + "│ T2 = Float32\n", + "└ @ NNlib C:\\Users\\Yash\\.julia\\packages\\NNlib\\0QnJJ\\src\\conv.jl:285\n" + ] + }, + { + "data": { + "text/plain": [ + "(10, 4)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Testing ResNet20 model\n", + "# Expected output: (10, 4)\n", + "m = ResNet20Model(3, 10);\n", + "inputs = randn(Float32, (32, 32, 3, 4))\n", + "outputs = m(inputs);\n", + "size(outputs)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "8e43380e", + "metadata": {}, + "source": [ + "# Data Preprocessing " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "84857fa0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32×32×3×45000 Array{Float32, 4}\n", + "45000-element Vector{Int64}\n", + "32×32×3×5000 Array{Float32, 4}\n", + "5000-element Vector{Int64}\n", + "32×32×3×10000 Array{Float32, 4}\n", + "10000-element Vector{Int64}\n" + ] + } + ], + "source": [ + "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", + "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", + "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", + "xtst,ytst = CIFAR10.testdata(Float32)\n", + "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "45acc000", + "metadata": {}, + "outputs": [], + "source": [ + "# Normalize all the data\n", + "\n", + "means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))\n", + "stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))\n", + "normalize(x) = (x .- means) ./ stdevs\n", + "\n", + "train_x = normalize(xtrn);\n", + "val_x = normalize(xval);\n", + "test_x = normalize(xtst);" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9e93cda3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "splitobs (generic function with 11 methods)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Train-test split\n", + "# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65\n", + "# obsview doesn't work with this data, so use getobs instead\n", + "\n", + "import MLDataPattern.splitobs;\n", + "\n", + "function splitobs(data; at, shuffle::Bool=false)\n", + " if shuffle\n", + " data = shuffleobs(data)\n", + " end\n", + " n = numobs(data)\n", + " return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9c649cac", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Notebook testing: Use less data\n", + "train_x, train_y = MLDatasets.getobs((train_x, ytrn), 1:500);\n", + "\n", + "val_x, val_y = MLDatasets.getobs((val_x, yval), 1:50);\n", + "\n", + "test_x, test_y = MLDatasets.getobs((test_x, ytst), 1:50);" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "75266187", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(40, 40, 3, 500)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# Pad the training data for further augmentation\n", + "train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0))); \n", + "size(train_x_padded) # Should be (40, 40, 3, 50000)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "fc788d3e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6-step Augmentor.ImmutablePipeline:\n", + " 1.) Permute dimension order to (3, 1, 2)\n", + " 2.) Combine color channels into colorant RGB\n", + " 3.) Either: (50%) Flip the X axis. (50%) No operation.\n", + " 4.) Crop random window with size (32, 32)\n", + " 5.) Split colorant into its color channels\n", + " 6.) Permute dimension order to (2, 3, 1)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "815faf28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "outbatch (generic function with 1 method)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create an output array for augmented images\n", + "outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2e86e8f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "augmentbatch (generic function with 1 method)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Function that takes a batch (images and targets) and augments the images\n", + "augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e4d362ce", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: The specified values for size and/or count will result in 4 unused data points\n", + "└ @ MLDataPattern C:\\Users\\Yash\\.julia\\packages\\MLDataPattern\\KlSmO\\src\\dataview.jl:205\n" + ] + } + ], + "source": [ + "\n", + "# Shuffled and batched dataset of augmented images\n", + "train_batch_size = 16\n", + "\n", + "train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e2386c3c", + "metadata": {}, + "outputs": [], + "source": [ + "# Test and Validation data\n", + "test_batch_size = 32\n", + "\n", + "val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);\n", + "test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);" + ] + }, + { + "cell_type": "markdown", + "id": "05599606", + "metadata": {}, + "source": [ + "# Training setup" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "fd7aadd5", + "metadata": {}, + "outputs": [], + "source": [ + "#Sparse Cross Entropy function" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9f6c4d38", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "sparse_logit_cross_entropy (generic function with 1 method)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\"\"\"\n", + " sparse_logit_cross_entropy(logits, labels)\n", + "\n", + "Efficient computation of cross entropy loss with model logits and integer indices as labels.\n", + "Integer indices are from [0, N-1], where N is the number of classes\n", + "Similar to TensorFlow SparseCategoricalCrossEntropy\n", + "\n", + "# Arguments\n", + "- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)\n", + "- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)\n", + "\n", + "# Returns\n", + "- `loss::Float32`: Cross entropy loss\n", + "\"\"\"\n", + "# function sparse_logit_cross_entropy(logits, labels)\n", + "# log_probs = logsoftmax(logits);\n", + "# # Select indices of labels for loss\n", + "# log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);\n", + "# loss = -mean(log_probs);\n", + "# return loss\n", + "# end\n", + "\n", + "function sparse_logit_cross_entropy(logits, labels)\n", + " log_probs = logsoftmax(logits);\n", + " inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));\n", + " # Select indices of labels for loss\n", + " log_probs = log_probs[inds];\n", + " loss = -mean(log_probs);\n", + " return loss\n", + "end\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "3998a220", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Create model with 3 input channels and 10 classes\n", + "model = ResNet20Model(3, 10);" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6fa4497b", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup AdamW optimizer\n", + "β = (0.9, 0.999);\n", + "decay = 1e-4;\n", + "state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "b852506d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "(x, y) = first(train_batches);" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "e71cc12e", + "metadata": {}, + "outputs": [], + "source": [ + "# loss, g = grad(loss_function, model, x, y);" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "1a9a8a89", + "metadata": {}, + "outputs": [], + "source": [ + "mutable struct ResNet5\n", + " input_conv::Conv2D\n", + " resnet_block::ResNetLayer\n", + " pool::GlobalMeanPool\n", + " linear::Linear\n", + "end\n", + "\n", + "@functor ResNet5\n", + "\n", + "function ResNet5(in_channels::Int, num_classes::Int)\n", + " return ResNet5(\n", + " Conv2D((3, 3), in_channels, 16, bias=false),\n", + " ResNetLayer(16, 16),\n", + " GlobalMeanPool(),\n", + " Linear(16, num_classes)\n", + " )\n", + "end\n", + "\n", + "function (self::ResNet5)(x::AbstractArray)\n", + " z = self.input_conv(x)\n", + " z = self.resnet_block(z)\n", + " z = self.pool(z)\n", + " z = dropdims(z, dims=(1, 2))\n", + " y = self.linear(z)\n", + " return y\n", + "end\n", + "\n", + "\n", + "# function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)\n", + "# ŷ = model(x)\n", + "# loss = sparse_logit_cross_entropy(ŷ, y)\n", + "# return loss\n", + "# end" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "028a6d25", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?\n", + "# loss, g = grad(loss_function, model, x, y)\n", + "model = ResNet5(3, 10);\n", + "\n", + "# loss, g = Zygote.gradient(loss_function, model, x, y);" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "696231c0", + "metadata": {}, + "outputs": [], + "source": [ + "# g" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "7d23487b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet5(Conv2D{Float32}([-0.3269322 -0.09079589 -0.30220258; 0.29980195 -0.35697645 -0.43193826; 0.41894063 -0.27608046 -0.35023037;;; -0.11048572 -0.18733561 -0.048941202; -0.25535512 0.41386655 0.23646444; 0.1226552 0.19139434 -0.3201441;;; -0.2881651 0.4041223 -0.11729951; 0.28896266 0.124178275 0.10890088; -0.07136462 0.37597623 0.2907424;;;; 0.0116342725 -0.06491064 0.03947901; 0.36589766 -0.31363672 0.32354057; -0.101177834 0.22076249 0.26570976;;; -0.22781743 -0.16796216 0.079579934; 0.43243396 -0.18935399 0.3949348; -0.3725451 -0.06775151 0.21907443;;; -0.05270441 -0.43405735 -0.44125763; -0.47045088 -0.30292767 0.014733751; -0.04850591 -0.2133474 0.2412362;;;; 0.2572607 0.18735757 -0.33566207; -0.03157889 -0.04323261 0.1315869; -0.16356815 -0.23604983 0.051579874;;; -0.3262316 0.40397793 0.07843399; 0.17368728 0.31032175 -0.2273731; -0.20191403 0.11151084 0.33216488;;; -0.34083363 -0.46381113 0.055753145; -0.44104743 0.31393462 0.2622986; 0.13619547 -0.12979876 -0.043511562;;;; … ;;;; -0.020386824 -0.104114234 0.37638608; 0.41557428 -0.19767518 0.15894295; 0.150955 0.4521936 -0.26687187;;; 0.38604698 0.30180404 -0.1059084; -0.15032865 -0.031554278 -0.21704273; -0.03794346 0.3485954 0.38278958;;; -0.3071966 -0.3373205 0.26615357; 0.4422373 0.13577293 -0.10324652; 0.2894867 -0.23344572 0.39201385;;;; -0.13550712 0.30746755 0.38600484; 0.35319903 0.27227426 -0.42721114; -0.4167391 0.460941 0.23783916;;; 0.45264333 0.30202055 -0.32739767; -0.34008625 -0.23484135 0.19659689; -0.14264174 -0.09916833 0.27199847;;; -0.38800704 -0.060515907 0.4428402; 0.24729028 0.38564798 0.008954014; 0.10717848 0.3565583 -0.40317935;;;; 0.29413882 0.032473866 -0.24675108; 0.18658455 0.4415373 -0.07814981; 0.33296683 -0.1115019 -0.33509403;;; -0.24343053 0.042397596 0.35703608; 0.36186588 -0.05911843 -0.08993424; 0.13785711 -0.26265797 -0.067820854;;; -0.10355408 -0.26968983 0.097447224; -0.25024468 -0.1599089 0.4510931; 0.4365045 0.18134817 0.32099614], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), ResNetLayer(Conv2D{Float32}([0.19761112 0.13021815 -0.028087448; -0.00057760417 -0.11352962 0.05831184; 0.124577686 0.089965284 0.13280511;;; -0.11783607 0.089995995 -0.1501678; -0.015031724 -0.012284083 -0.0898654; -0.09346841 -0.1621385 -0.0109067345;;; 0.17110728 -0.10372673 -0.015643882; 0.008885473 0.09228887 0.20323306; 0.015771464 -0.12023175 0.039679002;;; … ;;; -0.046293672 0.0779443 0.011689981; 0.122653276 -0.106488414 0.16222343; 0.006254584 -0.017572897 0.07593362;;; 0.0962515 0.18770924 0.19851086; -0.13754818 -0.029091569 0.06325006; 0.19043466 -0.08298591 0.17023039;;; -0.061017822 0.17036065 0.06359598; -0.005008466 -0.15680619 0.1233305; 0.17700247 0.20328574 -0.16379757;;;; -0.18611959 -0.043676604 0.108012736; 0.02672942 -0.13492495 0.15822719; -0.002692112 0.16126484 -0.025754329;;; -0.15724711 -0.04828544 -0.1707655; -0.0970394 -0.055499617 -0.024079112; 0.025048656 -0.082812436 -0.022185529;;; -0.1882148 0.16506943 0.000598823; -0.0912264 -0.16550767 -0.16717595; 0.06827358 0.17189996 -0.16707191;;; … ;;; 0.18228532 -0.17576592 -0.18167828; 0.19565201 -0.20267504 -0.18348633; -0.15575626 0.15550001 -0.14896752;;; -0.08209669 0.15093498 -0.11133591; -0.05105649 0.022263639 0.027264051; -0.02951181 0.18026373 -0.07041432;;; 0.025740579 0.075161055 -0.0525457; 0.092696406 -0.09947786 -0.12829517; -0.069097444 0.12314727 -0.1672388;;;; 0.10623432 -0.057691578 0.119959146; 0.03539229 -0.022670422 -0.111270726; 0.008098309 0.0037883115 -0.026243139;;; -0.070837826 -0.11017056 -0.178822; 0.06665229 -0.005612837 0.07103156; 0.1561577 0.031744555 0.0140344165;;; 0.15617746 -0.14973398 0.07564629; 0.0016903019 0.18394831 0.09675205; 0.19826071 -0.09340203 0.1700775;;; … ;;; 0.069550656 0.14834794 -0.06968259; 0.20270519 0.11043808 0.027695874; 0.13334787 0.16532846 0.048797905;;; 0.18591698 0.018436953 -0.0032594716; -0.08772257 -0.052733872 -0.14566335; 0.011975072 -0.15187715 0.10042701;;; 0.112629846 0.18635167 0.16804218; -0.19342335 0.010884567 0.14426668; -0.059680305 -0.038495857 0.19673485;;;; … ;;;; 0.12924753 0.11957796 -0.107034236; -0.15904024 -0.1602915 -0.094139844; 0.08867885 0.17599945 -0.04848101;;; -0.10934039 -0.19765444 -0.14756997; 0.17149675 0.14435652 -0.10002485; 0.18451841 0.066142604 0.17169759;;; -0.11077233 -0.0039441674 -0.069028825; -0.110270225 0.00804806 -0.080900356; -0.16658163 0.054695323 -0.015006443;;; … ;;; 0.00212283 0.18980761 -0.122687295; -0.115651555 -0.14763168 -0.06032928; -0.18875845 0.16435969 0.015897948;;; -0.16687778 -0.13686112 -0.10878749; -0.06288342 0.18291253 0.06357345; 0.09998129 0.1873411 0.08618598;;; -0.08974118 -0.0026351716 0.16592684; 0.12683599 -0.0756004 -0.058696333; -0.025129566 -0.17873748 0.15622494;;;; -0.020486929 -0.1648232 0.0033708948; 0.07356035 -0.04425343 -0.006855549; -0.18207838 0.07768629 -0.11125955;;; 0.05911343 -0.13387455 -0.02985895; 0.19365218 -0.0906411 -0.047491293; 0.1927835 -0.12751262 -0.13904938;;; 0.09187676 0.122309804 0.055273946; 0.1549648 0.06880033 -0.11684039; -0.10373542 0.018027786 -0.1502131;;; … ;;; -0.013978474 -0.051094815 -0.0014352868; -0.046254106 0.19333138 0.0050440175; 0.02186895 -0.083350666 -0.20366172;;; 0.087388806 -0.12761207 -0.027745271; -0.13694362 -0.14262795 0.15379757; -0.088099465 0.05600551 0.05024689;;; -0.18494785 3.431023f-5 -0.059908554; -0.0059224563 0.1425408 -0.02109383; 0.08903239 0.13319586 -0.1353602;;;; 0.0057842666 0.15523006 -0.031229122; -0.19541235 0.051155504 0.06527311; -0.1730501 -0.15493153 -0.1712542;;; 0.15538698 0.19521399 -0.061315224; 0.05348293 -0.15681091 0.06969483; -0.15938468 0.00255144 -0.008160578;;; -0.042394936 -0.023364875 0.030308511; -0.106786154 -0.062052142 0.110279866; -0.024075657 -0.20078343 -0.11427486;;; … ;;; 0.18105797 -0.13237794 -0.1409754; 0.19100334 -0.061168298 0.046601903; 0.14125203 0.17658699 -0.14446235;;; 0.05207166 0.06348066 -0.09771767; 0.0501899 0.09775206 0.1500323; -0.048741743 -0.13970241 -0.013351497;;; 0.11777932 -0.08002654 0.18714364; 0.12242211 0.17391886 0.12460178; -0.0737201 -0.15127876 -0.10289584], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), Conv2D{Float32}([0.09717922 -0.03407784 0.014590169; -0.18445018 0.12831593 -0.07739436; -0.09349194 0.0037209808 -0.13268006;;; 0.17895766 -0.02247517 -0.037817243; -0.111089125 0.12584838 0.0016550183; 0.077631414 -0.044876367 -0.14228758;;; 0.13512933 0.14910595 -0.07057957; -0.121997386 0.13264237 0.19979271; -0.20085463 0.11576198 0.14067839;;; … ;;; 0.10820168 0.12200413 -0.08980403; 0.0011746995 0.1944284 0.18440181; 0.1767769 0.026093902 0.12368353;;; -0.007022744 -0.11667823 -0.20263171; 0.17421563 -0.11981526 -0.049599864; -0.15285212 -0.11957154 0.11077546;;; -0.158289 0.14341965 0.027710011; 0.1390186 -0.12730537 -0.16582859; -0.12259298 -0.11377486 -0.056429546;;;; 0.13843888 0.03513197 -0.15407993; -0.022550944 0.13677955 0.10618293; -0.14883694 0.16104752 -0.033526104;;; 0.11633827 0.073986895 0.039874643; 0.017279312 0.07693811 0.093623415; 0.10816275 -0.20379694 0.018923648;;; 0.15248454 0.06787851 -0.01800759; 0.09909373 -0.060843762 -0.027705412; -0.099589966 -0.1385283 -0.14809528;;; … ;;; -0.16959943 -0.06297494 0.05770058; 0.20113298 -0.047386974 0.11995617; -0.0032085418 -0.022590097 0.010460361;;; -0.18173099 0.13940151 -0.045853406; 0.07445716 0.10802089 -0.18385637; 0.12254616 -0.1648605 -0.1824452;;; -0.011050132 0.12620825 -0.021450827; 0.1770418 0.116070434 -0.091566995; -0.19971325 -0.16312777 -0.013281781;;;; -0.035456408 -0.036289588 -0.05979168; 0.09422635 -0.056371607 -0.0059971116; -0.19720648 0.015351248 -0.09199891;;; -0.03359222 0.0419892 0.18662222; -0.05208587 0.090805836 -0.073787235; -0.0937151 0.044613346 0.16400264;;; -0.10861183 0.17393792 -0.100663245; -0.09707624 -0.09358217 -0.17184801; 0.081750296 -0.076938204 0.06963327;;; … ;;; 0.108367175 -0.14059828 0.012450817; -0.045001905 -0.19173874 0.15082036; -0.008574759 0.15566646 0.14164986;;; -0.1068494 -0.030200982 0.17375617; 0.08390958 0.08448434 0.17300947; -0.16909282 0.056779485 0.03290582;;; -0.07413019 -0.031639017 0.0054012574; 0.033854 0.0916983 -0.18244815; 0.10101544 0.15076157 -0.16013426;;;; … ;;;; -0.063150845 0.17775898 0.1134759; -0.058250155 -0.0747142 0.15499249; -0.05641818 0.073472485 -0.17872544;;; 0.2006494 0.047009416 -0.006391436; -0.07353716 -0.09418866 -0.17701323; 0.024081083 -0.19869477 -0.00069250696;;; -0.049676735 0.0745745 -0.13708827; 0.08180036 -0.02534971 0.03376078; 0.071060985 0.17471206 -0.16068852;;; … ;;; 0.049811736 -0.11532743 0.002412666; 0.16178377 0.1469855 0.10922075; 0.04306834 0.057045255 -0.16866772;;; -0.013301638 -0.15541886 -0.035264272; 0.072164506 0.16410089 -0.09112215; 0.15948404 0.11376505 -0.10438497;;; 0.09450845 -0.19718246 0.073083684; 0.051564816 -0.033912666 -0.1722798; -0.047565997 -0.014022055 0.12652722;;;; 7.1759474f-5 -0.031632643 -0.09353316; 0.027056292 0.035871707 -0.011876182; 0.17951545 -0.13125665 -0.070746765;;; 0.059760485 -0.118754365 0.093996614; 0.13821077 0.19534363 0.12683688; 0.15409826 -0.17226508 -0.051132996;;; -0.092778996 -0.079336025 -0.10923138; 0.08583115 -0.0496824 0.012007193; 0.16740225 0.19941543 -0.06619081;;; … ;;; -0.059872467 0.015607577 0.13966718; 0.10251013 0.030340316 0.20154124; 0.10072651 0.06240758 -0.19964255;;; 0.19304037 0.102634646 0.15421943; 0.15746589 0.096430525 -0.04338655; 0.044689316 -0.047149528 -0.011374255;;; -0.026993804 -0.069641635 -0.063028984; 0.052857243 -0.08376117 -0.021661483; 0.09019801 -0.0031467103 -0.18673009;;;; 0.14249101 -0.18121263 0.020420937; -0.19246987 -0.011450296 -0.1786581; -0.0054819956 -0.016487014 -0.2034004;;; 0.17387968 -0.16070417 0.13330147; 0.07486626 0.16504952 -0.073846295; -0.14962983 -0.09689917 0.000648001;;; -0.15290444 0.17162122 0.0069305687; -0.06845603 -0.1037042 -0.13593975; 0.15990376 -0.08033438 -0.14466582;;; … ;;; -0.029019274 -0.16538228 0.11829052; 0.18074305 -0.04143785 0.16447136; 0.041418746 -0.17317066 0.11413952;;; 0.14402567 -0.030171514 0.10105601; -0.13740699 -0.02441596 0.11153571; 0.086634375 -0.15125358 0.14048335;;; 0.05758249 0.118906185 -0.10497864; -0.17066532 -0.044436127 0.023161229; -0.14597139 -0.03848912 0.17103036], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], false, 1), BatchNorm(16), BatchNorm(16), NNlib.relu, 16, 16, 1), GlobalMeanPool(), Linear(o))" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "01294ff6", + "metadata": {}, + "source": [ + "# Training Knet" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "74f3f1da", + "metadata": {}, + "outputs": [ + { + "ename": "LoadError", + "evalue": "UndefVarError: progress! not defined", + "output_type": "error", + "traceback": [ + "UndefVarError: progress! not defined", + "", + "Stacktrace:", + " [1] macro expansion", + " @ .\\In[31]:12 [inlined]", + " [2] macro expansion", + " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", + " [3] macro expansion", + " @ .\\In[31]:11 [inlined]", + " [4] top-level scope", + " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", + " [5] top-level scope", + " @ .\\In[31]:0", + " [6] eval", + " @ .\\boot.jl:368 [inlined]", + " [7] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", + " @ Base .\\loading.jl:1428" + ] + } + ], + "source": [ + "\n", + "\n", + "loss(test_x, test_y) = nll(model(test_x), test_y)\n", + "evalcb = () -> (loss(test_x, test_y)) #function that will be called to get the loss \n", + "const to = TimerOutput() # creating a TimerOutput, keeps track of everything\n", + "\n", + "\n", + "@timeit to \"Train Total\" begin\n", + " for epoch in 1:10\n", + " train_epoch = epoch > 1 ? \"train_epoch\" : \"train_ji\"\n", + " @timeit to train_epoch begin\n", + " progress!(adam(model, train_batches; lr = 1e-3))\n", + " end\n", + " \n", + " evaluation = epoch > 1 ? \"evaluation\" : \"eval_jit\"\n", + " @timeit to evaluation begin\n", + " accuracy(model, train_batches)\n", + " end \n", + " \n", + " end \n", + "end " + ] + }, + { + "cell_type": "markdown", + "id": "a57b3c8d", + "metadata": {}, + "source": [ + "# Evaluation Function" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "02f69609", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# function evaluate(model, test_loader)\n", + "# preds = []\n", + "# targets = []\n", + "# for (x, y) in test_loader\n", + "# # Get model predictions\n", + "# # Note argmax of nd-array gives CartesianIndex\n", + "# # Need to grab the first element of each CartesianIndex to get the true index\n", + "# logits = model(x)\n", + "# ŷ = map(i -> i[1], argmax(logits, dims=1))\n", + "# append!(preds, ŷ)\n", + "\n", + "# # Get true labels\n", + "# append!(targets, y)\n", + "# end\n", + "# accuracy = sum(preds .== targets) / length(targets)\n", + "# return accuracy\n", + "# end" + ] + }, + { + "cell_type": "markdown", + "id": "f2072bd8", + "metadata": {}, + "source": [ + "# Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "cc39bcab", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# # Setup timing output\n", + "# const to = TimerOutput()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "9b5c088c", + "metadata": {}, + "outputs": [], + "source": [ + "# # last_loss = 0;\n", + "# # @timeit to \"total_training_time\" begin\n", + "# for epoch in 1:10\n", + "# timing_name = epoch > 1 ? \"average_epoch_training_time\" : \"train_jit\"\n", + "\n", + "# # Create lazily evaluated augmented training data\n", + "# train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", + "\n", + "# @timeit to timing_name begin\n", + "# losses = []\n", + "# for (x, y) in train_batches\n", + "# # loss_function does forward pass\n", + "# # Yota.jl grad function computes model parameter gradients in g[2]\n", + "# loss, g = grad(loss_function, model, x, y)\n", + " \n", + "# # Optimiser updates parameters\n", + "# Optimisers.update!(state, model, g[2])\n", + "# push!(losses, loss)\n", + "# end\n", + "# last_loss = mean(losses)\n", + "# @info(\"epoch (mean(losses))\")\n", + "# end\n", + "# # timing_name = epoch > 1 ? \"average_inference_time\" : \"eval_jit\"\n", + "# # @timeit to timing_name begin\n", + "# # acc = evaluate(model, test_loader)\n", + "# # @info(\"epoch (acc)\")\n", + "# # end\n", + "# end\n", + "# end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1955c486", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.2", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 476217326e9a5045a80f3f65127ff1a155a25832 Mon Sep 17 00:00:00 2001 From: Yash Pokra <79229682+yashpokra@users.noreply.github.com> Date: Thu, 9 Feb 2023 01:44:38 -0500 Subject: [PATCH 26/26] Delete julia_resnetmodel_updated_FINAL_version.ipynb --- ...ia_resnetmodel_updated_FINAL_version.ipynb | 1626 ----------------- 1 file changed, 1626 deletions(-) delete mode 100644 convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb diff --git a/convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb b/convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb deleted file mode 100644 index 0ed2ca5..0000000 --- a/convolutional neural network/julia_resnetmodel_updated_FINAL_version.ipynb +++ /dev/null @@ -1,1626 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "69f91157", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "9b1583d4", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "using Yota;\n", - "using MLDatasets;\n", - "using NNlib;\n", - "using Statistics;\n", - "using Distributions;\n", - "using Functors;\n", - "using Optimisers;\n", - "using MLUtils: DataLoader;\n", - "using OneHotArrays: onehotbatch\n", - "using Knet:conv4\n", - "using Metrics;\n", - "using TimerOutputs;\n", - "using Flux: BatchNorm, kaiming_uniform, nfan;\n", - "using Functors\n", - "\n", - "# Model creation\n", - "using NNlib;\n", - "using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;\n", - "using Statistics;\n", - "using Distributions;\n", - "using Functors;\n", - "\n", - "# Data processing\n", - "using MLDatasets;\n", - "using MLUtils: DataLoader;\n", - "using MLDataPattern;\n", - "using ImageCore;\n", - "using Augmentor;\n", - "using ImageFiltering;\n", - "using MappedArrays;\n", - "using Random;\n", - "using Flux: DataLoader;\n", - "# using OneHotArrays: onehotbatch\n", - "\n", - "# Training\n", - "# using Yota;\n", - "using Zygote;\n", - "using Optimisers;\n", - "using Metrics;\n", - "using TimerOutputs;\n", - "\n", - "\n", - "#using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "19aff91e", - "metadata": {}, - "source": [ - "# Conv 2D" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "481a3d9a", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct Conv2D{T}\n", - " w::AbstractArray{T, 4}\n", - " b::AbstractVector{T}\n", - " use_bias::Bool\n", - " padding::Int \n", - "end\n", - "\n", - "@functor Conv2D (w, b)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "59da1b27", - "metadata": {}, - "outputs": [], - "source": [ - "function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;\n", - " bias::Bool=false, padding::Int=1)\n", - " w_size = (kernel_size..., in_channels, out_channels)\n", - " w = kaiming_uniform(w_size...)\n", - " (fan_in, fan_out) = nfan(w_size)\n", - " \n", - " if bias\n", - " # Init bias with fan_in from weights. Use gain = √2 for ReLU\n", - " bound = √3 * √2 / √fan_in\n", - " rng = Uniform(-bound, bound)\n", - " b = rand(rng, out_channels, Float32)\n", - " else\n", - " b = zeros(Float32, out_channels)\n", - " end\n", - "\n", - " return Conv2D(w, b, bias, padding)\n", - "end\n", - "\n", - "function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)\n", - " y = conv4(self.w, x; stride=stride, padding=self.padding, dilation=dilation)\n", - " if self.use_bias\n", - " # Bias is applied channel-wise\n", - " (w, h, c, b) = size(y)\n", - " bias = reshape(self.b, (1, 1, c, 1))\n", - " y = y .+ bias\n", - " end\n", - " return y\n", - "end\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "252e934f", - "metadata": {}, - "source": [ - "# ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "3e66be4f", - "metadata": {}, - "outputs": [], - "source": [ - "mutable struct ResNetLayer\n", - " conv1::Conv2D\n", - " conv2::Conv2D\n", - " bn1::BatchNorm\n", - " bn2::BatchNorm\n", - " f::Function\n", - " in_channels::Int\n", - " channels::Int\n", - " stride::Int\n", - "end\n", - "\n", - "@functor ResNetLayer (conv1, conv2, bn1, bn2)\n", - "\n", - "function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}\n", - " (w, h, c, b) = size(x)\n", - " stride = layer.stride\n", - " if stride > 1\n", - " @assert ((w % stride == 0) & (h % stride == 0)) \"Spatial dimensions are not divisible by `stride`\"\n", - " \n", - " # Strided downsample\n", - " x_id = copy(x[begin:2:end, begin:2:end, :, :])\n", - " else\n", - " x_id = x\n", - " end\n", - "\n", - " channels = layer.channels\n", - " in_channels = layer.in_channels\n", - " if in_channels < channels\n", - " # Zero padding on extra channels\n", - " (w, h, c, b) = size(x_id)\n", - " pad = zeros(w, h, channels - in_channels, b)\n", - " x_id = cat(x_id, pad; dims=3)\n", - " elseif in_channels > channels\n", - " error(\"in_channels > out_channels not supported\")\n", - " end\n", - " return x_id\n", - "end\n", - "\n", - "function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)\n", - " bn1 = BatchNorm(in_channels)\n", - " conv1 = Conv2D((3, 3), in_channels, channels, bias=false)\n", - " bn2 = BatchNorm(channels)\n", - " conv2 = Conv2D((3, 3), channels, channels, bias=false)\n", - "\n", - " return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)\n", - "end\n", - "\n", - "\n", - "function (self::ResNetLayer)(x::AbstractArray)\n", - " identity = residual_identity(self, x)\n", - " z = self.bn1(x)\n", - " z = self.f(z)\n", - " z = self.conv1(z; pad=1, stride=self.stride) # pad=1 will keep same size with (3x3) kernel\n", - " z = self.bn2(z)\n", - " z = self.f(z)\n", - " z = self.conv2(z; pad=1)\n", - "\n", - " y = z + identity\n", - " return y\n", - "end" - ] - }, - { - "cell_type": "markdown", - "id": "9f06e04e", - "metadata": {}, - "source": [ - "# Testing ResNetLayer" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7cdc72a9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(16, 16, 10, 4)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "l = ResNetLayer(3, 10; stride=2);\n", - "x = randn(Float32, (32, 32, 3, 4));\n", - "y = l(x);\n", - "size(y)" - ] - }, - { - "cell_type": "markdown", - "id": "7b21b952", - "metadata": {}, - "source": [ - "# Linear Layer" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "8987f02c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: method definition for Linear at In[6]:22 declares type variable T but does not use it.\n" - ] - } - ], - "source": [ - "mutable struct Linear\n", - " W::AbstractMatrix{T} where T\n", - " b::AbstractVector{T} where T\n", - "end\n", - "\n", - "@functor Linear\n", - "\n", - "# Init\n", - "function Linear(in_features::Int, out_features::Int)\n", - " k_sqrt = sqrt(1 / in_features)\n", - " d = Uniform(-k_sqrt, k_sqrt)\n", - " return Linear(rand(d, out_features, in_features), rand(d, out_features))\n", - "end\n", - "Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])\n", - "\n", - "function Base.show(io::IO, l::Linear)\n", - " o, i = size(l.W)\n", - " print(io, \"Linear(o)\")\n", - "end\n", - "\n", - "# Forward\n", - "(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a386ea7a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "02eca287", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResNet20Model" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# ResNet Architecture\n", - "\n", - "mutable struct ResNet20Model\n", - " input_conv::Conv2D\n", - " resnet_blocks::Chain\n", - " pool::GlobalMeanPool\n", - " linear::Linear\n", - "end\n", - "\n", - "@functor ResNet20Model\n", - "\n", - "function ResNet20Model(in_channels::Int, num_classes::Int)\n", - " resnet_blocks = Chain(\n", - " block_1 = ResNetLayer(16, 16),\n", - " block_2 = ResNetLayer(16, 16),\n", - " block_3 = ResNetLayer(16, 16),\n", - " block_4 = ResNetLayer(16, 32; stride=2),\n", - " block_5 = ResNetLayer(32, 32),\n", - " block_6 = ResNetLayer(32, 32),\n", - " block_7 = ResNetLayer(32, 64; stride=2),\n", - " block_8 = ResNetLayer(64, 64),\n", - " block_9 = ResNetLayer(64, 64)\n", - " )\n", - " return ResNet20Model(\n", - " Conv2D((3, 3), in_channels, 16, bias=false),\n", - " resnet_blocks,\n", - " GlobalMeanPool(),\n", - " Linear(64, num_classes)\n", - " )\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "cdef0144", - "metadata": {}, - "outputs": [], - "source": [ - "function (self::ResNet20Model)(x::AbstractArray)\n", - " z = self.input_conv(x)\n", - " z = self.resnet_blocks(z)\n", - " z = self.pool(z)\n", - " z = dropdims(z, dims=(1, 2))\n", - " y = self.linear(z)\n", - " return y\n", - "end\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "25c15eb5", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.\n", - "│ yT = Float64\n", - "│ T1 = Float64\n", - "│ T2 = Float32\n", - "└ @ NNlib C:\\Users\\Yash\\.julia\\packages\\NNlib\\0QnJJ\\src\\conv.jl:285\n" - ] - }, - { - "data": { - "text/plain": [ - "(10, 4)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Testing ResNet20 model\n", - "# Expected output: (10, 4)\n", - "m = ResNet20Model(3, 10);\n", - "inputs = randn(Float32, (32, 32, 3, 4))\n", - "outputs = m(inputs);\n", - "size(outputs)\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "8e43380e", - "metadata": {}, - "source": [ - "# Data Preprocessing " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "84857fa0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "32×32×3×45000 Array{Float32, 4}\n", - "45000-element Vector{Int64}\n", - "32×32×3×5000 Array{Float32, 4}\n", - "5000-element Vector{Int64}\n", - "32×32×3×10000 Array{Float32, 4}\n", - "10000-element Vector{Int64}\n" - ] - } - ], - "source": [ - "# This loads the CIFAR-10 Dataset for training, validation, and evaluation\n", - "xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)\n", - "xval,yval = CIFAR10.traindata(Float32, 45001:50000)\n", - "xtst,ytst = CIFAR10.testdata(Float32)\n", - "println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "45acc000", - "metadata": {}, - "outputs": [], - "source": [ - "# Normalize all the data\n", - "\n", - "means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))\n", - "stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))\n", - "normalize(x) = (x .- means) ./ stdevs\n", - "\n", - "train_x = normalize(xtrn);\n", - "val_x = normalize(xval);\n", - "test_x = normalize(xtst);" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "9e93cda3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "splitobs (generic function with 11 methods)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Train-test split\n", - "# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65\n", - "# obsview doesn't work with this data, so use getobs instead\n", - "\n", - "import MLDataPattern.splitobs;\n", - "\n", - "function splitobs(data; at, shuffle::Bool=false)\n", - " if shuffle\n", - " data = shuffleobs(data)\n", - " end\n", - " n = numobs(data)\n", - " return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "9c649cac", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Notebook testing: Use less data\n", - "train_x, train_y = MLDatasets.getobs((train_x, ytrn), 1:500);\n", - "\n", - "val_x, val_y = MLDatasets.getobs((val_x, yval), 1:50);\n", - "\n", - "test_x, test_y = MLDatasets.getobs((test_x, ytst), 1:50);" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "75266187", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(40, 40, 3, 500)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Pad the training data for further augmentation\n", - "train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0))); \n", - "size(train_x_padded) # Should be (40, 40, 3, 50000)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "fc788d3e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6-step Augmentor.ImmutablePipeline:\n", - " 1.) Permute dimension order to (3, 1, 2)\n", - " 2.) Combine color channels into colorant RGB\n", - " 3.) Either: (50%) Flip the X axis. (50%) No operation.\n", - " 4.) Crop random window with size (32, 32)\n", - " 5.) Split colorant into its color channels\n", - " 6.) Permute dimension order to (2, 3, 1)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "815faf28", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "outbatch (generic function with 1 method)" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Create an output array for augmented images\n", - "outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "2e86e8f7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "augmentbatch (generic function with 1 method)" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Function that takes a batch (images and targets) and augments the images\n", - "augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "e4d362ce", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "┌ Warning: The specified values for size and/or count will result in 4 unused data points\n", - "└ @ MLDataPattern C:\\Users\\Yash\\.julia\\packages\\MLDataPattern\\KlSmO\\src\\dataview.jl:205\n" - ] - } - ], - "source": [ - "\n", - "# Shuffled and batched dataset of augmented images\n", - "train_batch_size = 16\n", - "\n", - "train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "e2386c3c", - "metadata": {}, - "outputs": [], - "source": [ - "# Test and Validation data\n", - "test_batch_size = 32\n", - "\n", - "val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);\n", - "test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "3998a220", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# # Create model with 3 input channels and 10 classes\n", - " model = ResNet20Model(3, 10);" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "3731cc35", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "# loss(xtst, ytst) = nll(model(xtst), ytst)\n", - "# evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss \n", - "# const to = TimerOutput() # creating a TimerOutput, keeps track of everything\n", - "\n", - "\n", - "# @timeit to \"Train Total\" begin\n", - "# for epoch in 1:10\n", - "# train_epoch = epoch > 1 ? \"train_epoch\" : \"train_ji\"\n", - "# @timeit to train_epoch begin\n", - "# progress!(adam(model, train_batches; lr = 1e-3))\n", - "# end\n", - " \n", - "# evaluation = epoch > 1 ? \"evaluation\" : \"eval_jit\"\n", - "# @timeit to evaluation begin\n", - "# accuracy(model, test_loader)\n", - "# end \n", - " \n", - "# end \n", - "# end \n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c33ae82c", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1edf6901", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c04eb217", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c945bf07", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19b1cfc9", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0396b9b1", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "05599606", - "metadata": {}, - "source": [ - "# Training setup" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "fd7aadd5", - "metadata": {}, - "outputs": [], - "source": [ - "#Sparse Cross Entropy function" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "9f6c4d38", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "sparse_logit_cross_entropy (generic function with 1 method)" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "\"\"\"\n", - " sparse_logit_cross_entropy(logits, labels)\n", - "\n", - "Efficient computation of cross entropy loss with model logits and integer indices as labels.\n", - "Integer indices are from [0, N-1], where N is the number of classes\n", - "Similar to TensorFlow SparseCategoricalCrossEntropy\n", - "\n", - "# Arguments\n", - "- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)\n", - "- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)\n", - "\n", - "# Returns\n", - "- `loss::Float32`: Cross entropy loss\n", - "\"\"\"\n", - "# function sparse_logit_cross_entropy(logits, labels)\n", - "# log_probs = logsoftmax(logits);\n", - "# # Select indices of labels for loss\n", - "# log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);\n", - "# loss = -mean(log_probs);\n", - "# return loss\n", - "# end\n", - "\n", - "function sparse_logit_cross_entropy(logits, labels)\n", - " log_probs = logsoftmax(logits);\n", - " inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));\n", - " # Select indices of labels for loss\n", - " log_probs = log_probs[inds];\n", - " loss = -mean(log_probs);\n", - " return loss\n", - "end\n" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "6fa4497b", - "metadata": {}, - "outputs": [], - "source": [ - "# Setup AdamW optimizer\n", - "β = (0.9, 0.999);\n", - "decay = 1e-4;\n", - "state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "b852506d", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "(x, y) = first(train_batches);" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "e71cc12e", - "metadata": {}, - "outputs": [], - "source": [ - "# loss, g = grad(loss_function, model, x, y);" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "1a9a8a89", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "loss_function (generic function with 1 method)" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mutable struct ResNet5\n", - " input_conv::Conv2D\n", - " resnet_block::ResNetLayer\n", - " pool::GlobalMeanPool\n", - " linear::Linear\n", - "end\n", - "\n", - "@functor ResNet5\n", - "\n", - "function ResNet5(in_channels::Int, num_classes::Int)\n", - " return ResNet5(\n", - " Conv2D((3, 3), in_channels, 16, bias=false),\n", - " ResNetLayer(16, 16),\n", - " GlobalMeanPool(),\n", - " Linear(16, num_classes)\n", - " )\n", - "end\n", - "\n", - "function (self::ResNet5)(x::AbstractArray)\n", - " z = self.input_conv(x)\n", - " z = self.resnet_block(z)\n", - " z = self.pool(z)\n", - " z = dropdims(z, dims=(1, 2))\n", - " y = self.linear(z)\n", - " return y\n", - "end\n", - "\n", - "\n", - "function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)\n", - " ŷ = model(x)\n", - " loss = sparse_logit_cross_entropy(ŷ, y)\n", - " return loss\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "028a6d25", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?\n", - "# loss, g = grad(loss_function, model, x, y)\n", - "model = ResNet5(3, 10);\n", - "\n", - "loss, g = Zygote.gradient(loss_function, model, x, y);" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "696231c0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "32×32×3×16 Array{Float32, 4}:\n", - "[:, :, 1, 1] =\n", - " 2.87384f-5 5.15218f-6 -7.95893f-6 … 2.03477f-5 -2.35546f-5\n", - " 1.31753f-5 -5.36217f-6 -2.37871f-5 1.75244f-5 -9.57452f-6\n", - " 5.49028f-6 -1.71476f-5 -3.2099f-5 9.48742f-6 -1.40432f-5\n", - " 7.2699f-6 -5.92972f-6 -2.18901f-5 2.19976f-5 -1.03878f-5\n", - " -2.49811f-7 -1.22833f-6 -1.28048f-5 2.23008f-5 -8.82221f-6\n", - " -2.13457f-6 2.44618f-6 -1.34374f-5 … 2.04867f-5 -8.85707f-6\n", - " -1.58989f-6 1.10097f-6 -3.85356f-5 1.71607f-5 -8.91197f-6\n", - " 2.43009f-6 2.74427f-6 -2.31398f-5 1.43616f-5 -8.99907f-6\n", - " 8.6382f-6 4.24066f-6 -1.89015f-5 2.12419f-5 -9.01441f-6\n", - " 1.01316f-5 2.10697f-6 -2.02417f-5 2.03862f-5 -8.99482f-6\n", - " 9.46375f-6 -6.25004f-6 -1.73722f-5 … 2.14269f-5 -8.98492f-6\n", - " 9.65401f-6 -3.91408f-6 -2.25808f-5 2.14387f-5 -8.96013f-6\n", - " 9.63481f-6 -4.90097f-6 -2.71419f-5 2.27392f-5 -8.99733f-6\n", - " ⋮ ⋱ ⋮ \n", - " 6.6796f-6 -7.3569f-6 -2.57043f-5 … 2.38568f-5 -8.91103f-6\n", - " 7.58086f-7 -2.9958f-6 -1.29167f-5 2.07219f-5 -8.85528f-6\n", - " 2.3564f-6 4.17472f-6 1.37998f-7 2.16737f-5 -8.90104f-6\n", - " 8.35725f-6 7.49693f-6 -1.33849f-5 2.03694f-5 -8.93113f-6\n", - " 1.10655f-5 1.39176f-6 -3.57415f-5 2.42324f-5 -8.66875f-6\n", - " 1.06251f-5 -2.52541f-6 -3.04435f-5 … 2.01572f-5 -8.3407f-6\n", - " 8.71558f-6 -4.15225f-6 -2.62513f-5 2.06746f-5 -8.21051f-6\n", - " 4.74822f-6 1.08056f-5 -2.1629f-5 2.31422f-5 -7.38363f-6\n", - " 4.34095f-6 -6.39721f-6 -3.72078f-6 2.3523f-5 -1.0264f-5\n", - " 1.64836f-5 1.37776f-5 -1.70328f-5 2.25005f-5 -1.09304f-5\n", - " -2.97065f-6 3.41663f-7 -1.48043f-5 … 1.87838f-5 -1.2569f-5\n", - " -4.29779f-5 -3.22912f-6 -2.24622f-5 -2.00159f-5 -1.13841f-5\n", - "\n", - "[:, :, 2, 1] =\n", - " 3.51288f-6 -1.22697f-5 -2.48598f-5 … -5.24601f-5 -2.63561f-5\n", - " 5.0802f-6 -8.81436f-6 -9.8856f-6 -7.7401f-5 -4.79854f-5\n", - " 6.16604f-6 -1.53578f-5 -1.70939f-5 -6.56983f-5 -4.47286f-5\n", - " 1.24349f-5 -1.22695f-5 -2.24474f-5 -6.35044f-5 -4.26525f-5\n", - " 1.34362f-5 -1.62494f-5 -1.91734f-5 -6.25156f-5 -4.13404f-5\n", - " 1.10909f-5 -2.10735f-5 -1.57398f-5 … -6.32187f-5 -4.12555f-5\n", - " 9.65864f-6 -1.58431f-5 -2.85245f-5 -6.47562f-5 -4.12256f-5\n", - " 8.55365f-6 -1.49781f-5 1.02663f-5 -6.98997f-5 -4.12231f-5\n", - " 7.57427f-6 -1.58006f-5 5.02288f-5 -6.04549f-5 -4.13147f-5\n", - " 9.93018f-6 -7.01827f-6 1.28258f-5 -5.44541f-5 -4.14396f-5\n", - " 1.12572f-5 -1.46116f-5 -2.1672f-5 … -5.62349f-5 -4.13453f-5\n", - " 1.18546f-5 -1.30768f-5 -1.94482f-5 -5.62452f-5 -4.12333f-5\n", - " 1.18409f-5 -1.25132f-5 -2.20645f-5 -5.53337f-5 -4.11621f-5\n", - " ⋮ ⋱ ⋮ \n", - " 1.18743f-5 -1.3683f-5 -2.22972f-5 … -5.65869f-5 -4.12856f-5\n", - " 1.42515f-5 -1.64138f-5 -1.96379f-5 -5.67975f-5 -4.13787f-5\n", - " 1.24159f-5 -1.52812f-5 -1.22782f-5 -6.03015f-5 -4.13753f-5\n", - " 9.13442f-6 -9.09309f-6 -1.12662f-5 -6.61189f-5 -4.21891f-5\n", - " 1.05468f-5 -7.99576f-6 -1.33269f-5 -5.49894f-5 -4.08029f-5\n", - " 1.1644f-5 -1.0743f-5 -1.19364f-5 … -5.90684f-5 -4.122f-5\n", - " 1.14246f-5 -1.58993f-5 -2.733f-5 -6.15047f-5 -4.00522f-5\n", - " 2.33722f-5 -1.66764f-5 -8.95946f-6 -5.53262f-5 -4.31432f-5\n", - " -1.36753f-5 -3.431f-5 -1.81776f-5 -5.894f-5 -4.33847f-5\n", - " -5.28935f-6 -4.45395f-5 -4.33759f-5 -6.29569f-5 -3.67273f-5\n", - " -5.14466f-5 -5.65307f-5 -5.83267f-5 … -7.28656f-5 -3.01449f-5\n", - " -3.77745f-5 -3.05045f-5 -4.26683f-5 -6.23611f-5 -2.93418f-5\n", - "\n", - "[:, :, 3, 1] =\n", - " 1.30337f-5 1.97476f-5 -1.37888f-6 … 2.5648f-5 -4.64034f-6\n", - " 3.68331f-7 3.68014f-5 1.60665f-5 1.87886f-5 1.83326f-5\n", - " 1.01084f-5 3.88496f-5 3.575f-5 2.20447f-5 1.84873f-5\n", - " 1.87285f-6 3.03808f-5 3.2317f-5 2.7348f-5 1.9302f-5\n", - " 3.31373f-6 3.43455f-5 3.81177f-5 3.29928f-5 1.7696f-5\n", - " 4.18613f-6 4.38729f-5 2.48672f-5 … 2.99008f-5 1.75935f-5\n", - " 2.50239f-6 3.87682f-5 4.05324f-5 2.82491f-5 1.7558f-5\n", - " 3.0775f-6 4.32761f-5 3.71651f-5 3.06394f-5 1.76267f-5\n", - " 3.0736f-6 3.7843f-5 -3.7053f-5 2.99056f-5 1.76673f-5\n", - " 2.25529f-6 3.46948f-5 -1.37121f-5 3.12987f-5 1.76435f-5\n", - " 4.21507f-6 4.15411f-5 2.95277f-5 … 3.20953f-5 1.75889f-5\n", - " 3.72306f-6 3.7588f-5 1.91054f-5 3.17634f-5 1.75182f-5\n", - " 3.94906f-6 4.03008f-5 2.22407f-5 3.2732f-5 1.74733f-5\n", - " ⋮ ⋱ ⋮ \n", - " 3.294f-6 3.00897f-5 3.7794f-5 … 2.85874f-5 1.75208f-5\n", - " 2.5483f-6 3.25366f-5 2.8732f-5 3.14754f-5 1.76076f-5\n", - " 3.52557f-6 3.8136f-5 2.43104f-5 3.40014f-5 1.76464f-5\n", - " 2.90404f-6 3.00002f-5 1.83853f-5 3.81449f-5 1.71642f-5\n", - " 2.10267f-6 2.82629f-5 3.10803f-5 3.24481f-5 1.66842f-5\n", - " 3.57918f-6 3.26719f-5 2.92418f-5 … 3.31228f-5 1.64505f-5\n", - " 7.08678f-6 2.56342f-5 4.57358f-5 3.47681f-5 1.91743f-5\n", - " -3.32079f-6 2.68047f-5 2.83111f-5 3.74222f-5 1.95459f-5\n", - " 7.50175f-6 4.48023f-5 3.10073f-5 3.69629f-5 2.05496f-5\n", - " 3.77103f-5 3.62848f-5 3.53838f-5 3.04651f-5 1.50938f-5\n", - " 2.26291f-6 6.13585f-5 5.42865f-5 … 3.72573f-5 2.30365f-5\n", - " 9.99016f-6 1.11194f-5 1.05421f-5 1.7655f-5 6.95817f-6\n", - "\n", - "[:, :, 1, 2] =\n", - " -1.09797f-5 -2.13722f-5 -1.27747f-5 … -1.69448f-5 -7.88239f-6\n", - " 1.12475f-5 -2.0106f-5 -3.77094f-5 -3.23303f-5 -2.27554f-5\n", - " -5.15608f-6 -2.29666f-5 -1.57832f-5 -1.66505f-5 -2.19483f-5\n", - " -4.40688f-6 -9.89303f-6 -2.67959f-5 -6.70534f-6 -1.72258f-5\n", - " 1.08458f-6 -2.6123f-5 -5.26342f-5 -6.83994f-6 -1.43384f-5\n", - " -3.56635f-6 -1.78059f-5 -2.80402f-5 … -6.95699f-6 -1.45106f-5\n", - " -3.02137f-6 -1.73025f-5 -1.40618f-5 -2.04053f-5 -6.60865f-6\n", - " -4.13502f-7 -1.72447f-5 -1.44892f-5 -2.04278f-5 -6.64876f-6\n", - " -6.49256f-6 -9.48326f-6 -1.42205f-5 -2.04246f-5 -6.61602f-6\n", - " -1.11668f-5 -3.3754f-6 -8.46039f-6 -2.0354f-5 -6.67575f-6\n", - " 7.79242f-6 -1.16571f-5 -1.09261f-5 … -2.0452f-5 -6.61188f-6\n", - " 1.53111f-5 -1.4928f-5 2.88932f-6 -2.04422f-5 -6.60986f-6\n", - " 5.2131f-6 -1.3793f-5 -3.3462f-5 -2.0508f-5 -6.62253f-6\n", - " ⋮ ⋱ ⋮ \n", - " 1.38997f-5 -1.92094f-5 -3.57026f-5 … -1.14854f-5 -8.50423f-6\n", - " 3.52126f-5 5.34156f-6 -4.03604f-5 -2.19193f-5 -9.99198f-6\n", - " 3.7833f-5 -1.7161f-5 -5.25765f-5 -2.1176f-5 -4.21367f-6\n", - " -4.3647f-6 5.2507f-6 -5.71789f-6 -5.89364f-6 -1.47315f-5\n", - " 1.43847f-6 -3.27774f-5 4.64099f-5 -2.04356f-5 -6.57705f-6\n", - " 9.02621f-6 -6.39047f-5 2.31406f-5 … -2.04515f-5 -6.59162f-6\n", - " 1.05266f-5 -7.57128f-5 -4.46544f-5 -2.04393f-5 -6.60923f-6\n", - " 9.78515f-6 -3.16386f-5 -5.53192f-5 -2.03575f-5 -6.66031f-6\n", - " 2.875f-5 -1.83572f-5 -4.66145f-5 -1.73273f-5 -6.96841f-6\n", - " 1.6593f-5 -2.25877f-5 -2.63732f-5 -2.25418f-5 -1.45672f-6\n", - " -1.93264f-5 -4.5622f-6 -2.44407f-6 … -5.44911f-6 -5.07896f-6\n", - " -1.26567f-5 -8.28866f-6 -1.98972f-5 -1.90232f-5 -1.22861f-5\n", - "\n", - "[:, :, 2, 2] =\n", - " -4.52723f-6 -3.49323f-5 -4.12799f-5 … -1.50356f-5 -2.16989f-7\n", - " -6.78492f-5 -7.84942f-5 -8.87692f-5 -2.63529f-5 -4.13737f-6\n", - " -8.41087f-5 -3.1084f-5 -9.76829f-5 -4.07474f-5 -7.83035f-6\n", - " -7.08706f-5 -2.40582f-5 -0.000109596 -5.32373f-5 -3.9194f-6\n", - " -7.77566f-5 -6.85934f-7 -0.000118317 -5.26677f-5 -3.95299f-6\n", - " -8.1118f-5 9.22139f-7 -0.000108936 … -4.52668f-5 9.43691f-6\n", - " -8.13958f-5 -1.85499f-7 -0.000107129 -5.46121f-5 1.62983f-5\n", - " -8.20843f-5 -5.70543f-7 -0.00011982 -5.47232f-5 1.63308f-5\n", - " -8.19961f-5 -9.28187f-6 -0.000109798 -5.46595f-5 1.6309f-5\n", - " -7.93751f-5 -2.48049f-5 -0.000109075 -5.44518f-5 1.6343f-5\n", - " -7.90918f-5 -3.43573f-5 -0.000101994 … -5.43597f-5 1.63746f-5\n", - " -6.23688f-5 -2.64341f-5 -9.38569f-5 -5.44339f-5 1.63483f-5\n", - " -6.52792f-5 -2.48703f-5 -7.8946f-5 -5.44905f-5 1.6355f-5\n", - " ⋮ ⋱ ⋮ \n", - " -8.1285f-5 -4.88402f-5 -0.000100694 … -4.41222f-5 9.497f-6\n", - " -4.88119f-5 -8.28254f-5 -9.25717f-5 -5.31737f-5 1.68613f-5\n", - " -5.17702f-5 -4.68936f-5 -7.68762f-5 -5.82533f-5 1.66828f-6\n", - " -5.55461f-5 -4.18896f-5 -7.47917f-5 -4.51579f-5 9.44686f-6\n", - " -4.83162f-5 -5.65184f-5 -4.35828f-5 -5.4252f-5 1.63604f-5\n", - " -5.14842f-5 -3.70771f-5 -5.60604f-5 … -5.44706f-5 1.62527f-5\n", - " -7.58537f-5 -8.31569f-6 -0.000117182 -5.46463f-5 1.62203f-5\n", - " -8.04193f-5 1.43932f-5 -0.000156959 -5.45568f-5 1.63059f-5\n", - " -5.56057f-5 -8.65628f-6 -0.000125903 -5.44588f-5 1.16301f-5\n", - " -7.1956f-5 -1.97293f-5 -0.000133141 -5.4016f-5 1.89256f-5\n", - " -7.28343f-5 -3.14315f-5 -0.000127052 … -5.81531f-5 7.89226f-6\n", - " -6.13478f-5 -5.94695f-7 -7.32276f-5 -2.62953f-5 2.57695f-5\n", - "\n", - "[:, :, 3, 2] =\n", - " -2.19261f-5 2.7254f-5 -4.52801f-6 … -5.49369f-6 1.45297f-5\n", - " -4.46353f-5 2.62317f-5 8.04015f-6 -2.53547f-5 4.73811f-6\n", - " -3.79392f-5 -1.06354f-5 5.85113f-5 -1.7056f-5 6.20228f-6\n", - " 1.1313f-5 1.57169f-5 5.71432f-5 -3.7475f-5 5.49092f-6\n", - " 2.04558f-6 3.71181f-5 4.01594f-5 -3.35148f-5 -6.25117f-6\n", - " 4.73865f-6 1.68689f-5 4.19958f-5 … -2.51303f-5 -4.91489f-6\n", - " 3.20972f-6 1.13822f-5 3.06967f-5 -2.25985f-5 -1.27144f-5\n", - " 4.89944f-6 9.41842f-6 3.82225f-5 -2.26813f-5 -1.27046f-5\n", - " 1.26125f-5 2.34069f-5 2.76001f-5 -2.27173f-5 -1.27674f-5\n", - " 2.32491f-5 1.94585f-5 2.85936f-5 -2.26634f-5 -1.27467f-5\n", - " 4.74152f-6 2.08929f-5 3.76962f-5 … -2.26154f-5 -1.27296f-5\n", - " -2.01045f-6 1.11967f-5 3.11925f-5 -2.26469f-5 -1.2722f-5\n", - " -1.39923f-6 2.3897f-5 3.35059f-5 -2.27345f-5 -1.27035f-5\n", - " ⋮ ⋱ ⋮ \n", - " 7.09541f-6 -3.04937f-5 4.89153f-5 … -1.33056f-5 -9.03981f-6\n", - " -1.36651f-5 2.70014f-5 3.32767f-5 -3.2459f-5 -2.12493f-7\n", - " 3.17517f-5 6.80056f-6 2.52505f-6 -2.99075f-5 -1.37626f-5\n", - " 8.04813f-6 9.21358f-6 -1.15224f-5 -2.61928f-5 -4.83346f-6\n", - " -5.77461f-5 4.40058f-5 2.18061f-5 -2.26172f-5 -1.26375f-5\n", - " -3.28887f-5 1.54883f-5 2.57241f-5 … -2.26395f-5 -1.25606f-5\n", - " 1.8196f-5 7.08923f-6 -3.4964f-6 -2.27131f-5 -1.26321f-5\n", - " 2.73781f-5 3.86673f-7 2.03445f-5 -2.27442f-5 -1.27123f-5\n", - " 1.39553f-5 -5.43552f-6 4.37894f-5 -2.18561f-5 -1.45538f-5\n", - " 1.25204f-5 4.71521f-6 4.39248f-5 -1.00257f-5 -1.665f-5\n", - " 2.14967f-5 1.08339f-5 -6.27319f-7 … 6.15765f-6 -2.11569f-5\n", - " 5.82686f-5 -3.92828f-6 3.74167f-5 9.41486f-7 -1.17081f-5\n", - "\n", - "[:, :, 1, 3] =\n", - " -1.32921f-5 -2.54748f-5 -2.4276f-5 … -2.82705f-5 -1.69928f-5\n", - " -2.25105f-5 -3.48271f-5 -2.74502f-5 -1.38684f-5 -2.77036f-5\n", - " -2.05614f-5 -1.17774f-5 -4.85968f-6 1.2789f-5 -1.05062f-5\n", - " -3.98836f-5 -2.81254f-5 -6.40034f-5 -4.08217f-5 -2.0309f-5\n", - " -4.38632f-5 -2.61992f-5 -4.8767f-5 -1.90552f-6 -7.2014f-6\n", - " -2.09627f-5 -4.17273f-5 -5.62594f-5 … -1.09989f-5 -1.79831f-5\n", - " -1.28669f-5 -1.27832f-5 -3.18246f-5 -4.09247f-5 -1.90839f-5\n", - " -3.97809f-5 -4.64297f-5 -6.28011f-5 -7.87657f-6 -1.12435f-5\n", - " -4.04715f-5 -4.66917f-5 -4.45681f-5 -4.18667f-5 3.13917f-6\n", - " -2.64794f-5 -7.84953f-5 -7.23706f-5 -2.23594f-5 -7.51302f-6\n", - " -3.37457f-5 -1.42143f-5 -4.43477f-5 … -7.00164f-7 -2.95561f-5\n", - " -2.87837f-5 -2.55903f-5 -2.86363f-5 -2.50456f-5 -3.61103f-5\n", - " -7.72938f-6 -5.0325f-6 -6.46772f-5 -6.57527f-5 -1.66607f-5\n", - " ⋮ ⋱ ⋮ \n", - " -3.11348f-5 -4.19208f-5 -6.37247f-5 … -2.89499f-5 -4.03736f-6\n", - " -3.92671f-5 -3.46014f-5 -3.81206f-5 -2.31593f-5 -1.69803f-5\n", - " -3.53422f-5 -2.82261f-5 -4.3099f-5 -2.38775f-5 -1.47377f-5\n", - " -4.09482f-5 -3.56281f-5 -2.9176f-5 -1.49256f-5 -2.23803f-5\n", - " -8.15614f-6 -1.66638f-5 -4.15551f-5 -9.67728f-6 -1.75268f-5\n", - " -2.87176f-5 -1.85177f-5 -1.01355f-5 … -2.92795f-5 -3.84081f-6\n", - " -3.27056f-5 -4.02847f-5 -5.10919f-5 -2.17138f-5 -1.25814f-5\n", - " -2.75769f-5 -2.67168f-5 -2.62755f-5 -2.40851f-5 -1.13193f-5\n", - " -1.80984f-5 -3.46694f-5 -3.13359f-5 -3.40728f-5 -1.15171f-5\n", - " -4.0983f-5 -1.22354f-5 -2.21499f-5 -3.92156f-5 2.48212f-6\n", - " -2.17541f-5 -1.4721f-5 -3.61195f-5 … -1.8059f-5 -2.34319f-5\n", - " -1.73993f-5 -8.16249f-6 -2.49898f-5 -6.6299f-8 -3.32326f-6\n", - "\n", - "[:, :, 2, 3] =\n", - " 1.37565f-5 5.39358f-5 4.58624f-5 … 4.44795f-5 2.50945f-5\n", - " -1.05192f-5 6.36829f-5 4.8567f-5 6.02761f-5 1.23682f-5\n", - " 8.05044f-6 4.69759f-5 2.89093f-5 1.65648f-5 1.09588f-5\n", - " 2.33734f-6 0.000106804 3.09279f-5 -3.14127f-5 1.23341f-5\n", - " 9.22226f-7 0.000113907 2.53979f-5 4.82065f-6 7.19267f-6\n", - " -5.75003f-6 0.000115487 6.12588f-5 … -9.67223f-6 -3.66514f-6\n", - " -7.21526f-6 0.000118417 4.04848f-5 -1.31087f-5 4.85808f-6\n", - " -1.83487f-5 0.000144694 3.57431f-5 2.71545f-6 1.97464f-6\n", - " -2.31197f-6 0.000116426 4.23696f-5 -1.99311f-6 1.24911f-5\n", - " 1.33507f-5 0.000142583 5.01268f-5 8.75845f-6 4.12369f-5\n", - " 2.56636f-5 0.000180252 5.40545f-5 … 0.000105484 2.48439f-5\n", - " -3.61895f-6 0.000149931 4.13484f-5 4.68084f-5 4.66717f-5\n", - " 2.01801f-5 8.81076f-5 -2.2904f-5 -1.71351f-5 -4.21337f-6\n", - " ⋮ ⋱ ⋮ \n", - " -1.24302f-5 0.00013385 9.00135f-5 … -1.33804f-5 5.75934f-6\n", - " 1.25677f-5 0.000127111 8.09964f-5 -1.75825f-5 9.55642f-6\n", - " -1.91218f-6 9.21105f-5 9.24472f-5 -2.37863f-5 9.99772f-6\n", - " 1.86928f-5 0.000107805 9.78562f-5 -9.6838f-6 -4.90789f-6\n", - " 1.90321f-6 9.47662f-5 6.81009f-5 -1.21807f-5 -6.32489f-7\n", - " 4.47338f-6 8.07572f-5 3.97976f-5 … -1.41697f-5 1.19191f-5\n", - " 1.87808f-5 5.76884f-5 5.41935f-5 -8.75647f-6 8.57532f-6\n", - " 3.17319f-5 4.42051f-5 2.78475f-5 -1.52921f-5 9.61939f-6\n", - " 9.52952f-6 5.40576f-5 5.0719f-5 -1.71491f-5 9.88385f-6\n", - " 5.17164f-7 5.02872f-5 3.96059f-5 -1.22204f-5 1.37712f-5\n", - " 1.50076f-5 8.3402f-5 5.2399f-5 … 6.42688f-6 -1.64415f-5\n", - " 3.14011f-6 4.67719f-5 4.23395f-5 2.95754f-5 3.52819f-6\n", - "\n", - "[:, :, 3, 3] =\n", - " -2.13358f-5 -3.7937f-5 -2.88963f-5 … -5.82107f-5 -3.07175f-5\n", - " -8.75597f-6 -8.63936f-6 -1.4399f-6 -4.39748f-5 -3.46493f-5\n", - " -3.99843f-5 -7.18937f-5 -5.48815f-5 2.70835f-5 -2.11729f-5\n", - " -3.38074f-5 -8.02526f-5 -3.76416f-5 -1.26173f-5 -3.77554f-5\n", - " -5.74915f-6 -5.5656f-5 -3.4937f-5 -2.98646f-5 -3.76138f-5\n", - " -8.11329f-6 -2.84418f-5 -3.05806f-5 … -2.07718f-5 -4.41476f-5\n", - " -2.09927f-5 -5.92586f-5 -3.75812f-5 -1.33266f-5 -3.69712f-5\n", - " -2.36342f-5 -6.4137f-5 -2.07103f-6 -1.34164f-5 -4.23239f-5\n", - " -1.78956f-5 -6.06292f-5 -2.3291f-5 -4.24331f-6 -3.4804f-5\n", - " -2.15475f-5 -2.42674f-5 -6.6006f-5 -2.98182f-5 -6.30099f-5\n", - " -8.82353f-6 -4.50255f-5 -3.5869f-5 … -6.11666f-5 -3.6939f-5\n", - " -2.73957f-5 -5.63472f-5 -1.24705f-5 -2.32421f-5 -3.37618f-5\n", - " -1.91092f-5 -0.000121495 -4.93165f-5 -2.39933f-5 -3.47519f-5\n", - " ⋮ ⋱ ⋮ \n", - " -1.63307f-5 -5.76098f-5 -1.52878f-5 … -2.98358f-5 -3.83354f-5\n", - " -3.66333f-6 -5.44591f-5 -6.64642f-5 -2.32273f-5 -4.34126f-5\n", - " -1.08237f-5 -3.67816f-5 -7.95501f-5 -3.41968f-5 -3.89323f-5\n", - " 4.25022f-6 -4.04467f-5 -2.91673f-5 -6.07295f-6 -3.00576f-5\n", - " -2.61168f-5 -2.23294f-5 -5.08112f-5 -1.30184f-5 -4.02096f-5\n", - " -2.06026f-5 -3.47827f-5 -5.70607f-5 … -1.61879f-5 -4.02096f-5\n", - " -1.69183f-5 -3.22452f-5 -2.34737f-5 -1.40456f-5 -4.36295f-5\n", - " -1.28868f-5 -2.36515f-5 -4.06212f-5 -1.34111f-5 -4.07483f-5\n", - " -3.36046f-5 -1.2702f-5 -6.03335f-6 -1.00821f-5 -3.58138f-5\n", - " -1.11375f-5 -5.46132f-5 -6.48593f-6 -1.6174f-5 -3.2f-5\n", - " 1.48887f-5 -2.30397f-5 -1.40682f-5 … -8.4154f-6 -3.42971f-5\n", - " 2.61805f-5 -9.19203f-7 -5.72442f-7 -3.74935f-6 -2.72021f-6\n", - "\n", - ";;;; … \n", - "\n", - "[:, :, 1, 14] =\n", - " 1.33663f-5 4.85302f-6 3.58623f-6 … 1.37095f-5 1.52744f-6\n", - " 2.40446f-5 -2.13888f-5 -7.09837f-7 -2.65502f-5 -3.84853f-5\n", - " 2.15673f-5 -3.83408f-5 -6.04328f-6 -2.47308f-5 -6.90114f-5\n", - " 1.92326f-5 -3.71648f-5 -2.57633f-5 8.74128f-6 -5.50297f-5\n", - " 1.71903f-5 -3.32595f-5 -3.98588f-5 1.39541f-5 -5.28194f-5\n", - " 1.4936f-5 -3.24147f-5 -4.14944f-5 … 1.85891f-5 -4.86048f-5\n", - " 1.73485f-5 -2.66423f-5 -2.27061f-5 1.55611f-5 -4.97301f-5\n", - " 1.35476f-5 -3.83446f-5 -1.91897f-5 1.55515f-5 -4.66568f-5\n", - " 1.39347f-5 -3.14124f-5 -2.70663f-5 1.66001f-5 -4.76427f-5\n", - " 2.03841f-5 -2.95635f-5 -4.12947f-5 1.51247f-5 -4.52537f-5\n", - " 1.9246f-5 -4.71404f-5 -3.9704f-5 … 1.44092f-5 -4.75653f-5\n", - " 1.02947f-5 -2.77938f-5 -1.65234f-5 9.53047f-6 -4.31208f-5\n", - " 1.7194f-5 -2.49142f-5 -2.85847f-5 2.09133f-5 -5.40761f-5\n", - " ⋮ ⋱ ⋮ \n", - " 1.87105f-5 -5.13546f-5 -1.63782f-5 … 1.17234f-5 -4.45439f-5\n", - " 1.84893f-5 -4.52386f-5 -3.94779f-5 -2.03202f-5 -5.12953f-5\n", - " 1.86344f-5 -4.50125f-5 -4.02963f-5 6.36818f-6 -4.96217f-5\n", - " 1.55367f-5 -3.01873f-5 -4.11503f-5 1.33834f-5 -5.56446f-5\n", - " 1.6725f-5 -3.16873f-5 -3.58039f-5 2.98642f-5 -5.41134f-5\n", - " 1.91699f-5 -4.34588f-5 -5.96472f-6 … 3.36201f-5 -6.05377f-5\n", - " 1.02027f-5 -3.68617f-5 -1.26946f-6 1.37002f-5 -5.93067f-5\n", - " 2.02941f-5 -3.33365f-5 -2.79833f-5 -1.27102f-7 -4.2245f-5\n", - " 1.90906f-5 -4.93408f-5 -1.02181f-5 4.24493f-6 -6.76327f-5\n", - " 1.35392f-5 -3.17202f-5 -1.05651f-5 1.37745f-5 -6.95663f-5\n", - " 3.22179f-5 -3.63395f-5 -3.44376f-5 … -1.8031f-5 -3.67907f-5\n", - " 3.1825f-5 -2.07306f-5 -1.293f-5 -1.55749f-5 -5.47155f-5\n", - "\n", - "[:, :, 2, 14] =\n", - " -1.16588f-5 -1.2407f-5 3.49407f-6 … -2.00775f-5 2.20019f-5\n", - " -1.22088f-5 -8.70201f-6 7.22299f-6 -2.42942f-5 3.22809f-5\n", - " -4.46188f-6 1.12563f-6 1.55854f-6 1.82184f-6 2.75427f-5\n", - " -6.50865f-6 6.71314f-6 -4.40675f-6 -1.15001f-5 1.37051f-5\n", - " 1.46834f-6 1.8095f-6 -1.09453f-5 -1.93772f-5 1.18843f-5\n", - " 9.96785f-6 -6.14118f-6 -1.91112f-5 … -2.1082f-5 2.99534f-6\n", - " 7.32546f-6 -1.10439f-5 -1.05676f-5 -1.90141f-5 6.74762f-6\n", - " 4.39465f-7 -8.66755f-7 -3.26735f-5 -1.68382f-5 5.62134f-6\n", - " 1.84562f-6 8.40483f-6 -3.41861f-5 -1.86092f-5 4.87877f-6\n", - " 3.29915f-6 -6.79852f-7 -3.56272f-5 -1.23352f-5 7.23447f-6\n", - " -7.65528f-7 -3.32969f-6 -2.83803f-5 … -1.67057f-5 1.49106f-6\n", - " 3.31928f-6 -5.36953f-6 -2.76984f-5 -1.32463f-5 1.17819f-5\n", - " 8.95364f-6 -1.72866f-6 -3.75768f-5 1.09688f-5 -4.38009f-6\n", - " ⋮ ⋱ ⋮ \n", - " -3.17671f-6 5.2434f-6 -2.69256f-5 … -4.84416f-5 8.25841f-6\n", - " -3.05886f-6 1.01041f-6 -2.00724f-5 -2.82864f-5 -3.48747f-6\n", - " -3.11412f-6 -2.6824f-6 -1.84471f-5 -2.03048f-5 -8.71729f-6\n", - " 2.70974f-6 -4.42211f-6 -1.84807f-5 -1.75151f-5 -2.35158f-5\n", - " 5.19128f-6 -2.57658f-6 -2.41242f-5 2.26607f-5 2.58383f-6\n", - " -4.43079f-7 9.24148f-6 -2.55504f-5 … -4.00863f-5 -2.20372f-7\n", - " 3.4482f-6 9.72595f-6 -3.16414f-5 -3.03052f-5 -2.44575f-5\n", - " 3.71146f-6 3.85173f-6 -2.39893f-5 -2.33657f-5 1.92735f-5\n", - " -4.2958f-7 6.57322f-6 -9.47067f-6 -3.18343f-5 1.15899f-5\n", - " 3.97834f-7 -3.9832f-6 -2.29879f-5 -2.98401f-5 1.83165f-5\n", - " -2.48082f-6 2.33363f-6 -1.80932f-5 … -3.67383f-5 3.11761f-5\n", - " 1.40496f-5 -1.21956f-5 -1.12533f-5 -3.16762f-5 -1.81151f-5\n", - "\n", - "[:, :, 3, 14] =\n", - " -7.93633f-6 -1.74683f-5 -3.30643f-5 … -2.41837f-6 -2.35818f-5\n", - " -2.14967f-5 -3.45701f-5 -4.38043f-5 -3.09205f-6 -2.01295f-5\n", - " -2.20267f-5 -3.53623f-5 -3.24018f-5 6.1502f-6 -1.65125f-5\n", - " -2.08004f-5 -4.26441f-5 -2.3583f-5 -2.71736f-5 -1.68349f-5\n", - " -1.79565f-5 -4.85526f-5 -4.17504f-5 -2.07986f-5 -1.3357f-5\n", - " -2.09896f-5 -2.60134f-5 -5.04995f-5 … -1.76324f-5 -7.67686f-6\n", - " -3.54343f-5 -3.14013f-5 -2.514f-5 -2.07742f-5 -1.19488f-5\n", - " -2.83928f-5 -4.10369f-5 -2.33926f-5 -2.41132f-5 -1.1116f-5\n", - " -1.46165f-5 -4.19793f-5 -2.49181f-5 -2.58115f-5 -1.18515f-5\n", - " -2.30708f-5 -4.50573f-5 -3.26967f-5 -2.11465f-5 -1.01586f-5\n", - " -3.29882f-5 -3.43335f-5 -1.62386f-5 … -2.59596f-5 -7.33063f-6\n", - " -1.49393f-5 -4.06295f-5 -2.20788f-5 -2.70018f-5 -1.27446f-5\n", - " -2.14708f-5 -2.84157f-5 -4.1274f-5 -2.99866f-5 -5.0091f-6\n", - " ⋮ ⋱ ⋮ \n", - " -2.12108f-5 -5.19599f-5 -3.05368f-5 … -2.42368f-5 -1.09666f-5\n", - " -2.10179f-5 -4.74168f-5 -4.27474f-5 -1.5578f-5 -1.5984f-5\n", - " -2.09912f-5 -3.35144f-5 -3.88128f-5 -2.7594f-5 3.39654f-6\n", - " -1.95054f-5 -3.31035f-5 -3.50551f-5 -4.12634f-5 -2.56047f-5\n", - " -2.33043f-5 -3.33993f-5 -3.51916f-5 -1.40093f-5 -1.22362f-5\n", - " -3.31995f-5 -3.87969f-5 -2.05606f-5 … -2.27915f-5 -1.53536f-5\n", - " -1.52208f-5 -5.35259f-5 -2.78927f-5 -2.51014f-5 -2.51453f-5\n", - " -2.30069f-5 -4.30843f-5 -5.04652f-5 -2.78615f-5 1.22018f-6\n", - " -3.28321f-5 -3.75078f-5 -2.48635f-5 -2.51772f-5 -1.9633f-5\n", - " -7.88154f-6 -4.40925f-5 -3.61323f-5 -2.70516f-5 -1.97826f-5\n", - " -1.74243f-5 -3.12033f-5 -4.37996f-5 … -2.22341f-5 -3.2528f-5\n", - " -3.9618f-5 -5.10782f-5 -3.54032f-5 -7.39695f-6 -1.27495f-7\n", - "\n", - "[:, :, 1, 15] =\n", - " -1.42322f-5 -8.28306f-7 -1.66416f-5 … -2.11629f-5 -8.51041f-6\n", - " 1.78201f-5 -2.03631f-5 -2.93015f-5 -3.91225f-5 -2.38064f-5\n", - " -4.01629f-6 7.20289f-6 -5.25714f-6 -1.72928f-5 -3.32815f-5\n", - " -2.82395f-5 -4.71707f-6 -3.74868f-5 -3.02409f-5 -1.64734f-5\n", - " -1.81109f-5 5.98791f-6 -5.90359f-5 -1.7286f-5 -2.49253f-5\n", - " 8.58902f-6 -1.84229f-5 -5.12326f-5 … -3.34777f-5 1.37166f-5\n", - " 4.36534f-5 -1.47507f-5 -3.43016f-5 -2.38659f-5 1.95873f-5\n", - " 3.05164f-5 -1.30216f-5 -2.95457f-5 -2.2512f-5 -1.54837f-5\n", - " 3.12878f-5 -2.15927f-5 -3.26094f-5 -1.72527f-5 -6.8008f-6\n", - " 3.13024f-5 -2.1157f-5 -3.26002f-5 -1.25f-5 -7.81954f-6\n", - " 3.44746f-5 -2.64019f-5 -2.90338f-5 … -8.88944f-6 -1.10185f-5\n", - " 3.56305f-5 -2.51382f-5 -3.36172f-5 -1.70083f-5 -5.94835f-6\n", - " 3.22221f-5 -1.83961f-5 -3.63357f-5 -1.24371f-5 -7.30951f-6\n", - " ⋮ ⋱ ⋮ \n", - " 3.40669f-5 -2.65137f-5 -2.84164f-5 … -1.8529f-5 -8.00771f-6\n", - " 3.23159f-5 -1.99016f-5 -3.71337f-5 -1.29374f-5 -2.77934f-6\n", - " 3.06859f-5 -1.97791f-5 -3.23031f-5 -1.18352f-5 -5.00627f-6\n", - " 2.92809f-5 -2.24281f-5 -2.87513f-5 -1.18677f-5 -1.03392f-5\n", - " 3.13974f-5 -2.08298f-5 -3.25597f-5 -1.30022f-5 -1.88982f-5\n", - " 3.4333f-5 -2.63916f-5 -2.87311f-5 … -1.94221f-5 -9.26061f-6\n", - " 3.27389f-5 -2.00057f-5 -3.68491f-5 -1.92571f-5 -5.71103f-6\n", - " 3.07853f-5 -1.95269f-5 -3.19953f-5 -1.89164f-5 -6.1808f-6\n", - " 3.26149f-5 -2.96727f-5 -2.83632f-6 -1.86268f-5 -3.14514f-6\n", - " 4.09495f-5 -8.29647f-6 -2.99314f-5 -1.44101f-5 -7.22131f-6\n", - " 3.36188f-5 -5.28598f-6 2.68256f-6 … -1.77605f-5 -1.72127f-5\n", - " 1.5718f-5 -1.06357f-5 -1.58782f-5 -5.37788f-6 -4.31066f-5\n", - "\n", - "[:, :, 2, 15] =\n", - " -2.71481f-5 -6.23731f-5 -6.94312f-5 … -3.79918f-5 -7.59295f-6\n", - " -5.76539f-5 -6.44929f-5 -6.63022f-5 -4.24491f-5 -3.27987f-6\n", - " -3.08224f-5 -4.92059f-5 -5.37293f-5 -5.99305f-5 -9.06002f-6\n", - " -6.15207f-5 -5.25847f-5 -5.51519f-5 -5.93541f-5 -2.0077f-5\n", - " -8.6014f-5 -6.91141f-5 -0.000100434 -9.90005f-5 -5.61515f-5\n", - " -5.77686f-5 -0.000104908 -4.87899f-5 … -8.39651f-5 -5.26676f-5\n", - " -6.25512f-5 -0.00010837 -2.85305f-5 -8.89585f-5 -4.53344f-6\n", - " -5.81532f-5 -0.00010201 -2.09666f-5 -8.76519f-5 -1.36613f-5\n", - " -6.27479f-5 -0.000103551 -1.84921f-5 -9.89555f-5 -2.4265f-5\n", - " -6.32125f-5 -0.000103278 -1.85082f-5 -9.37496f-5 -2.54385f-5\n", - " -6.8004f-5 -0.000106113 -1.69805f-5 … -9.61866f-5 -2.37284f-5\n", - " -6.86537f-5 -9.7697f-5 -2.13105f-5 -0.00010507 -1.76843f-5\n", - " -6.46231f-5 -8.94775f-5 -3.06477f-5 -9.72222f-5 -1.85753f-5\n", - " ⋮ ⋱ ⋮ \n", - " -6.71139f-5 -0.000105918 -1.67125f-5 … -9.68319f-5 -2.71956f-5\n", - " -6.31445f-5 -9.47925f-5 -2.28639f-5 -9.50865f-5 -2.68261f-5\n", - " -6.359f-5 -9.87085f-5 -2.61428f-5 -0.000101427 -1.90983f-5\n", - " -6.16764f-5 -0.000107163 -1.4755f-5 -9.67539f-5 -2.50946f-5\n", - " -5.63963f-5 -0.000103415 -2.06108f-5 -8.91569f-5 -3.287f-5\n", - " -6.76495f-5 -0.000106304 -1.62743f-5 … -9.48742f-5 -3.3864f-5\n", - " -6.35575f-5 -9.48656f-5 -2.22142f-5 -0.000101478 -2.79286f-5\n", - " -6.35552f-5 -9.81299f-5 -2.57596f-5 -0.000101187 -2.81493f-5\n", - " -5.26933f-5 -0.000102399 -1.71688f-5 -9.89127f-5 -2.84281f-5\n", - " -5.17102f-5 -7.33952f-5 -3.16825f-5 -0.000103105 -1.59417f-5\n", - " -3.87775f-5 -5.88389f-5 -4.60914f-5 … -8.24853f-5 -6.29939f-6\n", - " -1.01194f-5 -5.58116f-5 1.2544f-8 -2.73959f-5 7.38215f-7\n", - "\n", - "[:, :, 3, 15] =\n", - " -1.05581f-5 4.04572f-6 4.48011f-6 … 2.63798f-6 1.36589f-6\n", - " -1.10731f-5 -1.23928f-6 1.43986f-6 -3.17876f-5 -4.85235f-6\n", - " -3.0389f-5 -1.88746f-5 -1.09316f-6 -2.05259f-5 -6.97609f-6\n", - " -2.31538f-5 -1.87011f-5 4.27785f-5 5.31586f-6 -1.19822f-7\n", - " 1.77212f-5 -3.64338f-5 4.40577f-6 2.44596f-5 2.9193f-5\n", - " -5.06462f-6 -3.26252f-5 -4.55707f-6 … -3.38217f-6 1.32613f-5\n", - " -2.9407f-5 4.73535f-5 1.84007f-5 1.7393f-5 -7.39461f-6\n", - " -3.36971f-6 3.14314f-6 2.49574f-5 2.26891f-5 1.02206f-5\n", - " -2.95847f-6 1.58656f-5 1.15972f-5 2.19764f-5 8.8516f-6\n", - " -2.94252f-6 1.59466f-5 1.12964f-5 2.21957f-5 4.75332f-6\n", - " -4.56594f-6 1.67592f-5 1.0183f-5 … 2.72823f-5 4.26673f-6\n", - " -3.58368f-7 9.34662f-6 1.08674f-5 2.61219f-5 8.15499f-6\n", - " 3.7385f-6 1.83577f-5 1.69644f-5 2.75185f-5 6.39195f-6\n", - " ⋮ ⋱ ⋮ \n", - " -5.14102f-6 1.59719f-5 9.30539f-6 … 2.49045f-5 5.96053f-6\n", - " 1.24572f-6 7.70415f-6 1.15018f-5 2.81032f-5 8.23876f-6\n", - " -7.85542f-8 2.54156f-5 1.60599f-5 3.06293f-5 7.06368f-6\n", - " -6.55128f-6 1.78671f-5 5.77923f-6 1.16846f-5 -2.76677f-7\n", - " -3.26391f-6 9.23154f-6 1.38038f-5 1.70896f-5 6.30555f-6\n", - " -4.88428f-6 1.65113f-5 1.02148f-5 … 2.52434f-5 5.1654f-6\n", - " 7.36047f-7 8.26562f-6 1.18475f-5 2.72033f-5 4.53531f-6\n", - " -2.71816f-8 2.55809f-5 1.5987f-5 2.37181f-5 5.42823f-6\n", - " -5.75722f-6 3.17648f-5 -1.50827f-5 2.2849f-5 6.46165f-6\n", - " 5.26474f-6 2.6f-5 1.6355f-5 2.14678f-5 -6.35829f-6\n", - " 7.64364f-6 1.2984f-5 1.15768f-5 … 9.00361f-6 -2.41165f-5\n", - " 5.46411f-6 2.37213f-5 1.45142f-5 1.40133f-5 8.73353f-6\n", - "\n", - "[:, :, 1, 16] =\n", - " -5.03802f-6 -1.53795f-6 4.60553f-6 … 8.59933f-6 1.34548f-5\n", - " -1.33139f-5 8.2283f-6 4.72714f-6 1.69092f-5 4.31501f-6\n", - " -1.87007f-6 1.15159f-5 5.04875f-6 5.16017f-6 8.04131f-6\n", - " 6.35418f-8 3.08693f-6 -1.29598f-5 8.22963f-6 2.8338f-6\n", - " -2.16983f-5 -2.4079f-5 7.10067f-6 1.87707f-5 4.5435f-6\n", - " -1.25426f-5 -1.95699f-6 -3.54366f-7 … 3.89216f-6 4.52989f-6\n", - " -1.91067f-5 1.63425f-5 -2.13026f-5 9.29225f-6 1.57625f-6\n", - " -1.27482f-5 -6.22103f-6 -1.65056f-5 1.1202f-5 1.11683f-5\n", - " -2.21989f-5 -3.97914f-6 -2.48552f-5 3.4693f-6 6.54992f-6\n", - " -1.66821f-5 3.08557f-6 -2.25274f-5 2.99258f-6 1.19086f-5\n", - " -1.35066f-5 -3.56337f-5 -3.92915f-5 … -6.58634f-7 4.13911f-6\n", - " 2.01907f-7 -4.15685f-6 -2.32959f-5 -3.8229f-6 6.58547f-6\n", - " -2.55842f-5 -4.38035f-6 -7.13148f-6 -9.832f-7 3.90915f-6\n", - " ⋮ ⋱ ⋮ \n", - " -2.38357f-5 -1.98589f-5 -2.21745f-5 … 9.79263f-6 -5.56729f-7\n", - " -2.90955f-5 -2.92635f-6 -2.00649f-5 6.53204f-7 4.46353f-6\n", - " -2.28595f-5 -1.98966f-5 -8.38325f-6 1.59205f-5 4.65142f-6\n", - " -2.96641f-5 -2.09721f-5 -1.28431f-5 4.71605f-6 2.15453f-6\n", - " -2.20343f-5 -3.17625f-5 -5.10912f-5 6.94432f-6 1.09819f-5\n", - " -2.3516f-6 1.31662f-6 -1.44183f-5 … -4.05642f-6 -2.73961f-6\n", - " -2.16968f-5 -6.0583f-6 -9.5029f-6 1.47617f-5 1.67887f-6\n", - " -3.17614f-6 -2.08875f-5 2.43521f-5 -9.33056f-6 8.64669f-6\n", - " -2.03552f-5 9.42903f-6 2.71075f-6 9.54452f-6 2.22076f-5\n", - " 5.75616f-6 1.0806f-5 -2.36865f-7 4.67347f-6 1.61059f-5\n", - " -1.40951f-5 -1.02704f-5 -3.18635f-6 … -8.39052f-7 2.45332f-5\n", - " -3.1905f-6 -3.46778f-6 -3.62573f-6 1.71537f-5 2.06471f-5\n", - "\n", - "[:, :, 2, 16] =\n", - " -2.15884f-5 -1.14227f-5 -1.62082f-5 … -2.55764f-5 -3.13291f-5\n", - " -2.23083f-5 1.90108f-5 1.10183f-5 -5.42675f-7 -1.1911f-5\n", - " -9.12809f-6 3.98662f-5 2.59647f-6 -3.74951f-5 -8.69153f-6\n", - " -7.62137f-6 3.32652f-5 1.61093f-5 -2.59258f-5 7.50499f-6\n", - " -3.47015f-6 2.78749f-7 1.91968f-5 -2.03829f-5 1.30214f-5\n", - " -4.30571f-6 1.6286f-5 3.87096f-5 … -1.3211f-5 1.14148f-5\n", - " 1.43065f-7 3.21748f-5 4.9644f-5 -1.63238f-5 8.8627f-6\n", - " 4.03356f-6 4.056f-5 3.15086f-5 -1.25394f-5 2.2954f-6\n", - " -1.13654f-5 7.34695f-5 3.3689f-5 -1.45838f-5 1.1383f-7\n", - " -2.92628f-5 5.7042f-5 -2.2081f-6 3.45306f-6 1.79821f-6\n", - " -1.09181f-5 4.72039f-5 1.92767f-5 … 2.18812f-6 -6.04413f-6\n", - " -1.69237f-5 1.27184f-5 1.6065f-5 5.88449f-6 -1.0836f-5\n", - " -8.81984f-6 2.0075f-5 4.56131f-5 -8.64202f-6 -1.25129f-5\n", - " ⋮ ⋱ ⋮ \n", - " 1.20612f-5 5.55669f-5 6.90552f-5 … -2.94486f-5 -1.32609f-5\n", - " 2.02f-5 4.96709f-5 6.90785f-5 -1.91592f-5 4.8409f-6\n", - " 1.83125f-5 6.40737f-5 6.18068f-5 -2.0052f-5 1.90681f-6\n", - " 7.91827f-6 6.21826f-5 6.86201f-5 -2.40519f-5 7.56757f-6\n", - " 4.21825f-7 4.81752f-5 4.54214f-5 -1.54302f-5 5.53415f-6\n", - " 2.0481f-6 3.59647f-5 -6.68709f-6 … -8.75948f-6 5.73184f-6\n", - " 2.2825f-6 4.04816f-5 1.54492f-5 1.44027f-5 5.88194f-6\n", - " -2.17447f-5 1.23614f-8 -5.07233f-7 -9.76107f-6 -5.2898f-6\n", - " -4.04921f-6 3.28196f-5 1.09735f-5 -2.13296f-5 -1.0723f-5\n", - " -6.53629f-7 9.64771f-6 2.88722f-5 -5.34081f-6 1.22531f-5\n", - " -3.75205f-6 5.36633f-6 2.31161f-5 … -9.85085f-7 1.42955f-5\n", - " 5.52764f-6 2.27731f-5 3.81831f-5 1.51786f-5 1.40957f-5\n", - "\n", - "[:, :, 3, 16] =\n", - " 2.09887f-5 2.48268f-5 3.26262f-5 … 2.52067f-5 1.83206f-5\n", - " 1.60499f-5 3.12218f-5 3.18235f-5 1.53867f-5 2.27519f-5\n", - " 3.48141f-5 9.79703f-6 2.88586f-6 1.14747f-5 1.88519f-5\n", - " 8.71227f-7 2.75015f-5 4.37508f-5 1.78847f-5 1.87048f-5\n", - " 2.18388f-5 4.08884f-5 3.65291f-5 1.0776f-5 2.12948f-5\n", - " 3.40946f-5 4.08884f-5 1.94329f-5 … 1.0377f-5 2.05071f-5\n", - " 1.92069f-5 1.97018f-5 3.90101f-5 1.92215f-5 1.76457f-5\n", - " 2.36644f-5 9.86611f-6 2.17295f-5 2.36401f-5 1.73641f-5\n", - " 2.31846f-5 2.88264f-5 3.39537f-5 2.51549f-5 1.47432f-5\n", - " 2.57393f-5 -9.20106f-6 3.30965f-5 4.29673f-5 1.3416f-5\n", - " 1.70486f-5 1.40648f-5 1.81989f-5 … 3.695f-5 1.24074f-5\n", - " 2.94645f-5 4.57801f-5 2.73179f-5 3.59657f-5 8.28307f-6\n", - " 2.56688f-5 1.4644f-5 2.24626f-5 2.24926f-5 8.88082f-6\n", - " ⋮ ⋱ ⋮ \n", - " 2.03935f-5 3.42854f-5 5.86772f-6 … 4.35592f-7 1.05793f-5\n", - " 1.92946f-5 2.94853f-5 3.70969f-5 1.82321f-5 1.76721f-5\n", - " 1.20686f-5 3.67779f-5 2.70968f-5 6.38749f-6 1.87694f-5\n", - " 1.40735f-5 3.77764f-6 2.13491f-5 1.19133f-5 1.27518f-5\n", - " 2.57298f-5 2.21913f-5 1.81671f-5 1.72985f-5 1.27492f-5\n", - " 1.5189f-5 2.35697f-5 1.81276f-5 … 1.98066f-5 7.02275f-6\n", - " 2.92245f-6 4.0379f-6 3.79888f-5 2.83694f-5 1.09155f-5\n", - " 2.72736f-5 8.72931f-6 -5.25945f-6 2.87468f-5 2.20531f-5\n", - " 2.43792f-5 1.75247f-5 6.85327f-6 -2.04517f-5 -5.39688f-6\n", - " 7.97139f-6 9.4859f-6 -1.00873f-6 2.4968f-5 1.06572f-5\n", - " 9.65145f-6 1.21162f-5 1.81417f-8 … 9.58578f-6 6.21811f-6\n", - " -1.35314f-5 -1.89171f-5 -2.95601f-5 -2.69976f-5 -1.33066f-5" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "g" - ] - }, - { - "cell_type": "markdown", - "id": "a57b3c8d", - "metadata": {}, - "source": [ - "# Evaluation Function" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "02f69609", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "evaluate (generic function with 1 method)" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "function evaluate(model, test_loader)\n", - " preds = []\n", - " targets = []\n", - " for (x, y) in test_loader\n", - " # Get model predictions\n", - " # Note argmax of nd-array gives CartesianIndex\n", - " # Need to grab the first element of each CartesianIndex to get the true index\n", - " logits = model(x)\n", - " ŷ = map(i -> i[1], argmax(logits, dims=1))\n", - " append!(preds, ŷ)\n", - "\n", - " # Get true labels\n", - " append!(targets, y)\n", - " end\n", - " accuracy = sum(preds .== targets) / length(targets)\n", - " return accuracy\n", - "end" - ] - }, - { - "cell_type": "markdown", - "id": "f2072bd8", - "metadata": {}, - "source": [ - "# Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "cc39bcab", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: redefinition of constant to. This may fail, cause incorrect answers, or produce other errors.\n" - ] - }, - { - "data": { - "text/plain": [ - "\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────\u001b[22m\n", - "\u001b[0m\u001b[1m \u001b[22m Time Allocations \n", - " ─────────────────────── ────────────────────────\n", - " Tot / % measured: 1.35ms / 0.0% 13.7KiB / 0.0% \n", - "\n", - " Section ncalls time %tot avg alloc %tot avg\n", - " ────────────────────────────────────────────────────────────────────\n", - "\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────\u001b[22m" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "# Setup timing output\n", - "const to = TimerOutput()" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "9b5c088c", - "metadata": {}, - "outputs": [ - { - "ename": "LoadError", - "evalue": "No derivative rule found for op %1174 = lastindex(%1172)::Int64 , try defining it using \n\n\tChainRulesCore.rrule(::typeof(lastindex), ::NTuple{4, Int64}) = ...\n", - "output_type": "error", - "traceback": [ - "No derivative rule found for op %1174 = lastindex(%1172)::Int64 , try defining it using \n\n\tChainRulesCore.rrule(::typeof(lastindex), ::NTuple{4, Int64}) = ...\n", - "", - "Stacktrace:", - " [1] error(s::String)", - " @ Base .\\error.jl:35", - " [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)", - " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:219", - " [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)", - " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:260", - " [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)", - " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:273", - " [5] gradtape(::Function, ::ResNet5, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)", - " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:300", - " [6] grad(::Function, ::ResNet5, ::Vararg{Any}; seed::Int64)", - " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:370", - " [7] grad(::Function, ::ResNet5, ::Vararg{Any})", - " @ Yota C:\\Users\\Yash\\.julia\\packages\\Yota\\KJQ6n\\src\\grad.jl:362", - " [8] macro expansion", - " @ .\\In[45]:14 [inlined]", - " [9] macro expansion", - " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", - " [10] macro expansion", - " @ .\\In[45]:9 [inlined]", - " [11] top-level scope", - " @ C:\\Users\\Yash\\.julia\\packages\\TimerOutputs\\4yHI4\\src\\TimerOutput.jl:237 [inlined]", - " [12] top-level scope", - " @ .\\In[45]:0", - " [13] eval", - " @ .\\boot.jl:368 [inlined]", - " [14] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)", - " @ Base .\\loading.jl:1428" - ] - } - ], - "source": [ - "last_loss = 0;\n", - "@timeit to \"total_training_time\" begin\n", - " for epoch in 1:10\n", - " timing_name = epoch > 1 ? \"average_epoch_training_time\" : \"train_jit\"\n", - "\n", - " # Create lazily evaluated augmented training data\n", - " train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));\n", - "\n", - " @timeit to timing_name begin\n", - " losses = []\n", - " for (x, y) in train_batches\n", - " # loss_function does forward pass\n", - " # Yota.jl grad function computes model parameter gradients in g[2]\n", - " loss, g = grad(loss_function, model, x, y)\n", - " \n", - " # Optimiser updates parameters\n", - " Optimisers.update!(state, model, g[2])\n", - " push!(losses, loss)\n", - " end\n", - " last_loss = mean(losses)\n", - " @info(\"epoch (mean(losses))\")\n", - " end\n", - " # timing_name = epoch > 1 ? \"average_inference_time\" : \"eval_jit\"\n", - " # @timeit to timing_name begin\n", - " # acc = evaluate(model, test_loader)\n", - " # @info(\"epoch (acc)\")\n", - " # end\n", - " end\n", - "end" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1955c486", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ace272d", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Julia 1.8.2", - "language": "julia", - "name": "julia-1.8" - }, - "language_info": { - "file_extension": ".jl", - "mimetype": "application/julia", - "name": "julia", - "version": "1.8.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}