{
"cells": [
{
"cell_type": "markdown",
"id": "1f146166",
"metadata": {
"id": "1f146166"
},
"source": [
"# **常見訓練設定**\n",
"此份程式碼將會介紹隨著訓練過程,可以調整或者紀錄的函式。\n",
"\n",
"## 本章節內容大綱\n",
"* ### EarlyStopping(已於 part3/2_Overfitting.ipynb 介紹)\n",
"* ### [ModelCheckpoint](#ModelCheckpoint)\n",
"* ### [LearningRateSchedular](#LearningRateSchedular)\n",
"* ### [CSVLogger](#CSVLogger)"
]
},
{
"cell_type": "markdown",
"id": "011871a9",
"metadata": {
"id": "011871a9"
},
"source": [
"## 匯入套件"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8600dff",
"metadata": {
"id": "d8600dff"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Tensorflow 相關套件\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers, callbacks"
]
},
{
"cell_type": "markdown",
"id": "33e4ec08",
"metadata": {
"id": "33e4ec08"
},
"source": [
"## 創建資料集/載入資料集(Dataset Creating / Loading)"
]
},
{
"cell_type": "code",
"source": [
"# 上傳資料\n",
"!wget -q https://github.com/TA-aiacademy/course_3.0/releases/download/DL/Data_part3.zip\n",
"!unzip -q Data_part3.zip"
],
"metadata": {
"id": "P8myXB-0vB_F"
},
"id": "P8myXB-0vB_F",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4f04ed4",
"metadata": {
"id": "d4f04ed4"
},
"outputs": [],
"source": [
"train_df = pd.read_csv('./Data/News_train.csv')\n",
"test_df = pd.read_csv('./Data/News_test.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "092919ec",
"metadata": {
"id": "092919ec"
},
"outputs": [],
"source": [
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c516c346",
"metadata": {
"id": "c516c346"
},
"outputs": [],
"source": [
"X_df = train_df.iloc[:, :-1].values\n",
"y_df = train_df.y_category.values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2dfe2b2f",
"metadata": {
"id": "2dfe2b2f"
},
"outputs": [],
"source": [
"X_test = test_df.iloc[:, :-1].values\n",
"y_test = test_df.y_category.values"
]
},
{
"cell_type": "markdown",
"id": "063a839a",
"metadata": {
"id": "063a839a"
},
"source": [
"## 資料前處理(Data Preprocessing)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04803f0e",
"metadata": {
"id": "04803f0e"
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
"# Feature scaling\n",
"sc = StandardScaler()\n",
"X_scale = sc.fit_transform(X_df, y_df)\n",
"X_test_scale = sc.transform(X_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6de8e01",
"metadata": {
"id": "f6de8e01"
},
"outputs": [],
"source": [
"# Convert to One-Hot encoding\n",
"y_onehot = keras.utils.to_categorical(y_df)\n",
"y_test_onehot = keras.utils.to_categorical(y_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd513648",
"metadata": {
"id": "bd513648"
},
"outputs": [],
"source": [
"# train, valid/test dataset split\n",
"from sklearn.model_selection import train_test_split\n",
"X_train, X_valid, y_train, y_valid = train_test_split(X_scale, y_onehot,\n",
" test_size=0.2,\n",
" random_state=17,\n",
" stratify=y_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c51092ac",
"metadata": {
"id": "c51092ac"
},
"outputs": [],
"source": [
"print(f'X_train shape: {X_train.shape}')\n",
"print(f'X_valid shape: {X_valid.shape}')\n",
"print(f'y_train shape: {y_train.shape}')\n",
"print(f'y_valid shape: {y_valid.shape}')"
]
},
{
"cell_type": "markdown",
"id": "670e0666",
"metadata": {
"id": "670e0666"
},
"source": [
"## 模型建置(Model Building)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce8aca19",
"metadata": {
"id": "ce8aca19"
},
"outputs": [],
"source": [
"def build_model(input_shape, output_shape):\n",
" keras.backend.clear_session()\n",
" tf.random.set_seed(17)\n",
"\n",
" model = keras.models.Sequential()\n",
" model.add(layers.Dense(64,\n",
" input_shape=input_shape,\n",
" activation='tanh'))\n",
" model.add(layers.Dense(64,\n",
" activation='tanh'))\n",
" model.add(tf.keras.layers.Dense(output_shape,\n",
" activation='softmax'))\n",
" return model"
]
},
{
"cell_type": "markdown",
"id": "5046649f",
"metadata": {
"id": "5046649f"
},
"source": [
"\n",
"* ## ModelCheckpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1ccd11e",
"metadata": {
"id": "a1ccd11e"
},
"outputs": [],
"source": [
"model = build_model(X_train[0].shape, y_onehot.shape[1])\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a70dc9f",
"metadata": {
"id": "4a70dc9f"
},
"outputs": [],
"source": [
"model.compile(optimizer='nadam',\n",
" loss='categorical_crossentropy',\n",
" metrics=['acc'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6bc4b25",
"metadata": {
"id": "e6bc4b25"
},
"outputs": [],
"source": [
"model_path = './Data/callbacks_model.h5' # 模型儲存的位置\n",
"\n",
"# 建立 Checkpoint\n",
"checkpoint = callbacks.ModelCheckpoint(\n",
" model_path,\n",
" verbose=1,\n",
" monitor='val_acc', # 儲存模型的指標\n",
" save_best_only=True, # 是否只儲存最好的\n",
" mode='max') # 與指標搭配模式"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "575624f3",
"metadata": {
"id": "575624f3"
},
"outputs": [],
"source": [
"history = model.fit(X_train, y_train,\n",
" batch_size=512,\n",
" epochs=20,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[checkpoint])"
]
},
{
"cell_type": "markdown",
"id": "815412b3",
"metadata": {
"id": "815412b3"
},
"source": [
"\n",
"* ## LearningRateSchedular"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80bbc0ab",
"metadata": {
"id": "80bbc0ab"
},
"outputs": [],
"source": [
"def schedule(epoch): # 定義 learning rate 根據 epoch 要如何變動\n",
" if epoch < 10:\n",
" return 0.001\n",
" elif epoch < 15:\n",
" return 0.0001\n",
" else:\n",
" return 0.00001\n",
"\n",
"\n",
"# 建立 LearningRateScheduler\n",
"lr_schedule = callbacks.LearningRateScheduler(\n",
" schedule, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "470a7f9b",
"metadata": {
"id": "470a7f9b"
},
"outputs": [],
"source": [
"rlp = callbacks.ReduceLROnPlateau(\n",
" monitor='val_loss', # 是否進步的指標\n",
" factor=0.1, # 以 factor 的倍數調整 learning rate\n",
" patience=5, # 經過 patience 次沒有進步調整 learning rate\n",
" verbose=2,\n",
" mode='min')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24de7e6a",
"metadata": {
"id": "24de7e6a"
},
"outputs": [],
"source": [
"# 建立兩個 list 記錄選用不同 learing rate schedular 的訓練結果\n",
"train_loss_list = []\n",
"train_acc_list = []\n",
"\n",
"# 建立兩個 list 記錄選用不同 learning rate schedular 的驗證結果\n",
"valid_loss_list = []\n",
"valid_acc_list = []\n",
"\n",
"callback_l = {'non': [], 'lr_s': lr_schedule, 'rlp': rlp}\n",
"for cb in callback_l:\n",
" print('Training a model with callbacks: {}'\n",
" .format(cb))\n",
" model = build_model(X_train[0].shape, y_onehot.shape[1])\n",
" model.compile(optimizer='nadam',\n",
" loss='categorical_crossentropy',\n",
" metrics=['acc'])\n",
" history = model.fit(X_train, y_train,\n",
" epochs=20,\n",
" batch_size=64,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[callback_l[cb]],\n",
" verbose=0)\n",
"\n",
" # 將訓練過程記錄下來\n",
" train_loss_list.append(history.history['loss'])\n",
" valid_loss_list.append(history.history['val_loss'])\n",
" train_acc_list.append(history.history['acc'])\n",
" valid_acc_list.append(history.history['val_acc'])\n",
" print('\\n')\n",
"print('----------------- training done! -----------------')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b56d666",
"metadata": {
"id": "5b56d666"
},
"outputs": [],
"source": [
"# 視覺化訓練過程\n",
"plt.figure(figsize=(15, 5))\n",
"\n",
"train_line = ()\n",
"valid_line = ()\n",
"\n",
"# 繪製 Training loss\n",
"plt.subplot(121)\n",
"for k, cb in enumerate(callback_l):\n",
" loss = train_loss_list[k]\n",
" val_loss = valid_loss_list[k]\n",
" train_l = plt.plot(\n",
" range(len(loss)), loss,\n",
" label=f'Training callback:{cb}')\n",
" valid_l = plt.plot(\n",
" range(len(val_loss)), val_loss, '--',\n",
" label=f'Validation callback:{cb}')\n",
"\n",
" train_line += tuple(train_l)\n",
" valid_line += tuple(valid_l)\n",
"plt.title('Loss')\n",
"\n",
"# 繪製 Training accuracy\n",
"plt.subplot(122)\n",
"train_acc_line = []\n",
"valid_acc_line = []\n",
"for k, cb in enumerate(callback_l):\n",
" acc = train_acc_list[k]\n",
" val_acc = valid_acc_list[k]\n",
" plt.plot(range(len(acc)), acc,\n",
" label=f'Training callback:{cb}')\n",
" plt.plot(range(len(val_acc)), val_acc, '--',\n",
" label=f'Validation callback:{cb}')\n",
"plt.title('Accuracy')\n",
"\n",
"first_legend = plt.legend(handles=train_line,\n",
" bbox_to_anchor=(1.05, 1))\n",
"\n",
"plt.gca().add_artist(first_legend)\n",
"plt.legend(handles=valid_line,\n",
" bbox_to_anchor=(1.05, 0.7))"
]
},
{
"cell_type": "markdown",
"id": "33e55106",
"metadata": {
"id": "33e55106"
},
"source": [
"\n",
"* ## CSVLogger"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c63f4345",
"metadata": {
"id": "c63f4345"
},
"outputs": [],
"source": [
"model = build_model(X_train[0].shape, y_onehot.shape[1])\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c873f5f",
"metadata": {
"id": "3c873f5f"
},
"outputs": [],
"source": [
"model.compile(optimizer='nadam',\n",
" loss='categorical_crossentropy',\n",
" metrics=['acc'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4aa8a4f5",
"metadata": {
"id": "4aa8a4f5"
},
"outputs": [],
"source": [
"csv_logger = callbacks.CSVLogger('./Data/training_log.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8010893",
"metadata": {
"id": "c8010893"
},
"outputs": [],
"source": [
"history = model.fit(X_train, y_train,\n",
" batch_size=512,\n",
" epochs=20,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[csv_logger])"
]
},
{
"cell_type": "markdown",
"id": "12554d04",
"metadata": {
"id": "12554d04"
},
"source": [
"---\n",
"wandb(補充教材): https://docs.wandb.ai/v/zh-hans/quickstart"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7548c212",
"metadata": {
"id": "7548c212"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"colab": {
"provenance": []
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 5
}