{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "HgxXIkoTaks_" }, "source": [ "# Pix2Pix (Image-to-Image Translation with Conditional Adversarial Networks)" ] }, { "cell_type": "markdown", "metadata": { "id": "nEBnm9VTaktI" }, "source": [ "### 本章節內容大綱\n", "* [Build a Pix2Pix Model](#Build-a-Pix2Pix-Model)\n", "* [Build a UNET model as Generator](#Build-a-UNET-model-as-Generator)\n", "* [Build a PatchGAN as Discriminator](#Build-a-PatchGAN-as-Discriminator)\n", "* [Loss function](#Loss-function)" ] }, { "cell_type": "markdown", "metadata": { "id": "ve6Hr-5EaktK" }, "source": [ "
\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "n-xsNFG3aktL" }, "source": [ "Pix2Pix是CGAN(conditional GAN)的一種,與vanilla gan不同的是,其input為一張圖片,而output出我們想要變成的樣子(不像是vanilla GAN無法控制想要output的樣子),像是上面的例子:把標記的segmentation變成圖片、把黑白照片轉成彩色照片等等,資料上需要一個個的pair data(labels與圖片),是supervised learning的一種。" ] }, { "cell_type": "markdown", "metadata": { "id": "ulcStG2SaktL" }, "source": [ "# Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qxoMyfspaktM" }, "outputs": [], "source": [ "''' basic package '''\n", "import os\n", "import time\n", "import glob\n", "from IPython.display import display, Image, clear_output\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow.keras import Model, Sequential\n", "from tensorflow.keras.layers import (\n", " Dense, Conv2DTranspose, Conv2D, BatchNormalization,\n", " LeakyReLU, Dropout, Reshape, Flatten\n", ")\n", "\n", "from tensorflow.keras.losses import BinaryCrossentropy\n", "from tensorflow.keras.optimizers import Adam\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KkOjXBmiaktP" }, "outputs": [], "source": [ "\n", "# 下載 facades dataset 到 hub 裡面的 home/jovyan/.keras 資料夾\n", "_URL = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz'\n", "\n", "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n", " origin=_URL,\n", " extract=True)\n", "\n", "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')" ] }, { "cell_type": "markdown", "metadata": { "id": "xgcpUL6zaktR" }, "source": [ "# Config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B3WRF2DqaktR" }, "outputs": [], "source": [ "BUFFER_SIZE = 400 # total 400 張圖片\n", "BATCH_SIZE = 1\n", "IMG_WIDTH = 256\n", "IMG_HEIGHT = 256" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MguQ21lPaktS" }, "outputs": [], "source": [ "# plot 原始圖片\n", "image = tf.io.read_file(PATH+'train/100.jpg')\n", "# 將 image decode 為 unit8 的 tensor\n", "image = tf.image.decode_jpeg(image)\n", "plt.imshow(image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p_-a_nvFaktT" }, "outputs": [], "source": [ "def load(image_file):\n", " image = tf.io.read_file(image_file)\n", " # 將 image decode 為 unit8 的 tensor\n", " image = tf.image.decode_jpeg(image)\n", "\n", " w = tf.shape(image)[1]\n", "\n", " # 因為原始圖片是將 image 與 label 黏在一起,所以這邊要把他們切開\n", " w = w // 2\n", " real_image = image[:, :w, :]\n", " input_image = image[:, w:, :]\n", "\n", " # 將 image datatype 改為 float32\n", " input_image = tf.cast(input_image, tf.float32)\n", " real_image = tf.cast(real_image, tf.float32)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bf59BJ-vaktT" }, "outputs": [], "source": [ "inp, re = load(PATH+'train/100.jpg')\n", "fig, axs = plt.subplots(1, 2)\n", "axs[0].imshow(inp/255.0)\n", "axs[1].imshow(re/255.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w2rZpJwiaktU" }, "outputs": [], "source": [ "def resize(input_image, real_image, height, width):\n", " # 將圖片放大成我們想要的大小,方法是用 NEAREST_NEIGHBOR 最鄰近插值法\n", " input_image = tf.image.resize(input_image, [height, width],\n", " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", " real_image = tf.image.resize(real_image, [height, width],\n", " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", "\n", " return input_image, real_image\n", "\n", "\n", "def random_crop(input_image, real_image):\n", " # 將圖片一張一張疊起來(類似append)\n", " stacked_image = tf.stack([input_image, real_image], axis=0)\n", " # 將圖片隨機切割成我們想要的大小: 256 x 256\n", " # 我們不想要隨機切 stack_size 與 channel 的部分,故就輸入其原始的 shape: stack_size=2 和 channel=3\n", " cropped_image = tf.image.random_crop(\n", " stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", "\n", " return cropped_image[0], cropped_image[1]\n", "\n", "\n", "# 將圖片正規化到 -1 到 1 之間\n", "def normalize(input_image, real_image):\n", " input_image = (input_image / 127.5) - 1\n", " real_image = (real_image / 127.5) - 1\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZlNgYxgDaktW" }, "outputs": [], "source": [ "@tf.function()\n", "# data augmentation\n", "def random_jitter(input_image, real_image):\n", " # 調整圖片大小至 286 x 286 x 3\n", " input_image, real_image = resize(input_image, real_image, 286, 286)\n", "\n", " # 隨機切割圖片至 256 x 256 x 3\n", " input_image, real_image = random_crop(input_image, real_image)\n", "\n", " # 隨機水平翻轉圖片\n", " if tf.random.uniform(()) > 0.5:\n", " # random mirroring\n", " input_image = tf.image.flip_left_right(input_image)\n", " real_image = tf.image.flip_left_right(real_image)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7rReT1WxaktX" }, "outputs": [], "source": [ "# 檢查 jitter 後的結果\n", "plt.figure(figsize=(6, 6))\n", "for i in range(4):\n", " rj_inp, rj_re = random_jitter(inp, re)\n", " plt.subplot(2, 2, i+1)\n", " plt.imshow(rj_inp/255.0)\n", " plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7by0TXiLaktY" }, "outputs": [], "source": [ "# 定義 train/test generator\n", "def load_image_train(image_file):\n", " input_image, real_image = load(image_file)\n", " input_image, real_image = random_jitter(input_image, real_image)\n", " input_image, real_image = normalize(input_image, real_image)\n", "\n", " return input_image, real_image\n", "\n", "\n", "# no jitter at testing!\n", "def load_image_test(image_file):\n", " input_image, real_image = load(image_file)\n", " input_image, real_image = resize(input_image, real_image,\n", " IMG_HEIGHT, IMG_WIDTH)\n", " input_image, real_image = normalize(input_image, real_image)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LDOPLYSPaktZ" }, "outputs": [], "source": [ "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n", "# num_parallel_calls 是一次準備多少圖片一起處理,他可以最佳化到底要讀多少圖的這個參數\n", "train_dataset = train_dataset.map(load_image_train,\n", " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", "train_dataset = train_dataset.batch(BATCH_SIZE)\n", "\n", "\n", "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n", "test_dataset = test_dataset.map(load_image_test)\n", "test_dataset = test_dataset.batch(BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HoVGm7G6aktZ" }, "outputs": [], "source": [ "OUTPUT_CHANNELS = 3" ] }, { "cell_type": "markdown", "metadata": { "id": "HMn-v25_aktZ" }, "source": [ "# Build a Pix2Pix Model" ] }, { "cell_type": "markdown", "metadata": { "id": "dZZvVH4fakta" }, "source": [ "
\n", "\n", "\n", "Pix2Pix的Generator因為是要將標籤圖片(labels)轉為照片,所以整個模型要使用類似Autoencoder的結構,稱作Unet,其實Unet與Autoencoder的結構基本上一樣,只是增加了skip connection來增加圖像的品質。\n", "而Discriminator則與之前差不多,目的是要透過真實的data作為依據,判斷餵進來的圖是真的還是假的。\n", "\n", "這邊先將 Generator 與 Discriminator 需要用到的結構:Downsample與Upsample寫成 function 備用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "upVhez1bakta" }, "outputs": [], "source": [ "def downsample(filters, size, apply_batchnorm=True):\n", " initializer = tf.random_normal_initializer(0., 0.02) # mean=0, stddev=0.02\n", "\n", " result = tf.keras.Sequential()\n", "\n", " # 因為預設會使用 batchnorm,所以不需要加 bias\n", " result.add(\n", " tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n", " kernel_initializer=initializer, use_bias=False))\n", "\n", " if apply_batchnorm:\n", " result.add(tf.keras.layers.BatchNormalization())\n", "\n", " result.add(tf.keras.layers.LeakyReLU())\n", "\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cZRdEX8jakta" }, "outputs": [], "source": [ "def upsample(filters, size, apply_dropout=False):\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", "\n", " result = tf.keras.Sequential()\n", "\n", " # 還記得 Vanilla GAN 裡提到的 Conv2DTranspose 的介紹嗎?忘記了可以再回去看喔!\n", " result.add(\n", " tf.keras.layers.Conv2DTranspose(filters, size, strides=2,\n", " padding='same',\n", " kernel_initializer=initializer,\n", " use_bias=False))\n", "\n", " result.add(tf.keras.layers.BatchNormalization())\n", "\n", " if apply_dropout:\n", " result.add(tf.keras.layers.Dropout(0.5))\n", "\n", " result.add(tf.keras.layers.ReLU())\n", "\n", " return result" ] }, { "cell_type": "markdown", "metadata": { "id": "wLi-DnyQakta" }, "source": [ "## Build a UNET model as Generator" ] }, { "cell_type": "markdown", "metadata": { "id": "XrpgDtjRakta" }, "source": [ "Unet的結構如前述所說,就是一個有skip connection的Autoencoder," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HP3mySw1akta" }, "outputs": [], "source": [ "def Generator():\n", " down_stack = [\n", " downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)\n", " downsample(128, 4), # (bs, 64, 64, 128)\n", " downsample(256, 4), # (bs, 32, 32, 256)\n", " downsample(512, 4), # (bs, 16, 16, 512)\n", " downsample(512, 4), # (bs, 8, 8, 512)\n", " downsample(512, 4), # (bs, 4, 4, 512)\n", " downsample(512, 4), # (bs, 2, 2, 512)\n", " downsample(512, 4), # (bs, 1, 1, 512)\n", " ]\n", "\n", " up_stack = [\n", " upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)\n", " upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)\n", " upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)\n", " upsample(512, 4), # (bs, 16, 16, 1024)\n", " upsample(256, 4), # (bs, 32, 32, 512)\n", " upsample(128, 4), # (bs, 64, 64, 256)\n", " upsample(64, 4), # (bs, 128, 128, 128)\n", " ]\n", "\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", " # 最後 output 的 range 要在 -1 ~ 1 之間,所以選用的 activation function 是 \"tanh\"\n", " last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,\n", " strides=2,\n", " padding='same',\n", " kernel_initializer=initializer,\n", " activation='tanh') # (bs, 256, 256, 3)\n", "\n", " concat = tf.keras.layers.Concatenate()\n", " inputs = tf.keras.layers.Input(shape=[None, None, 3])\n", " x = inputs\n", "\n", " # Downsampling\n", " # 用一個 list 將每層的輸出存起來,之後再 Upsampling 時可以使用\n", " skips = []\n", " for down in down_stack:\n", " x = down(x)\n", " skips.append(x)\n", " skips = reversed(skips[:-1]) # 把 skip connections 的值存起來並顛倒,後面在 upsampling 時會用到\n", "\n", " # Upsampling 和 skip connections\n", " for up, skip in zip(up_stack, skips):\n", " x = up(x)\n", " x = concat([x, skip])\n", "\n", " x = last(x)\n", "\n", " return tf.keras.Model(inputs=inputs, outputs=x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EyUS_VfTaktb" }, "outputs": [], "source": [ "generator = Generator()\n", "\n", "# 測試 output 是不是跟我們想要的一樣\n", "# 記得training要設成 false 否則 batchnorm 裡面的參數會被更新\n", "gen_output = generator(inp[tf.newaxis, ...], training=False)\n", "plt.imshow(gen_output[0, ...])" ] }, { "cell_type": "markdown", "metadata": { "id": "NU9mfP8gaktb" }, "source": [ "## Build a PatchGAN as Discriminator" ] }, { "cell_type": "markdown", "metadata": { "id": "Mm3SU4eFaktb" }, "source": [ "Discriminator相對單純,比較不一樣的地方是這邊所使用的是Markovian discriminator(PatchGAN),簡單來說就是在最後的輸出不要將feature maps轉成一個值作分類,而是把圖切成NxN的Patch一格格作判斷,之後取平均的概念。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L5kxIZNNaktc" }, "outputs": [], "source": [ "def Discriminator():\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", "\n", " inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')\n", " tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')\n", "\n", " x = tf.keras.layers.concatenate([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n", "\n", " down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)\n", " down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)\n", " down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)\n", "\n", " zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)\n", " conv = tf.keras.layers.Conv2D(512, 4, strides=1,\n", " kernel_initializer=initializer,\n", " use_bias=False)(zero_pad1) # (bs, 31, 31, 512)\n", "\n", " batchnorm1 = tf.keras.layers.BatchNormalization()(conv)\n", "\n", " leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)\n", "\n", " zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)\n", "\n", " last = tf.keras.layers.Conv2D(1, 4, strides=1,\n", " kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)\n", "\n", " return tf.keras.Model(inputs=[inp, tar], outputs=last)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dx7Y0N8Haktc" }, "outputs": [], "source": [ "# 畫出 discriminator 的 output,確認是我們要的樣子\n", "discriminator = Discriminator()\n", "disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)\n", "plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')\n", "plt.colorbar()" ] }, { "cell_type": "markdown", "metadata": { "id": "sXJmfcigaktd" }, "source": [ "# Loss function" ] }, { "cell_type": "markdown", "metadata": { "id": "mWZ05-nxaktd" }, "source": [ "這邊的discriminator與一般的一樣,而Generator有些微的差異,即加入了identity loss來修正output與實際圖片的差異。\n", "\n", "- Discriminator loss\n", " - real_loss: 更新將真實資料判斷成生成資料的狀況\n", " - generated_loss: 更新將生成資料判斷成真實資料的狀況\n", " \n", "- Generator loss\n", " - gan_loss: 更新生成資料被Discriminator判斷成生成資料的狀況\n", " - l1_loss: 修正生成資料與對應該label image的真實資料實際的差異(identity loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WSLjlkvpaktd" }, "outputs": [], "source": [ "LAMBDA = 100\n", "\n", "loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n", "\n", "\n", "def discriminator_loss(disc_real_output, disc_generated_output):\n", " real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)\n", "\n", " generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)\n", "\n", " total_disc_loss = real_loss + generated_loss\n", "\n", " return total_disc_loss\n", "\n", "\n", "def generator_loss(disc_generated_output, gen_output, target):\n", " gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)\n", "\n", " # 加入 L1 loss 來增加模型的 robustness 與細緻度, 影像也較 L2 loss 不模糊\n", " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", "\n", " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", "\n", " return total_gen_loss\n", "\n", "\n", "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", "discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RW52YKEJakte" }, "outputs": [], "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", " discriminator_optimizer=discriminator_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uUJHRyXQakte" }, "outputs": [], "source": [ "EPOCHS = 50\n", "\n", "\n", "def generate_images(model, test_input, tar):\n", "\n", " prediction = model(test_input, training=True) # 這邊設 training=True 是希望能得到 test_input 的一些統計量\n", " plt.figure(figsize=(15, 15))\n", "\n", " display_list = [test_input[0], tar[0], prediction[0]]\n", " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", "\n", " for i in range(3):\n", " plt.subplot(1, 3, i+1)\n", " plt.title(title[i])\n", " # 將圖片像素值調整至 0 - 1 之間才能 plot\n", " plt.imshow(display_list[i] * 0.5 + 0.5)\n", " plt.axis('off')\n", " plt.show()\n", "\n", "\n", "@tf.function\n", "def train_step(input_image, target):\n", " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", " gen_output = generator(input_image, training=True)\n", "\n", " disc_real_output = discriminator([input_image, target], training=True)\n", " disc_generated_output = discriminator([input_image, gen_output], training=True)\n", "\n", " gen_loss = generator_loss(disc_generated_output, gen_output, target)\n", " disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n", "\n", " generator_gradients = gen_tape.gradient(gen_loss,\n", " generator.trainable_variables)\n", " discriminator_gradients = disc_tape.gradient(disc_loss,\n", " discriminator.trainable_variables)\n", "\n", " generator_optimizer.apply_gradients(zip(generator_gradients,\n", " generator.trainable_variables))\n", " discriminator_optimizer.apply_gradients(zip(discriminator_gradients,\n", " discriminator.trainable_variables))\n", "\n", "\n", "def fit(train_ds, epochs, test_ds):\n", " for epoch in range(epochs):\n", " start = time.time()\n", "\n", " # Train\n", " for input_image, target in train_ds:\n", " train_step(input_image, target)\n", "\n", " # 每 10 個 epochs 顯示一次圖片\n", " if (epoch + 1) % 10 == 0:\n", " for example_input, example_target in test_ds.take(1):\n", " generate_images(generator, example_input, example_target)\n", "\n", " # 每 50 個 epochs 存一次 weight\n", " if (epoch + 1) % 50 == 0:\n", " checkpoint.save(file_prefix=checkpoint_prefix)\n", "\n", " print('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", " time.time()-start))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4nwjUWtvakte" }, "outputs": [], "source": [ "fit(train_dataset, EPOCHS, test_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "OGcR8CQtakte" }, "source": [ "最後將存好的weight載入,試試看生成一些圖片吧!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "giCsSjr5aktf" }, "outputs": [], "source": [ "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B2c8-IY3aktf" }, "outputs": [], "source": [ "for inp, tar in test_dataset.take(5):\n", " generate_images(generator, inp, tar)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VOvkz-POaktl" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [], "gpuType": "T4", "include_colab_link": true }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }