{ "cells": [ { "cell_type": "markdown", "id": "1f146166", "metadata": { "id": "1f146166" }, "source": [ "# **常見訓練設定**\n", "此份程式碼將會介紹隨著訓練過程,可以調整或者紀錄的函式。\n", "\n", "## 本章節內容大綱\n", "* ### EarlyStopping(已於 part3/2_Overfitting.ipynb 介紹)\n", "* ### [ModelCheckpoint](#ModelCheckpoint)\n", "* ### [LearningRateSchedular](#LearningRateSchedular)\n", "* ### [CSVLogger](#CSVLogger)" ] }, { "cell_type": "markdown", "id": "011871a9", "metadata": { "id": "011871a9" }, "source": [ "## 匯入套件" ] }, { "cell_type": "code", "execution_count": null, "id": "d8600dff", "metadata": { "id": "d8600dff" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "# Tensorflow 相關套件\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers, callbacks" ] }, { "cell_type": "markdown", "id": "33e4ec08", "metadata": { "id": "33e4ec08" }, "source": [ "## 創建資料集/載入資料集(Dataset Creating / Loading)" ] }, { "cell_type": "code", "source": [ "# 上傳資料\n", "!wget -q https://github.com/TA-aiacademy/course_3.0/releases/download/DL/Data_part3.zip\n", "!unzip -q Data_part3.zip" ], "metadata": { "id": "P8myXB-0vB_F" }, "id": "P8myXB-0vB_F", "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "id": "d4f04ed4", "metadata": { "id": "d4f04ed4" }, "outputs": [], "source": [ "train_df = pd.read_csv('./Data/News_train.csv')\n", "test_df = pd.read_csv('./Data/News_test.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "092919ec", "metadata": { "id": "092919ec" }, "outputs": [], "source": [ "train_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "c516c346", "metadata": { "id": "c516c346" }, "outputs": [], "source": [ "X_df = train_df.iloc[:, :-1].values\n", "y_df = train_df.y_category.values" ] }, { "cell_type": "code", "execution_count": null, "id": "2dfe2b2f", "metadata": { "id": "2dfe2b2f" }, "outputs": [], "source": [ "X_test = test_df.iloc[:, :-1].values\n", "y_test = test_df.y_category.values" ] }, { "cell_type": "markdown", "id": "063a839a", "metadata": { "id": "063a839a" }, "source": [ "## 資料前處理(Data Preprocessing)" ] }, { "cell_type": "code", "execution_count": null, "id": "04803f0e", "metadata": { "id": "04803f0e" }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "# Feature scaling\n", "sc = StandardScaler()\n", "X_scale = sc.fit_transform(X_df, y_df)\n", "X_test_scale = sc.transform(X_test)" ] }, { "cell_type": "code", "execution_count": null, "id": "f6de8e01", "metadata": { "id": "f6de8e01" }, "outputs": [], "source": [ "# Convert to One-Hot encoding\n", "y_onehot = keras.utils.to_categorical(y_df)\n", "y_test_onehot = keras.utils.to_categorical(y_test)" ] }, { "cell_type": "code", "execution_count": null, "id": "bd513648", "metadata": { "id": "bd513648" }, "outputs": [], "source": [ "# train, valid/test dataset split\n", "from sklearn.model_selection import train_test_split\n", "X_train, X_valid, y_train, y_valid = train_test_split(X_scale, y_onehot,\n", " test_size=0.2,\n", " random_state=17,\n", " stratify=y_df)" ] }, { "cell_type": "code", "execution_count": null, "id": "c51092ac", "metadata": { "id": "c51092ac" }, "outputs": [], "source": [ "print(f'X_train shape: {X_train.shape}')\n", "print(f'X_valid shape: {X_valid.shape}')\n", "print(f'y_train shape: {y_train.shape}')\n", "print(f'y_valid shape: {y_valid.shape}')" ] }, { "cell_type": "markdown", "id": "670e0666", "metadata": { "id": "670e0666" }, "source": [ "## 模型建置(Model Building)" ] }, { "cell_type": "code", "execution_count": null, "id": "ce8aca19", "metadata": { "id": "ce8aca19" }, "outputs": [], "source": [ "def build_model(input_shape, output_shape):\n", " keras.backend.clear_session()\n", " tf.random.set_seed(17)\n", "\n", " model = keras.models.Sequential()\n", " model.add(layers.Dense(64,\n", " input_shape=input_shape,\n", " activation='tanh'))\n", " model.add(layers.Dense(64,\n", " activation='tanh'))\n", " model.add(tf.keras.layers.Dense(output_shape,\n", " activation='softmax'))\n", " return model" ] }, { "cell_type": "markdown", "id": "5046649f", "metadata": { "id": "5046649f" }, "source": [ "\n", "* ## ModelCheckpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "a1ccd11e", "metadata": { "id": "a1ccd11e" }, "outputs": [], "source": [ "model = build_model(X_train[0].shape, y_onehot.shape[1])\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "id": "4a70dc9f", "metadata": { "id": "4a70dc9f" }, "outputs": [], "source": [ "model.compile(optimizer='nadam',\n", " loss='categorical_crossentropy',\n", " metrics=['acc'])" ] }, { "cell_type": "code", "execution_count": null, "id": "e6bc4b25", "metadata": { "id": "e6bc4b25" }, "outputs": [], "source": [ "model_path = './Data/callbacks_model.h5' # 模型儲存的位置\n", "\n", "# 建立 Checkpoint\n", "checkpoint = callbacks.ModelCheckpoint(\n", " model_path,\n", " verbose=1,\n", " monitor='val_acc', # 儲存模型的指標\n", " save_best_only=True, # 是否只儲存最好的\n", " mode='max') # 與指標搭配模式" ] }, { "cell_type": "code", "execution_count": null, "id": "575624f3", "metadata": { "id": "575624f3" }, "outputs": [], "source": [ "history = model.fit(X_train, y_train,\n", " batch_size=512,\n", " epochs=20,\n", " validation_data=(X_valid, y_valid),\n", " callbacks=[checkpoint])" ] }, { "cell_type": "markdown", "id": "815412b3", "metadata": { "id": "815412b3" }, "source": [ "\n", "* ## LearningRateSchedular" ] }, { "cell_type": "code", "execution_count": null, "id": "80bbc0ab", "metadata": { "id": "80bbc0ab" }, "outputs": [], "source": [ "def schedule(epoch): # 定義 learning rate 根據 epoch 要如何變動\n", " if epoch < 10:\n", " return 0.001\n", " elif epoch < 15:\n", " return 0.0001\n", " else:\n", " return 0.00001\n", "\n", "\n", "# 建立 LearningRateScheduler\n", "lr_schedule = callbacks.LearningRateScheduler(\n", " schedule, verbose=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "470a7f9b", "metadata": { "id": "470a7f9b" }, "outputs": [], "source": [ "rlp = callbacks.ReduceLROnPlateau(\n", " monitor='val_loss', # 是否進步的指標\n", " factor=0.1, # 以 factor 的倍數調整 learning rate\n", " patience=5, # 經過 patience 次沒有進步調整 learning rate\n", " verbose=2,\n", " mode='min')" ] }, { "cell_type": "code", "execution_count": null, "id": "24de7e6a", "metadata": { "id": "24de7e6a" }, "outputs": [], "source": [ "# 建立兩個 list 記錄選用不同 learing rate schedular 的訓練結果\n", "train_loss_list = []\n", "train_acc_list = []\n", "\n", "# 建立兩個 list 記錄選用不同 learning rate schedular 的驗證結果\n", "valid_loss_list = []\n", "valid_acc_list = []\n", "\n", "callback_l = {'non': [], 'lr_s': lr_schedule, 'rlp': rlp}\n", "for cb in callback_l:\n", " print('Training a model with callbacks: {}'\n", " .format(cb))\n", " model = build_model(X_train[0].shape, y_onehot.shape[1])\n", " model.compile(optimizer='nadam',\n", " loss='categorical_crossentropy',\n", " metrics=['acc'])\n", " history = model.fit(X_train, y_train,\n", " epochs=20,\n", " batch_size=64,\n", " validation_data=(X_valid, y_valid),\n", " callbacks=[callback_l[cb]],\n", " verbose=0)\n", "\n", " # 將訓練過程記錄下來\n", " train_loss_list.append(history.history['loss'])\n", " valid_loss_list.append(history.history['val_loss'])\n", " train_acc_list.append(history.history['acc'])\n", " valid_acc_list.append(history.history['val_acc'])\n", " print('\\n')\n", "print('----------------- training done! -----------------')" ] }, { "cell_type": "code", "execution_count": null, "id": "5b56d666", "metadata": { "id": "5b56d666" }, "outputs": [], "source": [ "# 視覺化訓練過程\n", "plt.figure(figsize=(15, 5))\n", "\n", "train_line = ()\n", "valid_line = ()\n", "\n", "# 繪製 Training loss\n", "plt.subplot(121)\n", "for k, cb in enumerate(callback_l):\n", " loss = train_loss_list[k]\n", " val_loss = valid_loss_list[k]\n", " train_l = plt.plot(\n", " range(len(loss)), loss,\n", " label=f'Training callback:{cb}')\n", " valid_l = plt.plot(\n", " range(len(val_loss)), val_loss, '--',\n", " label=f'Validation callback:{cb}')\n", "\n", " train_line += tuple(train_l)\n", " valid_line += tuple(valid_l)\n", "plt.title('Loss')\n", "\n", "# 繪製 Training accuracy\n", "plt.subplot(122)\n", "train_acc_line = []\n", "valid_acc_line = []\n", "for k, cb in enumerate(callback_l):\n", " acc = train_acc_list[k]\n", " val_acc = valid_acc_list[k]\n", " plt.plot(range(len(acc)), acc,\n", " label=f'Training callback:{cb}')\n", " plt.plot(range(len(val_acc)), val_acc, '--',\n", " label=f'Validation callback:{cb}')\n", "plt.title('Accuracy')\n", "\n", "first_legend = plt.legend(handles=train_line,\n", " bbox_to_anchor=(1.05, 1))\n", "\n", "plt.gca().add_artist(first_legend)\n", "plt.legend(handles=valid_line,\n", " bbox_to_anchor=(1.05, 0.7))" ] }, { "cell_type": "markdown", "id": "33e55106", "metadata": { "id": "33e55106" }, "source": [ "\n", "* ## CSVLogger" ] }, { "cell_type": "code", "execution_count": null, "id": "c63f4345", "metadata": { "id": "c63f4345" }, "outputs": [], "source": [ "model = build_model(X_train[0].shape, y_onehot.shape[1])\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "id": "3c873f5f", "metadata": { "id": "3c873f5f" }, "outputs": [], "source": [ "model.compile(optimizer='nadam',\n", " loss='categorical_crossentropy',\n", " metrics=['acc'])" ] }, { "cell_type": "code", "execution_count": null, "id": "4aa8a4f5", "metadata": { "id": "4aa8a4f5" }, "outputs": [], "source": [ "csv_logger = callbacks.CSVLogger('./Data/training_log.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "c8010893", "metadata": { "id": "c8010893" }, "outputs": [], "source": [ "history = model.fit(X_train, y_train,\n", " batch_size=512,\n", " epochs=20,\n", " validation_data=(X_valid, y_valid),\n", " callbacks=[csv_logger])" ] }, { "cell_type": "markdown", "id": "12554d04", "metadata": { "id": "12554d04" }, "source": [ "---\n", "wandb(補充教材): https://docs.wandb.ai/v/zh-hans/quickstart" ] }, { "cell_type": "code", "execution_count": null, "id": "7548c212", "metadata": { "id": "7548c212" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 5 }