{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "jVehH4MNhbTP" }, "source": [ "# **CNN 練習**" ] }, { "cell_type": "markdown", "metadata": { "id": "98Q6g2Bfu6hB" }, "source": [ "## 匯入所需套件" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BqJrEZc6Kws9" }, "outputs": [], "source": [ "# tf.keras.utils.plot_model 需要安裝的套件\n", "!pip install pydot\n", "!pip install graphviz" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hvShL5x7Kws-" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C-K6b_vfX3iJ" }, "outputs": [], "source": [ "from tensorflow import keras\n", "from tensorflow.keras.datasets import cifar10\n", "from tensorflow.keras.models import Model, load_model\n", "from tensorflow.keras.layers import (Input, Dense, Dropout,\n", " Activation, Flatten, Conv2D,\n", " MaxPooling2D)\n", "from tensorflow.keras.utils import plot_model" ] }, { "cell_type": "markdown", "metadata": { "id": "awautmYIh4a2" }, "source": [ "## Cifar10 資料讀入及前處理" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IKC4TPyHX5NE" }, "outputs": [], "source": [ "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", "\n", "print('x_train shape:', x_train.shape)\n", "print('y_train.shape:', y_train.shape)\n", "print(x_train.shape[0], 'train samples')\n", "print(x_test.shape[0], 'test samples')\n", "\n", "# x_train.shape: 四個維度:第 1 維度為筆數、第 2, 3 維度為影像大小 32*32、第 4 維度是 RGB 三原色,所以是 3\n", "# x_train 中有 50000 筆訓練資料,以及 x_test 中有 10000 筆的測試資料" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CoW9s8z0X6w3" }, "outputs": [], "source": [ "# image preprocessing\n", "x_train = x_train.astype('float32')\n", "x_test = x_test.astype('float32')\n", "\n", "# 將 features(照片影像特徵值)標準化,可以提高模型預測的準確度,並且更快收斂\n", "x_train /= 255 # rescaling\n", "x_test /= 255 # rescaling" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eSXBilArX8ih" }, "outputs": [], "source": [ "# 將訓練資料與測試資料的label,進行Onehot encoding轉換\n", "num_classes = 10\n", "y_train = np.eye(num_classes, dtype='float32')[y_train[:, 0]]\n", "y_test = np.eye(num_classes, dtype='float32')[y_test[:, 0]]\n", "\n", "print('y_train shape:', y_train.shape)\n", "print('y_test shape:', y_test.shape)" ] }, { "cell_type": "markdown", "source": [], "metadata": { "id": "Za3MKVSuKzuS" } }, { "cell_type": "markdown", "metadata": { "id": "rqCH66zEiJ8m" }, "source": [ "## 模型定義\n", "- 試著建立圖中的模型架構" ] }, { "cell_type": "markdown", "metadata": { "id": "cFqXmMk4KwtC" }, "source": [ "![image](https://hackmd.io/_uploads/By2IG1PL6.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "acP6_V7XKwtC" }, "outputs": [], "source": [ "print(x_train.shape[1:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3cZSnAQjX-M_" }, "outputs": [], "source": [ "'''在__________填入正確的參數讓產生的卷積影像大小不變吧'''\n", "\n", "inputs = Input(shape=x_train.shape[1:])\n", "# 建立卷積層,設定 32 個 3*3 的filters\n", "# 設定 padding,讓卷積運算,產生的卷積影像大小不變\n", "# 所有激活函數都設定為 ReLU\n", "x = Conv2D('______', '______', padding='______', activation='______')(inputs)\n", "x = Dropout(rate=0.25)(x)\n", "\n", "# 第二層 - 卷積層 (3x3 的 filters) + 池化層\n", "x = Conv2D('______', '______', padding='______', activation='______')(x)\n", "x = MaxPooling2D(pool_size=(2, 2))(x)\n", "\n", "# 第三層 - 卷積層 (3x3 的 filters)\n", "x = Conv2D('______', '______', padding='______', activation='______')(x)\n", "\n", "# 第四層 - 卷積層 (3x3 的 filters) + 池化層\n", "x = Conv2D('______', '______', padding='______', activation='______')(x)\n", "x = MaxPooling2D(pool_size=(2, 2))(x)\n", "x = Dropout(0.25)(x) # 避免overfitting\n", "\n", "# 建立分類模型 (MLP) : 平坦層 + 隱藏層 (512 神經元, ReLU 為激活函數) + 輸出層 (10)\n", "x = Flatten()(x)\n", "x = Dense('______', activation='______')(x)\n", "x = Dropout(0.25)(x)\n", "outputs = Dense(num_classes, activation='softmax')(x)\n", "\n", "\n", "model = Model(inputs=inputs, outputs=outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fUoML6ApKwtD" }, "outputs": [], "source": [ "# model.summary()\n", "plot_model(model, to_file='ex_Model.png', show_shapes=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5aD2EDYZKwtD" }, "outputs": [], "source": [ "model_dir = 'model-logs/'\n", "if not os.path.exists(model_dir):\n", " os.makedirs(model_dir)\n", "\n", "logfiles = f'{model_dir}{model.__class__.__name__}'\n", "\n", "modelfiles = f'{model_dir}basic_model-best-model.h5'\n", "model_mckp = keras.callbacks.ModelCheckpoint(modelfiles,\n", " monitor='val_accuracy',\n", " save_best_only=True)\n", "\n", "earlystop = keras.callbacks.EarlyStopping(monitor='val_loss',\n", " patience=5,\n", " verbose=1)\n", "\n", "\n", "callbacks_list = [model_mckp, earlystop]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8tie_2BzYAF0" }, "outputs": [], "source": [ "# 編譯模型\n", "# 選用 Adam 為 optimizer\n", "learning_rate = 0.0001\n", "optimizer = keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "model.compile(loss='categorical_crossentropy',\n", " optimizer=optimizer,\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "CkrKrXctigSI" }, "source": [ "## 開始訓練模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3-2yHC87YRXh" }, "outputs": [], "source": [ "batch_size = 32\n", "epochs = 20\n", "history = model.fit(x_train, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " validation_data=(x_test, y_test),\n", " callbacks=callbacks_list)" ] }, { "cell_type": "markdown", "metadata": { "id": "Adndiw8DCNhj" }, "source": [ "## 測試資料" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpX3cwZwKwtE" }, "outputs": [], "source": [ "best_model = \"./model-logs/basic_model-best-model.h5\"\n", "model = load_model(best_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0_9DvOzzCNhk" }, "outputs": [], "source": [ "test_pred = model.predict(x_test[0:1]).argmax(-1)\n", "\n", "plt.imshow(x_test[0])\n", "print('prediction: ', test_pred)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qHJahes4KwtE" }, "outputs": [], "source": [ "loss, acc = model.evaluate(x_test, y_test, verbose=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rig2pMPFKwtE" }, "outputs": [], "source": [ "y_pred = model.predict(x_test)\n", "print(y_pred[:3])\n", "y_pred = y_pred.argmax(-1)\n", "print(y_pred[:3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iVzmHo4JKwtE" }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score, confusion_matrix\n", "print(accuracy_score(y_test.argmax(-1), y_pred))\n", "print(confusion_matrix(y_test.argmax(-1), y_pred))" ] }, { "cell_type": "markdown", "metadata": { "id": "P6OaWCuWimfe" }, "source": [ "## 訓練結果視覺化" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p-H_2wjabXas" }, "outputs": [], "source": [ "train_history = ['loss', 'val_loss', 'accuracy', 'val_accuracy']\n", "name_history = ['training_loss', 'val_loss', 'training_acc', 'val_acc']\n", "\n", "plt.figure(figsize=(12, 5))\n", "for eachx, eachy, i in zip(train_history, name_history, range(4)):\n", " if i % 2 == 0:\n", " plt.subplot(1, 2, i//2+1)\n", " l_x = len(history.history[eachx])\n", " plt.plot(np.arange(l_x), history.history[eachx], label=eachy)\n", " plt.legend(loc='best')\n", " plt.title(eachy)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7jNrARB4KwtE" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "ex1_CNN_practice_advance.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 0 }