{ "cells": [ { "cell_type": "markdown", "id": "d0d29bfc", "metadata": { "id": "d0d29bfc" }, "source": [ "# **ResNet**\n", "此份程式碼會介紹如何使用 tf.keras 的方式建構 ResNet 的模型架構。" ] }, { "cell_type": "markdown", "id": "c76e4c9e", "metadata": { "id": "c76e4c9e" }, "source": [ "![image](https://hackmd.io/_uploads/rJ8SA1HOp.png)\n", "\n", "- [source paper](https://arxiv.org/abs/1512.03385)" ] }, { "cell_type": "markdown", "id": "d15836b8", "metadata": { "id": "d15836b8" }, "source": [ "## 匯入套件" ] }, { "cell_type": "code", "execution_count": null, "id": "14e0a7c1", "metadata": { "id": "14e0a7c1" }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Tensorflow 相關套件\n", "import tensorflow as tf\n", "from tensorflow.keras import datasets, layers, Model, Sequential, losses" ] }, { "cell_type": "markdown", "id": "2e22f407", "metadata": { "id": "2e22f407" }, "source": [ "## 載入資料集" ] }, { "cell_type": "code", "execution_count": null, "id": "45155e27", "metadata": { "id": "45155e27" }, "outputs": [], "source": [ "(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()\n", "\n", "# Expand dimensions\n", "x_train = tf.expand_dims(x_train, axis=3, name=None)\n", "x_test = tf.expand_dims(x_test, axis=3, name=None)\n", "print(f'x_train shape: {x_train.shape}')\n", "print(f'x_test shape: {x_test.shape}')\n", "print('----------')\n", "\n", "# Grayscale to RGB\n", "x_train = tf.repeat(x_train, 3, axis=3)\n", "x_test = tf.repeat(x_test, 3, axis=3)\n", "print(f'x_train shape: {x_train.shape}')\n", "print(f'x_test shape: {x_test.shape}')\n", "print('----------')\n", "\n", "# Split dataset into training and validation data\n", "x_val = x_train[int(x_train.shape[0]*0.8):, :, :, :]\n", "y_val = y_train[int(y_train.shape[0]*0.8):]\n", "x_train = x_train[:int(x_train.shape[0]*0.8), :, :, :]\n", "y_train = y_train[:int(y_train.shape[0]*0.8)]\n", "print(f'x_train shape: {x_train.shape}, x_val shape: {x_val.shape}')\n", "print(f'y_train shape: {y_train.shape}, y_val shape: {y_val.shape}')" ] }, { "cell_type": "markdown", "id": "c00e63ff", "metadata": { "id": "c00e63ff" }, "source": [ "## ResNet Arhietecture" ] }, { "cell_type": "markdown", "id": "be25a3b4", "metadata": { "id": "be25a3b4" }, "source": [ "![image](https://hackmd.io/_uploads/B16H0kHOT.png)\n", "\n", "- [source paper](https://arxiv.org/abs/1512.03385)" ] }, { "cell_type": "code", "execution_count": null, "id": "0ded32e1", "metadata": { "id": "0ded32e1" }, "outputs": [], "source": [ "def ResBlock(inputs, blocks_num, filters_num, kernel_size, strides=1):\n", " for i in range(blocks_num):\n", " x = layers.Conv2D(filters_num[0],\n", " (kernel_size[0], kernel_size[0]),\n", " strides=strides,\n", " padding='same')(inputs)\n", " x = layers.BatchNormalization()(x)\n", " strides = 1\n", " for j in range(1, len(filters_num)):\n", " x = layers.Activation('relu')(x)\n", " x = layers.Conv2D(filters_num[j],\n", " (kernel_size[j], kernel_size[j]),\n", " strides=strides,\n", " padding='same')(x)\n", " x = layers.BatchNormalization()(x)\n", "\n", " # 確認 Skip connection 維度一致\n", " if inputs.shape.as_list() == x.shape.as_list():\n", " identity = inputs\n", " else:\n", " identity_strides = inputs.shape[1]//x.shape[1]\n", " identity = layers.Conv2D(filters_num[-1], (1, 1),\n", " strides=identity_strides)(inputs)\n", " identity = layers.BatchNormalization()(identity)\n", "\n", " outputs = layers.add([identity, x])\n", " outputs = layers.Activation('relu')(outputs)\n", " inputs = outputs\n", "\n", " return outputs" ] }, { "cell_type": "code", "execution_count": null, "id": "b463453c", "metadata": { "id": "b463453c" }, "outputs": [], "source": [ "labels_num = 10" ] }, { "cell_type": "code", "execution_count": null, "id": "1407f6d8", "metadata": { "id": "1407f6d8" }, "outputs": [], "source": [ "tf.keras.backend.clear_session()\n", "inputs = layers.Input(shape=x_train.shape[1:])\n", "x = layers.Resizing(224, 224,\n", " interpolation=\"bilinear\",\n", " input_shape=x_train.shape[1:])(inputs)\n", "x = layers.Conv2D(64, (7, 7), strides=2, padding='same')(x)\n", "x = layers.BatchNormalization()(x)\n", "x = layers.Activation('relu')(x)\n", "x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)\n", "\n", "# conv2_x\n", "x = ResBlock(x, 3, [64, 64, 256], [1, 3, 1], strides=1)\n", "# conv3_x\n", "x = ResBlock(x, 4, [128, 128, 512], [1, 3, 1], strides=2)\n", "# conv4_x\n", "x = ResBlock(x, 6, [256, 256, 1024], [1, 3, 1], strides=2)\n", "# conv5_x\n", "x = ResBlock(x, 3, [512, 512, 2048], [1, 3, 1], strides=2)\n", "\n", "x = layers.GlobalAveragePooling2D()(x)\n", "outputs = layers.Dense(labels_num)(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "40070d68", "metadata": { "id": "40070d68" }, "outputs": [], "source": [ "ResNet_model = Model(inputs = inputs, outputs = outputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "013bfffb", "metadata": { "id": "013bfffb" }, "outputs": [], "source": [ "ResNet_model.summary()" ] }, { "cell_type": "code", "execution_count": null, "id": "8bbda37b", "metadata": { "id": "8bbda37b" }, "outputs": [], "source": [ "batch_size = 256" ] }, { "cell_type": "code", "execution_count": null, "id": "639e5dc2", "metadata": { "id": "639e5dc2" }, "outputs": [], "source": [ "inputs = np.ones((batch_size, x_train.shape[1], x_train.shape[2], 3),\n", " dtype=np.float32)\n", "ResNet_model(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "14a1c0bd", "metadata": { "id": "14a1c0bd" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }