{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "PAoR-j1b1Huc" }, "source": [ "# Vanilla GAN on MNIST (with DCGAN structure)" ] }, { "cell_type": "markdown", "metadata": { "id": "CSZejKDO1Huq" }, "source": [ "### 本章節內容大綱\n", "* [Generator 生成器](#Generator-生成器)\n", "* [Discriminator 判別器](#Discriminator-判別器)\n", "* [Loss function 損失函數](#Loss-function-損失函數)\n", "* [Mode Collapse 模式崩塌](#Mode-Collapse-模式崩塌)" ] }, { "cell_type": "markdown", "metadata": { "id": "GyigS8Po1Hus" }, "source": [ "在 unsupervised learning 的領域裡, GAN 是 generative model 劃時代的想法, AI 大神 Yann LeCun 曾在 Quora 上提到:「GAN及其變形是近十年最有趣的想法(This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.)」。像 Autoregressive model 或 Variational AutoEncoder 等 Generative model ,是看了訓練的資料後,用學到的特徵來產出結果。而相較於其他的生成模型, GAN 不一樣的地方是,嚴格來說 GAN 的 Generator 從未看過一筆訓練資料,他看到的只有 Discriminator 的得分(真偽分類結果),透過得分來更新 Generator 的網路權重,所以 GAN 也就能產生出訓練集中完全沒看過的假資料。\n", "\n", "
\n", "\n", "\n", "上圖為 GAN 在訓練時的基本流程, Generator在做的是先 sample 一個 latent vector(潛在向量/本徵相量) z ,之後丟給 Generator 生成一張圖片,而 Disciminator 則是透過真實的資料學習並且檢查 Generator 生成圖片的真偽,好讓 Generator 更新權重。在反覆的迭代後, Generator 與 Discriminator 都變得更會產生假資料及更會分辨假資料,之後我們就能夠拿 Generator 來產生我們要的假資料了。" ] }, { "cell_type": "markdown", "metadata": { "id": "01rB4Qb41Huu" }, "source": [ "# Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZ0hbn_b1Hu2" }, "outputs": [], "source": [ "''' basic package '''\n", "import os\n", "import time\n", "import imageio\n", "import glob\n", "from IPython.display import display, Image\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow.keras import Model, Sequential\n", "from tensorflow.keras.layers import (\n", " Dense, Conv2DTranspose, Conv2D, BatchNormalization,\n", " LeakyReLU, Dropout, Reshape, Flatten\n", ")\n", "\n", "from tensorflow.keras.losses import BinaryCrossentropy\n", "from tensorflow.keras.optimizers import Adam\n", "\n", "# 告訴系統要第幾張卡被看到。 Ex. 硬體總共有8張顯卡,以下設定只讓系統看到第1張顯卡\n", "# 若沒設定,則 Tensorflow 在運行時,預設會把所有卡都佔用\n", "# 要看裝置內顯卡數量及目前狀態的話,請在終端機內輸入 \"nvidia-smi\"\n", "# 若你的裝置只有一張顯卡可以使用,可以忽略此設定\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'" ] }, { "cell_type": "markdown", "metadata": { "id": "cYMj4jRA1Hu9" }, "source": [ "# Config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IBBZfByQ1Hu_" }, "outputs": [], "source": [ "BATCH_SIZE = 256\n", "BUFFER_SIZE = 60000 # tf2.0 的 shuffle 需要定義「抽籤桶」要多大,設 60000 意指全部的資料\n", "z_dim = 100 # latent/noise vector z 的維度\n", "EPOCHS = 50\n", "learning_rate = 1e-4\n", "num_examples_to_generate = 16 # 之後我們要 plot 出來檢視 generate 圖片品質的張數\n", "\n", "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()\n", "\n", "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", "\n", "# 將圖片正規化至 [-1 ~ 1]\n", "train_images = (train_images - 127.5) / 127.5\n", "\n", "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OW43zb0_1HvD" }, "outputs": [], "source": [ "plt.imshow(train_images[0].squeeze(), cmap='gray')" ] }, { "cell_type": "markdown", "metadata": { "id": "TB09E-8K1HvF" }, "source": [ "# Generator 生成器" ] }, { "cell_type": "markdown", "metadata": { "id": "OXOv4A0N1HvG" }, "source": [ "在生成圖片時需要用到 convolution 的變形: transposed convolution(轉置卷積)
\n", "一般的卷積可以理解為將圖片的特徵萃取成一張張的 feature map ,\n", "而轉置卷積簡單來說可以理解為將 feature map 還原為圖片(一種 upsampling 的方式)
\n", "下圖是 convolution 跟 transposed convolution(deconvolution) 的對應關係:\n", "\n", "\n", "\n", "這邊要注意的一點是,在做 Conv2DTranspose 時要注意 feature map 的大小,\n", "而下面將簡述 Conv2DTranspose 的兩個重要的參數所代表的意義:
\n", "\n", "* stride 步長\n", "\n", "\n", "\n", "剛好與 Conv2D 相反,在一般卷積中, stride 表示一次走幾格,有就是說當 stride=2 時, feature map size 就會縮小一半,\n", "而 Conv2DTranspose 中的 stride=2 所表示的,則是會將 feature map 中間插入孔洞,因而放大兩倍(當然還要搭配 padding 的設定)。\n", "\n", "* padding 填補\n", "\n", "\n", "\n", "與 Conv2D 不同,當設定為 'VALID' 時就會像是上圖一樣是有 padding 的,理解的方法可利用 Conv2D 的概念去回推,\n", "當 Conv2D 的 padding='VALID' 時, 4x4 的圖經過 Conv2D(kernel=3) 會產生 2x2 的圖,\n", "那 4x4 就會是 2x2 的圖經過 Conv2DTranspose(kernel=3, padding='VALID') 時所產出的大小。\n", "而在 tf.keras 的程式碼中,當 padding 是 'SAME' 時,會直接放大 stride 的倍數, output size 計算的方式會是如下:\n", "\n", "```python\n", "# for `tf.layers.conv2d_transpose()` with `SAME` padding:\n", "out_height = in_height * strides[1]\n", "out_width = in_width * strides[2]\n", "```\n", "(這邊跟最上面右邊的圖(stride=2, padding='same')有出入的原因是 keras 預設只會將 output 的特徵圖去掉 right/bottom 的 pad(與 padding='valid' 相比),但如果設定超參數 Conv2DTranspose(output_padding=0) 的話,就會將四周都去掉,而輸出跟最上面右邊的圖一樣的 shape)\n", "\n", "\n", "如果是自行定義的padding的話,則計算公式如下:\n", "\n", "```python\n", "# for `tf.layers.conv2d_transpose()` with given padding:\n", "out_height = strides[1] * (in_height - 1) + kernel_size[0] - 2 * padding_height\n", "out_width = strides[2] * (in_width - 1) + kernel_size[1] - 2 * padding_width\n", "```\n", "\n", "如果對於 convolution 的 input size 與 output size 不熟的話,以下是公式,如果要把它轉成 deconvolution 的話,將 input 與 output size 互換就可以得到還原後的結果:
\n", "\n", "\n", "\n", "W 代表 convolution 的 input size ; pad 表示單邊的 padding size ; ks 表示 kernel size ; S 表示 stride\n", "\n", "\n", "如果上面的解釋還是不清楚,下面有兩種對於 transposed convolution 用圖像化的解釋給學員參考:\n", "\n", "\n", "
\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HNUihVbS1HvH" }, "outputs": [], "source": [ "class Generator(Model):\n", " def __init__(self, z_dim):\n", "\n", " # super 是讓 Generator class 能夠繼承 Model 的 __init__ 中的變數\n", " super(Generator, self).__init__()\n", "\n", " self.model = Sequential()\n", "\n", " # [z_dim] => [7, 7, 128]\n", " self.model.add(Dense(7 * 7 * 128, use_bias=False, input_shape=(z_dim,)))\n", "\n", " self.model.add(LeakyReLU())\n", " self.model.add(Reshape((7, 7, 128)))\n", "\n", " # [7, 7, 128] => [14, 14, 64]\n", " self.model.add(Conv2DTranspose(64, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [14, 14, 64] => [28, 28, 1]\n", " self.model.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='tanh'))\n", "\n", " def call(self, x):\n", " return self.model(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "sMXZS-Qk1HvH" }, "source": [ "# Discriminator 判別器" ] }, { "cell_type": "markdown", "metadata": { "id": "Fnqxsn3W1HvI" }, "source": [ "discriminator 的結構就跟一般 cnn 一模一樣,這邊就不再多加贅述。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s0Wx1v8b1HvI" }, "outputs": [], "source": [ "class Discriminator(Model):\n", " def __init__(self):\n", "\n", " # super 是讓 Discriminator class 能夠繼承 Model 的 __init__ 中的變數\n", " super(Discriminator, self).__init__()\n", "\n", " self.model = Sequential()\n", "\n", " # [28, 28, 1] => [14, 14, 64]\n", " self.model.add(Conv2D(64, 5, strides=2, padding='same', input_shape=(28, 28, 1)))\n", " self.model.add(LeakyReLU())\n", "\n", " # [14, 14, 64] => [7, 7, 128]\n", " self.model.add(Conv2D(128, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [7, 7, 128] => [4, 4, 256]\n", " self.model.add(Conv2D(256, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [4, 4, 256] => [1]\n", " self.model.add(Flatten())\n", " self.model.add(Dense(1))\n", "\n", " def call(self, x):\n", " return self.model(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "y684I7AW1HvJ" }, "source": [ "# Loss function 損失函數" ] }, { "cell_type": "markdown", "metadata": { "id": "z5yJq10g1HvJ" }, "source": [ "GAN 中 generator 與 discriminator 要 optimize 的對象是不同的,所以要寫兩個 loss function 分別作 gradient descend 。
\n", "\n", "\n", "\n", "上表是 GAN 的演算法, discriminator 的 loss 為「真實資料卻分成假的(d_loss_real)+假資料卻分成真的(d_loss_fake)」,\n", "generator 的 loss 為「產生的假資料被判為假資料(g_loss)」
\n", "(這邊的例子中,我們的 generator loss 會改寫成由 Ian Goodfellow 在論文後半部所提出的下面的形式,好處是一開始參數更新得比較快)
\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xnl-9l2p1HvJ" }, "outputs": [], "source": [ "def gan_loss(d_real_logits, d_fake_logits):\n", " # d_real_logits 為 discriminator 看了真實資料的 output,\n", " # d_fake_logits 為 discriminator 看了 generator 製造的假資料的 output\n", "\n", " # 定義 cross entropy 作為衡量標準,from_logits 是指要將 input 作 sigmoid 處理\n", " cross_entropy = BinaryCrossentropy(from_logits=True)\n", " # discriminator loss\n", " d_loss_real = cross_entropy(tf.ones_like(d_real_logits), d_real_logits)\n", " d_loss_fake = cross_entropy(tf.zeros_like(d_fake_logits), d_fake_logits)\n", " d_loss = d_loss_real + d_loss_fake\n", " # generator loss\n", " g_loss = cross_entropy(tf.ones_like(d_fake_logits), d_fake_logits)\n", " return d_loss, g_loss" ] }, { "cell_type": "markdown", "metadata": { "id": "VCtrXHWy1HvK" }, "source": [ "# Start training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CokcxivF1HvK" }, "outputs": [], "source": [ "generator = Generator(z_dim)\n", "discriminator = Discriminator()\n", "g_optimizer = Adam(learning_rate)\n", "d_optimizer = Adam(learning_rate)\n", "\n", "# 固定 seed 以利觀察產出的圖片品質\n", "seed = tf.random.normal([num_examples_to_generate, z_dim])\n", "\n", "save_dir = './saved_imgs'\n", "checkpoint_dir = './training_checkpoints'\n", "\n", "if not os.path.exists(save_dir):\n", " os.makedirs(save_dir)\n", "if not os.path.exists(checkpoint_dir):\n", " os.makedirs(checkpoint_dir)\n", "\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,\n", " d_optimizer=d_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w4eN_ZGM1HvL" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(real_images, generator, discriminator, g_optimizer, d_optimizer):\n", "\n", " noise = tf.random.normal([BATCH_SIZE, z_dim])\n", " with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:\n", " fake_images = generator(noise)\n", "\n", " d_real_logits = discriminator(real_images)\n", " d_fake_logits = discriminator(fake_images)\n", "\n", " d_loss, g_loss = gan_loss(d_real_logits, d_fake_logits)\n", "\n", " g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)\n", " d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)\n", "\n", " g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))\n", " d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))\n", " return d_loss, g_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "99DGnnOB1HvL" }, "outputs": [], "source": [ "def generate_and_save_images(model, epoch, test_input, save_path):\n", "\n", " predictions = model(test_input)\n", "\n", " fig = plt.figure(figsize=(4, 4))\n", "\n", " for i in range(predictions.shape[0]):\n", " plt.subplot(4, 4, i + 1)\n", " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", " plt.axis('off')\n", "\n", " # 每 5 個 epoches 存一次圖片\n", " if (epoch + 1) % 5 == 0:\n", " plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch)))\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Eimmx4mn1HvM" }, "outputs": [], "source": [ "def train(dataset, epochs):\n", " for epoch in range(epochs):\n", " start = time.time()\n", "\n", " for image_batch in dataset:\n", " d_loss, g_loss = train_step(image_batch, generator, discriminator,\n", " g_optimizer, d_optimizer)\n", "\n", " # 產生圖片\n", " print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))\n", " print('discriminator loss: %.5f' % d_loss)\n", " print('generator loss: %.5f' % g_loss)\n", " generate_and_save_images(generator, epoch + 1, seed, save_dir)\n", "\n", " # 每 25 個 epochs 存一次模型\n", " if (epoch + 1) % 25 == 0:\n", " checkpoint.save(file_prefix=checkpoint_prefix)\n", "\n", " # 在最後一個 epoch 再產生一次圖片與儲存一次權重\n", " generate_and_save_images(generator, epochs, seed, save_dir)\n", " checkpoint.save(file_prefix=checkpoint_prefix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rHauRzH51HvM" }, "outputs": [], "source": [ "%%time\n", "train(train_dataset, EPOCHS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L1_9Pj1d1HvN" }, "outputs": [], "source": [ "# 使用 imageio 製作 gif 圖\n", "anim_file = 'saved_imgs/dcgan.gif'\n", "\n", "with imageio.get_writer(anim_file, mode='I') as writer:\n", "\n", " filenames = glob.glob('saved_imgs/image*.png')\n", " filenames = sorted(filenames)\n", " last = -1\n", " for i,filename in enumerate(filenames):\n", " # skip 掉 部分的 frame\n", " frame = 2*(i**0.5)\n", " if round(frame) > round(last):\n", " last = frame\n", " else:\n", " continue\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", "\n", "display(Image(filename=anim_file))" ] }, { "cell_type": "markdown", "metadata": { "id": "1xl52-yF1HvO" }, "source": [ "# Mode Collapse 模式崩塌" ] }, { "cell_type": "markdown", "metadata": { "id": "cm0Hc4V31HvP" }, "source": [ "從上面的訓練結果看來,我們的 Generator 最後好像只會生成很像1或9這種簡單的圖案,像是8、5或是2這些相對複雜的數字圖案是沒有生成出來的。這種問題發生於即便 z vector(雜訊向量/本徵相量)改變了,其對應 generator 的 G(z) 還是不會改變,而這種現象被稱之為 Mode Collapse 。造成這種現象的主要問題是因為 generator 的 loss function 改寫成 \"-log D(x)\" 的形式所導致,實際的問題用數學來說為 generator 要 minimize 的 loss function 其實是下式:
\n", "\n", "\n", "這個公式有個嚴重的矛盾即是同時要拉近兩個分佈的距離,同時又要推遠兩個分佈,如此會導致梯度不穩定。另一個問題是 KL divergence 懲罰錯誤的不一制性:
\n", "
\n", "\n", "$$ when \\; P_g(x)\\rightarrow 0 \\; and\\; P_r(x)\\rightarrow 1,\\; KL(P_g||P_r) = P_g(x) log\\frac{P_g(x)}{P_r(x)} \\rightarrow 0$$\n", "\n", "$$ when \\; P_g(x)\\rightarrow 1 \\; and\\; P_r(x)\\rightarrow 0,\\; KL(P_g||P_r) = P_g(x) log\\frac{P_g(x)}{P_r(x)} \\rightarrow +\\infty$$\n", "\n", "一式為 generator 沒辦法產生真實的資料, KL divergence 對 loss 的影響幾乎沒有;二式為 generator 產生了不真實的資料, KL divergence 對 loss 的影響非常強烈。\n", "由這兩點可以發現,就算 generator 不會產生一些相對複雜的數字也不會有太大的懲罰,但是對於產生一些「實驗性」的數字就會有很大的懲罰,所以 generator 就會產生相對「簡單又安全」的數字來確保 loss 最低。\n", "\n", "(對於想了解數學推導又不想看英文論文的人,可以詳見知乎的對於 GAN 的詳細數學說明:https://zhuanlan.zhihu.com/p/25071913 )\n", "\n", "後續有許多對 GAN 的研究也圍繞在 GAN 的 loss function 上面,像是接下來的 WGAN 就是一個例子。" ] }, { "cell_type": "markdown", "metadata": { "id": "2kEWOgPP1HvP" }, "source": [ "# Generating random digits image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3W0k7mye1HvP" }, "outputs": [], "source": [ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WEc-xAAl1HvQ" }, "outputs": [], "source": [ "noise = tf.random.normal([1, z_dim])\n", "img = generator(noise, training=False)\n", "\n", "# 生成的 pixel 數值介於-1~1之間,故要將其縮放至0~255方能顯示\n", "plt.imshow(img[0,:, :, 0] * 127.5 + 127.5, cmap='gray')\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4jHiBQGG1HvS" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [], "gpuType": "T4", "include_colab_link": true }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }