{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "AH4N4GKBKhFI" }, "source": [ "# WGAN on MNIST" ] }, { "cell_type": "markdown", "metadata": { "id": "fXEj4ZhEKhFW" }, "source": [ "### 本章節內容大綱\n", "* [WGAN原理](#WGAN原理)\n", "* [1-Lipschitz 實作](#1-Lipschitz-實作)" ] }, { "cell_type": "markdown", "metadata": { "id": "a_1rlgR5KhFY" }, "source": [ "# WGAN原理" ] }, { "cell_type": "markdown", "metadata": { "id": "OjezS-pyKhFc" }, "source": [ "在vanilla GAN裡面,是使用JS divergence來衡量兩個分佈的遠近,不過JS divergence有一個缺點,就是當資料沒有交集的時候,它的值永遠會是log2;而生成資料與真實資料皆可看成高維空間中的低維manifolds(流形),他們的交集基本上可以被忽略(或者說可以找到一個discriminator將其輕易得分開),那將會使JS divergence永遠維持在log2,而讓generator的gradient停在0無法更新。\n", "\n", "
\n", "\n", "\n", "WGAN作者使用了Wasserstein distance(又作Earth mover's distance,EMD)來取代原本的JS divergence,簡單來說,EMD就是將P分佈變成Q分佈所需要的最小代價:\n", "\n", "$$B(\\gamma) = \\sum_{x_p,x_q}\\gamma(x_p,x_q)||x_p-x_q|| $$\n", "$$W(P,Q) = \\min_{\\gamma \\in \\Pi} B(\\gamma)$$\n", "\n", "不過這個方法要窮舉所有的moving plan,WGAN作者透過複雜的數學推導,將下式作為 discriminator 的 objective function 直接衡量 $P_G$ 與 $P_{data}$ 之間的 Wasserstein distance:\n", "\n", "$$V(G,D) = \\max_{D \\in 1-Lipschitz} (E_{x \\sim P_{data}}\\;[D(x)] - E_{x \\sim P_{G}}\\;[D(x)])$$\n", "\n", "這邊的1-Lipschitz function簡單來說就是要discriminator變得平滑,讓generator能夠依循它的gradient更新。" ] }, { "cell_type": "markdown", "metadata": { "id": "K9DECLFuKhFe" }, "source": [ "# Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4n1DqVBoKhFl" }, "outputs": [], "source": [ "''' basic package '''\n", "import os\n", "import time\n", "import imageio\n", "import glob\n", "from IPython.display import display, Image\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow.keras import Model, Sequential\n", "from tensorflow.keras.layers import (\n", " Dense, Conv2DTranspose, Conv2D, BatchNormalization,\n", " LeakyReLU, Dropout, Reshape, Flatten\n", ")\n", "\n", "from tensorflow.keras.losses import BinaryCrossentropy\n", "\n", "from tensorflow.keras.optimizers import Adam, RMSprop\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'" ] }, { "cell_type": "markdown", "metadata": { "id": "kw5j7uQnKhFm" }, "source": [ "# Config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8xTSLWN8KhFn" }, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "BUFFER_SIZE = 60000 # tf2.0 的 shuffle 需要定義「抽籤桶」要多大,設 60000 意指全部的資料\n", "z_dim = 100 # latent/noise vector z 的維度\n", "EPOCHS = 50\n", "learning_rate = 1e-4\n", "num_examples_to_generate = 16\n", "clip = [-0.05, 0.05] # 將 weight 限制在 - 0.05 ~ + 0.05 之間\n", "\n", "\n", "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()\n", "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", "\n", "# 將圖片正規化至 [-1 ~ 1]\n", "train_images = (train_images - 127.5) / 127.5\n", "\n", "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "fOz76b91KhFo" }, "source": [ "# Train" ] }, { "cell_type": "markdown", "metadata": { "id": "CZyaGRc8KhFp" }, "source": [ "### 定義 generator & discriminator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w7b0v1kdKhFq" }, "outputs": [], "source": [ "# Define model\n", "\n", "\n", "class Generator(Model):\n", " def __init__(self, z_dim):\n", " super(Generator, self).__init__()\n", "\n", " self.model = Sequential()\n", "\n", " # [z_dim] => [7, 7, 128]\n", " self.model.add(Dense(7 * 7 * 128, input_shape=(z_dim,)))\n", " self.model.add(LeakyReLU())\n", " self.model.add(Reshape((7, 7, 128)))\n", "\n", " # [7, 7, 128] => [14, 14, 64]\n", " self.model.add(Conv2DTranspose(64, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [14, 14, 64] => [28, 28, 1]\n", " self.model.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='tanh'))\n", "\n", " def call(self, x):\n", " return self.model(x)\n", "\n", "\n", "class Discriminator(Model):\n", " def __init__(self):\n", " super(Discriminator, self).__init__()\n", "\n", " self.model = Sequential()\n", "\n", " # [28, 28, 1] => [14, 14, 64]\n", " self.model.add(Conv2D(64, 5, strides=2, padding='same', input_shape=(28, 28, 1)))\n", " self.model.add(LeakyReLU())\n", "\n", " # [14, 14, 64] => [7, 7, 128]\n", " self.model.add(Conv2D(128, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [7, 7, 128] => [4, 4, 256]\n", " self.model.add(Conv2D(256, 5, strides=2, padding='same'))\n", " self.model.add(LeakyReLU())\n", "\n", " # [4, 4, 256] => [1]\n", " self.model.add(Flatten())\n", " self.model.add(Dense(1))\n", "\n", " def call(self, x):\n", " return self.model(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n86PEGeFKhFr" }, "outputs": [], "source": [ "def gan_loss(d_real_output, d_fake_output):\n", "\n", " # 與vanilla GAN 不同的地方是不加 log 而是直接用 output 來算gradient\n", "\n", " # discriminator loss\n", "\n", " d_loss = tf.reduce_mean(d_fake_output) - tf.reduce_mean(d_real_output)\n", "\n", " # generator loss\n", " g_loss = tf.reduce_mean(-d_fake_output)\n", " return d_loss, g_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ix8rfvdeKhFs" }, "outputs": [], "source": [ "generator = Generator(z_dim)\n", "discriminator = Discriminator()\n", "\n", "g_optimizer = RMSprop(learning_rate)\n", "d_optimizer = RMSprop(learning_rate)\n", "\n", "# 固定 seed 來確定我們之後產生的圖片品質是不是有比之前的好\n", "seed = tf.random.normal([num_examples_to_generate, z_dim])\n", "\n", "save_dir = './saved_imgs_wgan'\n", "checkpoint_dir = './training_checkpoints_wgan'\n", "\n", "if not os.path.exists(save_dir):\n", " os.makedirs(save_dir)\n", "if not os.path.exists(checkpoint_dir):\n", " os.makedirs(checkpoint_dir)\n", "\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,\n", " d_optimizer=d_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" ] }, { "cell_type": "markdown", "metadata": { "id": "6xYKUTB4KhFs" }, "source": [ "# 1-Lipschitz 實作" ] }, { "cell_type": "markdown", "metadata": { "id": "WhMG1AGKKhFt" }, "source": [ "因為 WGAN 作者在當時還沒想到比較好的方法實現 1-Lipschitz function 的限制,所以在 WGAN 中是直接使用 weight clipping 讓 discriminator 變得平滑,如果 c 選得夠小的話,的確可以讓 discriminator 是 1-Lipschitz function(只要斜率小於等於 1 就滿足條件),但這同時也限制了 Discriminator 的能力。所以實務上我們會想辦法兩者兼顧得調整 c ,讓 Discriminator 盡量接近 1-Lipschitz function 同時又保有一定的能力。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MGLhEXJtKhFt" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(real_images, generator, discriminator, g_optimizer, d_optimizer):\n", " noise = tf.random.normal([BATCH_SIZE, z_dim])\n", " with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:\n", " fake_images = generator(noise)\n", "\n", " d_real_output = discriminator(real_images)\n", " d_fake_output = discriminator(fake_images)\n", "\n", " d_loss, g_loss = gan_loss(d_real_output, d_fake_output)\n", "\n", " g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)\n", " d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)\n", "\n", " g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))\n", " d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))\n", "\n", " '''weight clipping method in WGAN'''\n", " D_weight_clip_opt = [var.assign(tf.clip_by_value(var, clip[0], clip[1]))\n", " for var in discriminator.trainable_variables]\n", "\n", " return d_loss, g_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JtX8q5RKKhFu" }, "outputs": [], "source": [ "def train(dataset, epochs):\n", " for epoch in range(epochs):\n", " start = time.time()\n", "\n", " for image_batch in dataset:\n", " d_loss, g_loss = train_step(image_batch, generator, discriminator,\n", " g_optimizer, d_optimizer)\n", "\n", " # 產生圖片\n", " print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))\n", " print('discriminator loss: %.5f' % d_loss)\n", " print('generator loss: %.5f' % g_loss)\n", " generate_and_save_images(generator, epoch + 1, seed, save_dir)\n", "\n", " # 每 25 個 epochs 存一次模型\n", " if (epoch + 1) % 25 == 0:\n", " checkpoint.save(file_prefix=checkpoint_prefix)\n", "\n", " # 在最後一個 epoch 再產生一次圖片與儲存一次權重\n", " generate_and_save_images(generator, epochs, seed, save_dir)\n", " checkpoint.save(file_prefix=checkpoint_prefix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PHsHPlCgKhFv" }, "outputs": [], "source": [ "def generate_and_save_images(model, epoch, test_input, save_path):\n", "\n", " predictions = model(test_input)\n", "\n", " fig = plt.figure(figsize=(4, 4))\n", "\n", " for i in range(predictions.shape[0]):\n", " plt.subplot(4, 4, i + 1)\n", " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", " plt.axis('off')\n", "\n", " # 每 5 個 epoches 存一次圖片\n", " if (epoch + 1) % 5 == 0:\n", " plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch)))\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true }, "id": "SK7WFmBBKhFw" }, "outputs": [], "source": [ "%%time\n", "train(train_dataset, EPOCHS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zixPRvSFKhFz" }, "outputs": [], "source": [ "# 使用imageio製作gif圖\n", "anim_file = 'saved_imgs_wgan/wgan.gif'\n", "\n", "with imageio.get_writer(anim_file, mode='I') as writer:\n", "\n", " filenames = glob.glob('saved_imgs_wgan/image*.png')\n", " filenames = sorted(filenames)\n", " last = -1\n", " for i, filename in enumerate(filenames):\n", " frame = 2*(i**0.5)\n", " if round(frame) > round(last):\n", " last = frame\n", " else:\n", " continue\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", "\n", "display(Image(filename=anim_file))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "STtymd-TKhF0" }, "outputs": [], "source": [ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YJd3MN-bKhF1" }, "outputs": [], "source": [ "noise = tf.random.normal([1, z_dim])\n", "img = generator(noise, training=False)\n", "\n", "plt.imshow(img[0, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ijwqAfrMKhF1" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [], "gpuType": "T4", "include_colab_link": true }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }