Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 127 additions & 41 deletions handwriting_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"metadata": {
"id": "naG4_UDjlpdV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5b8a92f0-04b9-4d79-c54d-fe5eeaf1f827"
"outputId": "c74612f5-adf7-4748-8339-3395d60e5275"
},
"outputs": [
{
Expand Down Expand Up @@ -169,7 +169,7 @@
" # print(\"processing data for optimal results...\")\n",
" # x_train, y_train, x_test, y_test = shape_data(x_train, y_train, x_test, y_test)\n",
"\n",
" return [x_train, x_test, y_train, y_test]\n",
" return [x_train, y_train, x_test, y_test]\n",
"\n",
"\n",
"def load_convex_data():\n",
Expand All @@ -195,12 +195,12 @@
" # split data\n",
" x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=17)\n",
"\n",
" return [x_train, x_test, y_train, y_test]"
" return [x_train, y_train, x_test, y_test]"
],
"metadata": {
"id": "jYxDV2NxltT5"
},
"execution_count": 20,
"execution_count": 19,
"outputs": []
},
{
Expand All @@ -226,7 +226,7 @@
" model.add(MaxPooling2D((2, 2)))\n",
" model.add(Flatten())\n",
" model.add(Dense(64, activation='relu'))\n",
" model.add(Dense(10, activation='softmax'))\n",
" model.add(Dense(num_chars, activation='softmax'))\n",
"\n",
" # compile model\n",
" print(\"compiling model...\")\n",
Expand All @@ -243,35 +243,38 @@
" x_test = emnist_data[2]\n",
" y_test = emnist_data[3]\n",
"\n",
" y_train = tf.keras.utils.to_categorical(y_train, num_chars)\n",
" y_test = tf.keras.utils.to_categorical(y_test, num_chars)\n",
"\n",
" # train model on MNIST\n",
" print(\"training on MNIST data...\")\n",
" model.fit(x_train, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test))\n",
"\n",
" # save weights\n",
" save_weights(model)\n",
" return model\n",
"\n",
"\n",
"def save_weights(model):\n",
"def save_weights(model, dir):\n",
" # save weights\n",
" print(\"saving MNIST weights...\")\n",
" weights = model.get_weights()\n",
" with open(\"mnist_weights.csv\", 'w', newline='') as csvfile:\n",
" with open(dir, 'w', newline='') as csvfile:\n",
" writer = csv.writer(csvfile)\n",
" for weight in weights:\n",
" writer.writerow(weight.flatten())\n",
" \n",
"\n",
"def load_weights(filepath):\n",
"def load_weights(model, dir):\n",
" # load from file\n",
" print(\"loading weights for transfer learning...\")\n",
" with open(filepath, 'r') as csvfile:\n",
" with open(dir, 'r') as csvfile:\n",
" reader = csv.reader(csvfile)\n",
" weights = []\n",
" for row in reader:\n",
" weights.append(row.astype(float))\n",
" \n",
" # return weights\n",
" return weights\n",
" # return model\n",
" model = model.set_weights(weights)\n",
" return model\n",
"\n",
" \n",
"\"\"\"\n",
Expand All @@ -282,7 +285,7 @@
" x_train = data[0]\n",
" y_train = data[1]\n",
" x_test = data[2]\n",
" x_test = data[3]\n",
" y_test = data[3]\n",
"\n",
" # load weights for transfer learning\n",
" print(\"transferring learning & retraining...\")\n",
Expand All @@ -296,7 +299,7 @@
"metadata": {
"id": "t3g5Duu_C8uF"
},
"execution_count": 21,
"execution_count": 20,
"outputs": []
},
{
Expand Down Expand Up @@ -326,27 +329,26 @@
{
"cell_type": "code",
"source": [
"print(\" ::: STARTED MODEL TRAINING ::: \")\n",
"\n",
"# load data #\n",
"emnist_data = load_emnist()\n",
"\n",
"# train model for transfer #\n",
"handwriting_model = model_architecture()\n",
"handwriting_model = train_emnist(handwriting_model, emnist_data)\n",
"emnist_model = model_architecture()\n",
"emnist_model = train_emnist(emnist_model, emnist_data)\n",
"\n",
"print(emnist_model.summary())\n",
"\n",
"# save weights #\n",
"save_weights(handwriting_model)"
"save_weights(emnist_model, dir='mnist_weights')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 782
"base_uri": "https://localhost:8080/"
},
"id": "bRyicZ01Ew1B",
"outputId": "30ae983a-203e-4489-967c-6111e8a8eb4f"
"outputId": "5457ee2e-04de-4c69-d34d-2fa8a418f841"
},
"execution_count": 22,
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
Expand All @@ -356,25 +358,86 @@
"loading data for transfer learning...\n",
"defining model architecture...\n",
"compiling model...\n",
"training on MNIST data...\n"
]
},
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-20dac4a5fe8b>\u001b[0m in \u001b[0;36m<cell line: 8>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# train model for transfer #\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mhandwriting_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_architecture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mhandwriting_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_emnist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandwriting_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0memnist_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# save weights #\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-21-11e3aca60f6b>\u001b[0m in \u001b[0;36mtrain_emnist\u001b[0;34m(model, emnist_data)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# train model on MNIST\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training on MNIST data...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidation_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# save weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;31m# `tf.debugging.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/engine/data_adapter.py\u001b[0m in \u001b[0;36m_check_data_cardinality\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 1850\u001b[0m )\n\u001b[1;32m 1851\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"Make sure all arrays contain the same number of samples.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1852\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1853\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1854\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Data cardinality is ambiguous:\n x sizes: 112800\n y sizes: 18800\nMake sure all arrays contain the same number of samples."
"training on MNIST data...\n",
"Epoch 1/5\n",
"882/882 [==============================] - 112s 125ms/step - loss: 2.7690 - accuracy: 0.3033 - val_loss: 0.9500 - val_accuracy: 0.7102\n",
"Epoch 2/5\n",
"882/882 [==============================] - 104s 118ms/step - loss: 0.6859 - accuracy: 0.7786 - val_loss: 0.6115 - val_accuracy: 0.8078\n",
"Epoch 3/5\n",
"882/882 [==============================] - 103s 117ms/step - loss: 0.5122 - accuracy: 0.8288 - val_loss: 0.5290 - val_accuracy: 0.8254\n",
"Epoch 4/5\n",
"882/882 [==============================] - 103s 116ms/step - loss: 0.4492 - accuracy: 0.8457 - val_loss: 0.4931 - val_accuracy: 0.8365\n",
"Epoch 5/5\n",
"882/882 [==============================] - 104s 118ms/step - loss: 0.4137 - accuracy: 0.8548 - val_loss: 0.4704 - val_accuracy: 0.8453\n",
"Model: \"sequential_4\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" conv2d_8 (Conv2D) (None, 26, 26, 32) 320 \n",
" \n",
" max_pooling2d_8 (MaxPooling (None, 13, 13, 32) 0 \n",
" 2D) \n",
" \n",
" conv2d_9 (Conv2D) (None, 11, 11, 64) 18496 \n",
" \n",
" max_pooling2d_9 (MaxPooling (None, 5, 5, 64) 0 \n",
" 2D) \n",
" \n",
" flatten_4 (Flatten) (None, 1600) 0 \n",
" \n",
" dense_8 (Dense) (None, 64) 102464 \n",
" \n",
" dense_9 (Dense) (None, 84) 5460 \n",
" \n",
"=================================================================\n",
"Total params: 126,740\n",
"Trainable params: 126,740\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n",
"saving MNIST weights...\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Train on Math Symbols ##\n",
"In order to get proper OCR for symbols like '=', '+', etc. we must also transfer the learning from the EMNIST dataset to the Kaggle Math Symbols dataset.\n",
"\n",
"Then only can training be transferred to the final handwriting model for running/further training."
],
"metadata": {
"id": "b9jSz48m-E4Y"
}
},
{
"cell_type": "code",
"source": [
"\"\"\"\n",
"NOTE: this part is still heavily in development\n",
"\n",
"# load data #\n",
"symbol_data = load_math_data()\n",
"\n",
"# train model for transfer #\n",
"symbol_model = model_architecture()\n",
"symbol_model = load_weights(symbol_model, 'mnist_weights.csv')\n",
"symbol_model = train_emnist(symbol_model, symbol_data)\n",
"\n",
"print(symbol_model.summary())\n",
"\n",
"# save weights #\n",
"save_weights(symbol_model, dir='symbol_weights.csv')\n",
"\"\"\""
],
"metadata": {
"id": "wUlyVQBF-WVa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
Expand All @@ -387,13 +450,36 @@
},
{
"cell_type": "code",
"source": [],
"source": [
"# load data #\n",
"handwriting_data = load_convex_data()\n",
"\n",
"# train model for transfer #\n",
"handwriting_model = model_architecture()\n",
"handwriting_model = load_weights(handwriting_model, 'symbol_weights.csv')\n",
"handwriting_model = train_emnist(handwriting_model, handwriting_data)\n",
"\n",
"print(handwriting_model.summary())\n",
"\n",
"# save weights #\n",
"save_weights(handwriting_model, 'handwriting_weights.csv')"
],
"metadata": {
"id": "R4Utf-jOGHhM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Testing ##\n",
"Only if necessary, here are some tests to run (unfinished)."
],
"metadata": {
"id": "OC4fEZfv_QAG"
}
},
{
"cell_type": "code",
"source": [
Expand Down