{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "Jlc3qyFv1dJo" }, "source": [ "# WGAN-GP on MNIST" ] }, { "cell_type": "markdown", "metadata": { "id": "EXuP6sSw1dJu" }, "source": [ "### 本章節內容大綱\n", "* [WGAN-GP原理](#WGAN-GP原理)\n", "* [使用 Gradient penalty 取代 weight clipping](#使用-Gradient-penalty-取代-weight-clipping)" ] }, { "cell_type": "markdown", "metadata": { "id": "4ZBhBvgT1dJw" }, "source": [ "# WGAN-GP原理" ] }, { "cell_type": "markdown", "metadata": { "id": "PdeX9Mep1dJw" }, "source": [ "WGAN用了比較粗糙的作法——weight clipping,讓discriminator變得平滑而接近1-Lipschitz function,然而這種做法有以下兩點問題:
\n", "第一、最後Discriminator的參數很容易收斂到clipping範圍的兩個邊界值,只能學習出簡單的function\n", "\n", "
\n", "\n", "
\n", "\n", "第二、很容易因為clipping範圍些微差異的設定,在Discriminator network各層間出現梯度爆炸或梯度消失\n", "\n", "
\n", "\n", "
\n", "\n", "所以之後就有了改良版——WGAN-GP(gradient penalty),取代了weight clipping,WGAN-GP將Wasserstein distance改成下式:\n", "\n", "$$W(P_{data}, P_G) = \\max_D (E_{x\\sim P_{data}}\\; [D(x)] - E_{x\\sim P_G}\\;[D(x)] - \\lambda E_{x\\sim P_{penalty}}\\;\\;[max(0,||\\nabla_x D(x)|| -1)])$$\n", "\n", "因為$||\\nabla_x D(x)|| \\leq 1$ (Discriminator 對 input 的 gradient 小於或等於 1)就等價於 1-Lipschitz function ,所以 WGAN-GP 就加上了 discriminator gradient 減去 1 的懲罰項以近似 1-Lipschitz function ,然而實際實驗過後,作者將其改寫成下式(將其 gradient 接近 1),訓練效果會更好:\n", "\n", "$$W(P_{data}, P_G) = \\max_D (E_{x\\sim P_{data}}\\; [D(x)] - E_{x\\sim P_G}\\;[D(x)] - \\lambda E_{x\\sim P_{penalty}}\\;\\;[(||\\nabla_x D(x)|| -1)^2])$$\n", "\n", "其中 lambda 是 penalty weight,P_penalty 是在真實資料與生成資料之中 sample 資料,對其 gradient 作懲罰即可,而不需要對整個 data space 都作 graident。" ] }, { "cell_type": "markdown", "metadata": { "id": "vtIm9Zn81dJx" }, "source": [ "# Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5TDL-0Xg1dJy" }, "outputs": [], "source": [ "''' basic package '''\n", "import os\n", "import time\n", "import imageio\n", "import glob\n", "from IPython.display import display, Image\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow.keras import Model, Sequential\n", "from tensorflow.keras.layers import (\n", " Dense, Conv2DTranspose, Conv2D, BatchNormalization,\n", " LeakyReLU, Dropout, Reshape, Flatten\n", ")\n", "\n", "from tensorflow.keras.losses import BinaryCrossentropy\n", "\n", "from tensorflow.keras.optimizers import Adam, RMSprop\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'" ] }, { "cell_type": "markdown", "metadata": { "id": "_Clsoiq61dJ0" }, "source": [ "# Config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2d-Z502a1dJ1" }, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "BUFFER_SIZE = 60000\n", "z_dim = 100\n", "EPOCHS = 50\n", "learning_rate = 1e-4\n", "num_examples_to_generate = 16\n", "gp_weight = 10\n", "\n", "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()\n", "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]\n", "\n", "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "Dwm2kqfj1dJ2" }, "source": [ "# Train" ] }, { "cell_type": "markdown", "metadata": { "id": "VmUKn6ra1dJ3" }, "source": [ "### 定義 generator & discriminator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-ib2uP4J1dJ4" }, "outputs": [], "source": [ "# Define model\n", "\n", "\n", "class Generator(Model):\n", " def __init__(self, z_dim):\n", " super(Generator, self).__init__()\n", "\n", " self.model = Sequential()\n", "\n", " # [z_dim] => [7, 7, 128]\n", " self.model.add(Dense(7 * 7 * 128, use_bias=False, input_shape=(z_dim,)))\n", " self.model.add(LeakyReLU())\n", " self.model.add(Reshape((7, 7, 128)))\n", "\n", " # [7, 7, 128] => [14, 14, 64]\n", " self.model.add(Conv2DTranspose(64, 5, strides=2, padding='same', use_bias=False))\n", " self.model.add(LeakyReLU())\n", "\n", " # [14, 14, 64] => [28, 28, 1]\n", " self.model.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='tanh', use_bias=False))\n", "\n", " def call(self, x):\n", " return self.model(x)\n", "\n", "\n", "class Discriminator(Model):\n", " def __init__(self):\n", " super(Discriminator, self).__init__()\n", "\n", " self.model = Sequential()\n", "\n", " # [28, 28, 1] => [14, 14, 64]\n", " self.model.add(Conv2D(64, 5, strides=2, padding='same', input_shape=(28, 28, 1)))\n", " self.model.add(LeakyReLU())\n", "\n", " # [14, 14, 64] => [7, 7, 128]\n", " self.model.add(Conv2D(128, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [7, 7, 128] => [4, 4, 256]\n", " self.model.add(Conv2D(256, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [4, 4, 256] => [1]\n", " self.model.add(Flatten())\n", " self.model.add(Dense(1))\n", "\n", " def call(self, x):\n", " return self.model(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "ziQcAhm21dJ5" }, "source": [ "# 使用 Gradient penalty 取代 weight clipping" ] }, { "cell_type": "markdown", "metadata": { "id": "Ozjc7_791dJ5" }, "source": [ "跟WGAN或Vanilla GAN的最大差異就是在loss function中多了gradient penalty,基本上就是從真實資料與生成資料中sample資料,然後對discriminator作gradient descend後得到該資料點上的graident,取完norm後減一再平方就是我們要的gradient penalty了!\n", "\n", "
\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2nZ859Ne1dJ5" }, "outputs": [], "source": [ "def gradient_penalty(real_images, fake_images):\n", "\n", " # 先從 0 ~ 1 隨機 sample 一組權重(shape 必須與圖片的 tensor 吻合)\n", " epsilon = tf.random.uniform([real_images.shape[0], 1, 1, 1], 0.0, 1.0)\n", "\n", " # 將權重乘上一組真資料與假資料,這樣就等同於從真假資料之間 sample 一筆資料\n", " x_hat = epsilon * real_images + (1 - epsilon) * fake_images\n", "\n", " with tf.GradientTape() as t:\n", " # watch 這個 method 是確保 tape 能夠指認要微分的對象\n", " t.watch(x_hat)\n", " d_hat = discriminator(x_hat)\n", " gradients = t.gradient(d_hat, x_hat)\n", "\n", " # 先取 gradient 的 norm , 減一後再平方\n", " g_norm = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))\n", " gradient_penalty = tf.reduce_mean((g_norm - 1.0) ** 2)\n", " return gradient_penalty\n", "\n", "\n", "def gan_loss(d_real_output, d_fake_output, real_images, fake_images):\n", "\n", " # 與vanilla GAN 不同的地方是不加 log 而是直接用 output 來算gradient\n", "\n", " # discriminator loss\n", " d_loss = tf.reduce_mean(d_fake_output) - tf.reduce_mean(d_real_output) + gradient_penalty(\n", " real_images, fake_images) * gp_weight\n", "\n", " # generator loss\n", " g_loss = tf.reduce_mean(-d_fake_output)\n", " return d_loss, g_loss" ] }, { "cell_type": "markdown", "metadata": { "id": "6p8Y2b0p1dJ6" }, "source": [ "# Model, seed and checkpoint setting" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OHwqNbPD1dJ6" }, "outputs": [], "source": [ "generator = Generator(z_dim)\n", "discriminator = Discriminator()\n", "\n", "g_optimizer = RMSprop(learning_rate)\n", "d_optimizer = RMSprop(learning_rate)\n", "\n", "seed = tf.random.normal([num_examples_to_generate, z_dim])\n", "\n", "save_dir = './saved_imgs_GP'\n", "checkpoint_dir = './training_checkpoints_GP'\n", "\n", "if not os.path.exists(save_dir):\n", " os.makedirs(save_dir)\n", "if not os.path.exists(checkpoint_dir):\n", " os.makedirs(checkpoint_dir)\n", "\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,\n", " d_optimizer=d_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "raot1iXQ1dJ6" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(real_images, generator, discriminator, g_optimizer, d_optimizer):\n", " noise = tf.random.normal([real_images.shape[0], z_dim])\n", " with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:\n", " fake_images = generator(noise, training=True)\n", " d_real_logits = discriminator(real_images)\n", " d_fake_logits = discriminator(fake_images)\n", "\n", " d_loss, g_loss = gan_loss(d_real_logits, d_fake_logits, real_images, fake_images)\n", "\n", " g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)\n", " d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)\n", "\n", " g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))\n", " d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))\n", "\n", " return d_loss, g_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qtSQEAEs1dJ7" }, "outputs": [], "source": [ "def train(dataset, epochs):\n", " for epoch in range(epochs):\n", " start = time.time()\n", "\n", " for image_batch in dataset:\n", " d_loss, g_loss = train_step(image_batch, generator, discriminator, g_optimizer, d_optimizer)\n", "\n", " # Produce images\n", " print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))\n", " print('discriminator loss: %.5f' % d_loss)\n", " print('generator loss: %.5f' % g_loss)\n", " generate_and_save_images(generator, epoch + 1, seed, save_dir)\n", "\n", " # Save the model every 25 epochs\n", " if (epoch + 1) % 25 == 0:\n", " checkpoint.save(file_prefix=checkpoint_prefix)\n", " # generating / saving after the final epoch\n", " generate_and_save_images(generator, epochs, seed, save_dir)\n", " checkpoint.save(file_prefix=checkpoint_prefix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E-Muaay81dJ7" }, "outputs": [], "source": [ "def generate_and_save_images(model, epoch, test_input, save_path):\n", "\n", " predictions = model(test_input)\n", "\n", " fig = plt.figure(figsize=(4, 4))\n", "\n", " for i in range(predictions.shape[0]):\n", " plt.subplot(4, 4, i + 1)\n", " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", " plt.axis('off')\n", "\n", " # 每 5 個 epoches 存一次圖片\n", " if (epoch + 1) % 5 == 0:\n", " plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch)))\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true }, "id": "9QklF87f1dJ7" }, "outputs": [], "source": [ "%%time\n", "train(train_dataset, EPOCHS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZnYukEMn1dJ8" }, "outputs": [], "source": [ "# 使用imageio製作gif圖\n", "anim_file = 'saved_imgs_GP/wgan-gp.gif'\n", "\n", "with imageio.get_writer(anim_file, mode='I') as writer:\n", "\n", " filenames = glob.glob('saved_imgs_GP/image*.png')\n", " filenames = sorted(filenames)\n", " last = -1\n", " for i, filename in enumerate(filenames):\n", " frame = 2*(i**0.5)\n", " if round(frame) > round(last):\n", " last = frame\n", " else:\n", " continue\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", "\n", "display(Image(filename=anim_file))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FWIHeppg1dJ9" }, "outputs": [], "source": [ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mOySgsG11dJ9" }, "outputs": [], "source": [ "noise = tf.random.normal([1, z_dim])\n", "img = generator(noise)\n", "\n", "plt.imshow(img[0, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HTybi8M61dJ-" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [], "gpuType": "T4", "include_colab_link": true }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }