{ "cells": [ { "cell_type": "markdown", "id": "e1138fe9", "metadata": { "id": "e1138fe9" }, "source": [ "# **從簡單的資料集開始訓練模型**\n", "此份程式碼會從簡單的二維資料集介紹完整深度學習模型的訓練流程,從模型建置、模型訓練、模型評估,至模型儲存、載入重現結果。\n", "\n", "## 本章節內容大綱\n", "* ### [創建資料集/載入資料集(Dataset Creating/ Loading)](#DatasetCreating/Loading)\n", "* ### [模型建置(Model Building)](#ModelBuilding)\n", "* ### [模型訓練(Model Training)](#ModelTraining)\n", "* ### [模型評估(Model Evaluation)](#ModelEvaluation)\n", "* ### [模型儲存/載入(Model Saving/ Loading)](#ModelSaving/Loading)\n", "-----------------" ] }, { "cell_type": "markdown", "id": "bae31b71", "metadata": { "id": "bae31b71" }, "source": [ "## 匯入套件" ] }, { "cell_type": "code", "execution_count": null, "id": "3fae739d", "metadata": { "id": "3fae739d" }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Tensorflow 相關套件\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers" ] }, { "cell_type": "markdown", "id": "23b85c0b", "metadata": { "id": "23b85c0b" }, "source": [ "\n", "## 創建資料集/載入資料集(Dataset Creating / Loading)" ] }, { "cell_type": "code", "execution_count": null, "id": "7afa2ca3", "metadata": { "id": "7afa2ca3" }, "outputs": [], "source": [ "np.random.seed(12)\n", "num_samples_per_class = 1000\n", "\n", "# 創建負樣本 (X1_neg, X2_neg)\n", "negative_samples = np.random.multivariate_normal(\n", " mean=[0, 3], # 各維度的平均值\n", " cov=[[1, 0.5], [0.5, 1]], # 各維度的共變異數\n", " size=num_samples_per_class) # 樣本數量\n", "\n", "# 創建正樣本 (X1_pos, X2_pos)\n", "positive_samples = np.random.multivariate_normal(\n", " mean=[3, 0], # 各維度的平均值\n", " cov=[[1, 0.5], [0.5, 1]], # 各維度的共變異數\n", " size=num_samples_per_class) # 樣本數量\n", "\n", "print('shape of neg samples:', negative_samples.shape)\n", "print('shape of pos samples:', positive_samples.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "6a00e49f", "metadata": { "id": "6a00e49f" }, "outputs": [], "source": [ "inputs = np.vstack((negative_samples, positive_samples)).astype(np.float32)\n", "targets = np.vstack((np.zeros((num_samples_per_class, 1), # 負樣本標籤\n", " dtype='float32'),\n", " np.ones((num_samples_per_class, 1), # 正樣本標籤\n", " dtype='float32')))" ] }, { "cell_type": "code", "execution_count": null, "id": "7b38ef96", "metadata": { "id": "7b38ef96" }, "outputs": [], "source": [ "# 建立二維及三維的比較圖\n", "plt.figure(figsize=(10, 4))\n", "ax1 = plt.subplot(121)\n", "ax2 = plt.subplot(122, projection='3d')\n", "\n", "'''Plot on 2-dimension space'''\n", "# 繪製訓練資料集\n", "ax1.scatter(negative_samples[:, 0],\n", " negative_samples[:, 1],\n", " label='negative samples')\n", "ax1.scatter(positive_samples[:, 0],\n", " positive_samples[:, 1],\n", " label='negative samples')\n", "\n", "ax1.set_xlabel('x1')\n", "ax1.set_ylabel('x2')\n", "ax1.set_title('x1-x2 plane')\n", "\n", "'''Plot on 3-dimensions space'''\n", "# 繪製訓練資料集\n", "ax2.scatter(negative_samples[:, 0],\n", " negative_samples[:, 1],\n", " np.zeros((num_samples_per_class, 1), dtype='float32'),\n", " label='negative samples')\n", "\n", "ax2.scatter(positive_samples[:, 0],\n", " positive_samples[:, 1],\n", " np.ones((num_samples_per_class, 1), dtype='float32'),\n", " label='positive samples')\n", "\n", "ax2.set_xlabel('x1')\n", "ax2.set_ylabel('x2')\n", "ax2.set_zlabel('y')\n", "ax2.set_title('x1-x2-y space')\n", "ax2.view_init(45, 285)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "50f2fb6a", "metadata": { "id": "50f2fb6a" }, "outputs": [], "source": [ "# 打亂資料集順序\n", "shuffle_idx = tf.random.shuffle(range(2*num_samples_per_class), seed=17)\n", "inputs = inputs[shuffle_idx]\n", "targets = targets[shuffle_idx]" ] }, { "cell_type": "markdown", "id": "a73ba887", "metadata": { "id": "a73ba887" }, "source": [ "\n", "## 模型建置(Model Building)" ] }, { "cell_type": "markdown", "id": "268d89d7", "metadata": { "id": "268d89d7" }, "source": [ "目標:找到一個平面可以擬合這兩群資料點,假設此平面方程式為下列式子" ] }, { "cell_type": "markdown", "id": "cd08c6c0", "metadata": { "id": "cd08c6c0" }, "source": [ "![](https://hackmd.io/_uploads/rJiUtK8-T.png)\n" ] }, { "cell_type": "markdown", "id": "bfc3fa76", "metadata": { "id": "bfc3fa76" }, "source": [ "* ### Sequential model(序列模型)\n", "單輸入單輸出的模型,依順序堆疊網路層。" ] }, { "cell_type": "code", "execution_count": null, "id": "181dfad7", "metadata": { "id": "181dfad7" }, "outputs": [], "source": [ "keras.backend.clear_session() # 重置 keras 的所有狀態\n", "tf.random.set_seed(17) # 設定 tensorflow 隨機種子\n", "\n", "model = keras.models.Sequential()\n", "model.add(layers.Dense(1, # 神經元個數\n", " input_shape=inputs[0].shape)) # 輸入形狀\n", "\n", "# 以下寫法等同以上結果,將所有網路層按順序,以串列(list)的方式輸進 Sequential\n", "# model = keras.models.Sequential(\n", "# [layers.Dense(1, input_shape=inputs[0].shape)])" ] }, { "cell_type": "code", "execution_count": null, "id": "dd0c80d5", "metadata": { "id": "dd0c80d5" }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "id": "b23bec3a", "metadata": { "id": "b23bec3a" }, "source": [ "* ### Functional API\n", "除了單輸入單輸出外,也可以支援多輸入多輸出,相較 Sequential model 更彈性,依照變數傳遞的方式串接網路層。" ] }, { "cell_type": "code", "execution_count": null, "id": "37ad8028", "metadata": { "id": "37ad8028" }, "outputs": [], "source": [ "keras.backend.clear_session() # 重置 keras 的所有狀態\n", "tf.random.set_seed(17) # 設定 tensorflow 隨機種子\n", "\n", "x_inputs = layers.Input(shape=inputs[0].shape)\n", "x_outputs = layers.Dense(1)(x_inputs)\n", "\n", "model = keras.models.Model(inputs=x_inputs, outputs=x_outputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "e0649cc7", "metadata": { "id": "e0649cc7" }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "id": "b6ae643d", "metadata": { "id": "b6ae643d" }, "source": [ "\n", "## 模型訓練(Model Training)" ] }, { "cell_type": "markdown", "id": "efae4402", "metadata": { "id": "efae4402" }, "source": [ "* ### 模型編譯(model compile)\n", "設定模型訓練時,所需的優化器 (optimizer)、損失函數 (loss function)" ] }, { "cell_type": "code", "execution_count": null, "id": "8f53f320", "metadata": { "id": "8f53f320" }, "outputs": [], "source": [ "model.compile(optimizer='rmsprop', # 優化器\n", " loss='mean_squared_error') # 損失函數" ] }, { "cell_type": "code", "execution_count": null, "id": "d007e74d", "metadata": { "id": "d007e74d" }, "outputs": [], "source": [ "history = model.fit(\n", " inputs, # 輸入(訓練集)\n", " targets, # 標籤(訓練集)\n", " batch_size=16, # 批次數量\n", " epochs=20, # 訓練回合數\n", " validation_split=0.2) # 切分驗證集,前 80 % 為訓練集,後 20 % 為驗證集" ] }, { "cell_type": "markdown", "id": "e612e37d", "metadata": { "id": "e612e37d" }, "source": [ "\n", "## 模型評估(Model Evaluation)" ] }, { "cell_type": "markdown", "id": "33d36220", "metadata": { "id": "33d36220" }, "source": [ "* ### 視覺化訓練過程的評估指標 (Visualization)" ] }, { "cell_type": "code", "execution_count": null, "id": "33f9b739", "metadata": { "id": "33f9b739" }, "outputs": [], "source": [ "# type(history.history) = dictionary\n", "print(history.history.keys())" ] }, { "cell_type": "code", "execution_count": null, "id": "67569b4c", "metadata": { "id": "67569b4c" }, "outputs": [], "source": [ "train_loss = history.history['loss']\n", "valid_loss = history.history['val_loss']" ] }, { "cell_type": "code", "execution_count": null, "id": "4bbb0a56", "metadata": { "id": "4bbb0a56" }, "outputs": [], "source": [ "# 繪製 Epochs vs. MSE\n", "plt.figure(figsize=(15, 4))\n", "plt.plot(range(len(train_loss)), train_loss, label='train_loss')\n", "plt.plot(range(len(valid_loss)), valid_loss, label='valid_loss')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('MSE')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "9b639f83", "metadata": { "id": "9b639f83" }, "source": [ "* ### 模型預測(Model predictions)" ] }, { "cell_type": "code", "execution_count": null, "id": "638dca26", "metadata": { "id": "638dca26" }, "outputs": [], "source": [ "predictions = model(inputs)\n", "print(predictions)\n", "print(type(predictions))" ] }, { "cell_type": "code", "execution_count": null, "id": "39744711", "metadata": { "id": "39744711" }, "outputs": [], "source": [ "predictions = model.predict(inputs)\n", "print(predictions)\n", "print('Type:', type(predictions))" ] }, { "cell_type": "code", "execution_count": null, "id": "5f9f5711", "metadata": { "id": "5f9f5711" }, "outputs": [], "source": [ "loss = model.evaluate(inputs, targets)\n", "print(f'MSE: {loss}')" ] }, { "cell_type": "markdown", "id": "1c5705e4", "metadata": { "id": "1c5705e4" }, "source": [ "* ### 視覺化結果" ] }, { "cell_type": "code", "execution_count": null, "id": "02b976c8", "metadata": { "id": "02b976c8" }, "outputs": [], "source": [ "model.variables # 模型變數" ] }, { "cell_type": "code", "execution_count": null, "id": "35537bff", "metadata": { "id": "35537bff" }, "outputs": [], "source": [ "w = model.variables[0]\n", "b = model.variables[1]" ] }, { "cell_type": "code", "execution_count": null, "id": "a1b19f0b", "metadata": { "id": "a1b19f0b" }, "outputs": [], "source": [ "# 建立二維及三維的比較圖\n", "plt.figure(figsize=(10, 4))\n", "ax1 = plt.subplot(121)\n", "ax2 = plt.subplot(122, projection='3d')\n", "\n", "'''Plot on 2-dimension space'''\n", "# 決策邊界函數為 w1*x1 + w2*x2 + b = 0.5\n", "x = np.linspace(-3, 6, 100) # 從 -3 到 6 切分 100 等分\n", "boundary = - w[0] / w[1] * x + (0.5 - b) / w[1]\n", "\n", "# 繪製決策邊界線\n", "ax1.plot(x, boundary, '-r', label='Decision Boundary')\n", "\n", "# 繪製訓練資料集\n", "ax1.scatter(negative_samples[:, 0],\n", " negative_samples[:, 1],\n", " label='negative samples')\n", "ax1.scatter(positive_samples[:, 0],\n", " positive_samples[:, 1],\n", " label='negative samples')\n", "\n", "ax1.set_xlabel('x1')\n", "ax1.set_ylabel('x2')\n", "ax1.set_title('x1-x2 plane')\n", "\n", "'''Plot on 3-dimensions space'''\n", "x1 = np.linspace(-3, 5, 100)\n", "x2 = np.linspace(-3, 5, 100)\n", "x1, x2 = np.meshgrid(x1, x2) # [-3:5, -3:5]切分成 100x100 個位置點\n", "y = w[0] * x1 + w[1] * x2 + b\n", "\n", "ax2.contour3D(x1, x2, y, 100, alpha=0.5, cmap='viridis') # 擬合平面\n", "ax2.plot3D(x, boundary, 0.5, '-r', label='Decision Boundary') # 決策邊界線\n", "\n", "# 繪出訓練資料集\n", "ax2.scatter(negative_samples[:, 0],\n", " negative_samples[:, 1],\n", " np.zeros((num_samples_per_class, 1), dtype='float32'),\n", " label='negative samples',\n", " depthshade=False)\n", "ax2.scatter(positive_samples[:, 0],\n", " positive_samples[:, 1],\n", " np.ones((num_samples_per_class, 1), dtype='float32'),\n", " label='positive samples',\n", " depthshade=False)\n", "\n", "ax2.set_zlim(0, 1)\n", "ax2.set_xlabel('x1')\n", "ax2.set_ylabel('x2')\n", "ax2.set_zlabel('y')\n", "ax2.set_title('x1-x2-y space')\n", "ax2.view_init(45, 285)\n", "ax2.legend(bbox_to_anchor=(1.05, 1))\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ace0a364", "metadata": { "id": "ace0a364" }, "source": [ "\n", "## 模型儲存/載入(Model Saving/ Loading)" ] }, { "cell_type": "code", "execution_count": null, "id": "6ece0e6e", "metadata": { "id": "6ece0e6e" }, "outputs": [], "source": [ "model.save('Data/model.h5') # 儲存位置" ] }, { "cell_type": "code", "execution_count": null, "id": "733c1d31", "metadata": { "id": "733c1d31" }, "outputs": [], "source": [ "new_model = keras.models.load_model('Data/model.h5') # 讀取位置" ] }, { "cell_type": "code", "execution_count": null, "id": "209cb9e5", "metadata": { "id": "209cb9e5" }, "outputs": [], "source": [ "new_model.summary()" ] }, { "cell_type": "code", "execution_count": null, "id": "8180ac0a", "metadata": { "id": "8180ac0a" }, "outputs": [], "source": [ "loss= new_model.evaluate(inputs, targets)\n", "print(f'MSE: {loss}')" ] }, { "cell_type": "markdown", "id": "9474cb39", "metadata": { "id": "9474cb39" }, "source": [ "----------------\n", "## 動手試試看:\n", "1. 嘗試改動 random seed,觀察訓練的結果(收斂速度以及 MSE 表現等等)\n", "2. 嘗試改動 batch_size,觀察訓練的結果(收斂速度以及 MSE 表現等等)" ] }, { "cell_type": "code", "execution_count": null, "id": "5e99475d", "metadata": { "id": "5e99475d" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 5 }