{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "Q4-heyrYM7nL" }, "source": [ "# Cycle-GAN" ] }, { "cell_type": "markdown", "metadata": { "id": "ls9dpE03M7nR" }, "source": [ "### 本章節內容大綱\n", "* [Build a CyclaGAN Model](#Build-a-CyclaGAN-Model)\n", "* [Cycle Loss](#Cycle-Loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "FS_1kw3JM7nS" }, "source": [ "
\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "eY1k3YonM7nS" }, "source": [ "相較於conditional GAN需要一組的paired data來訓練模型,cycle GAN並不需要paired data,或是說,當我們無法取得paired data的時候,就可以使用cycle GAN。(像是要把照片的風格轉換成莫內的風格,但是莫內早已不在人世,我們就沒辦法得到對應於照片的畫作)" ] }, { "cell_type": "markdown", "metadata": { "id": "39eA5wuWM7nT" }, "source": [ "## Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZYbHgsapM7nV" }, "outputs": [], "source": [ "# 使用 instance normalization 所需要的套件\n", "!pip install tensorflow-addons" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EENUvKkkM7nW" }, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras import Model\n", "from tensorflow.keras.layers import (\n", " Input, Conv2DTranspose, Conv2D, BatchNormalization,\n", " ReLU, LeakyReLU, Dropout, Reshape, Activation, add, concatenate\n", ")\n", "from tensorflow.keras.initializers import glorot_uniform\n", "from tensorflow_addons.layers import InstanceNormalization\n", "\n", "import os\n", "import time\n", "import numpy as np\n", "import glob\n", "import matplotlib.pyplot as plt\n", "\n", "AUTOTUNE = tf.data.experimental.AUTOTUNE" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eekG90CsM7nY" }, "outputs": [], "source": [ "len(tf.config.experimental.list_physical_devices('GPU'))" ] }, { "cell_type": "markdown", "metadata": { "id": "fOxSQhLOM7nY" }, "source": [ "## Preparing data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_8T5PPccM7nZ" }, "outputs": [], "source": [ "# 上傳資料\n", "!wget -q https://github.com/TA-aiacademy/course_3.0/releases/download/v2.5_gan/GAN_part3.zip\n", "!unzip -q GAN_part3.zip" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eIBDIZs9M7nZ" }, "outputs": [], "source": [ "CLASS_MAP = {'trainA':0, 'trainB':1, 'testA':0, 'testB':1}\n", "def paths2labels(paths):\n", " return [CLASS_MAP[p.split(os.sep)[-2]] for p in paths]\n", "\n", "# 影像讀取 & resize\n", "def load_image(path):\n", " image = tf.io.read_file(path)\n", " image = tf.image.decode_jpeg(image, channels=3)\n", " image = tf.image.resize(image, [256, 256])\n", " return image\n", "\n", "# 使用路徑建構 tf.data.Dataset\n", "def build_ds(paths):\n", " labels = paths2labels(paths) # paths -> labels\n", " image_ds = tf.data.Dataset.from_tensor_slices((paths, labels))\n", " image_ds = image_ds.map(lambda path, label: (load_image(path), label)) # path -> img, labels\n", " image_ds = image_ds.prefetch(AUTOTUNE)\n", " return image_ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A6Gxe_lxM7nZ" }, "outputs": [], "source": [ "image_dir = glob.glob('summer2winter_yosemite/*/*.jpg')\n", "metadata_dict = dict()\n", "\n", "for i in image_dir:\n", " _, dirs, files = i.split('/')\n", " if dirs not in metadata_dict:\n", " metadata_dict[dirs] = [i]\n", " else:\n", " metadata_dict[dirs].append(i)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mv25dbWIM7na" }, "outputs": [], "source": [ "train_summer, train_winter = build_ds(metadata_dict['trainA']), build_ds(metadata_dict['trainB'])\n", "test_summer, test_winter = build_ds(metadata_dict['testA']), build_ds(metadata_dict['testB'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5u2xB4PRM7na" }, "outputs": [], "source": [ "BUFFER_SIZE = 1000\n", "BATCH_SIZE = 4\n", "IMG_WIDTH = 256\n", "IMG_HEIGHT = 256\n", "\n", "LAMBDA_cycle = 10 # cycle loss 的權重\n", "LAMBDA_identity = 0.5 # identity loss 的權重\n", "\n", "EPOCHS = 5\n", "epoch_decay = 100 # 在 100 個 epoches 後 weight decay" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lAFlRHisM7nb" }, "outputs": [], "source": [ "def random_crop(image):\n", " cropped_image = tf.image.random_crop(\n", " image, size=[IMG_HEIGHT, IMG_WIDTH, 3])\n", "\n", " return cropped_image\n", "\n", "\n", "# 將圖片正規化到 [-1, 1]\n", "def normalize(image):\n", " image = tf.cast(image, tf.float32)\n", " image = (image / 127.5) - 1\n", " return image\n", "\n", "\n", "def random_jitter(image):\n", "\n", " # resize 到 286 x 286 x 3\n", " image = tf.image.resize(image, [286, 286],\n", " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", " # 隨機 crop 到 256 x 256 x 3\n", " image = random_crop(image)\n", "\n", " # 隨機翻轉\n", " image = tf.image.random_flip_left_right(image)\n", "\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Isdgvj7fM7nb" }, "outputs": [], "source": [ "def preprocess_image_train(image, label):\n", " # 'label' 這個參數是為了接原本圖片有多一個label的維度(這邊不需要使用label)\n", " image = random_jitter(image)\n", " image = normalize(image)\n", " return image\n", "\n", "\n", "def preprocess_image_test(image, label):\n", " image = normalize(image)\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GQSmMKewM7nb" }, "outputs": [], "source": [ "# cache 可以將 data 先讀入 memory 加快速度\n", "# num_parallel_calls 是一次準備多少圖片一起處理,AUTOTUNE 是我們前面有定義的變數,他可以最佳化到底要讀多少圖的這個參數\n", "\n", "train_summer = train_summer.map(\n", " preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(\n", " BUFFER_SIZE).batch(BATCH_SIZE)\n", "\n", "train_winter = train_winter.map(\n", " preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(\n", " BUFFER_SIZE).batch(BATCH_SIZE)\n", "\n", "test_summer = test_summer.map(\n", " preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n", " BUFFER_SIZE).batch(1)\n", "\n", "test_winter = test_winter.map(\n", " preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n", " BUFFER_SIZE).batch(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mdv8dYj6M7nc" }, "outputs": [], "source": [ "sample_summer = next(iter(train_summer))[0]\n", "sample_winter = next(iter(train_winter))[0]\n", "\n", "plt.figure(figsize=(15,15))\n", "plt.subplot(2,2,1)\n", "plt.title('Yosemite in summer')\n", "plt.imshow(sample_summer * 0.5 + 0.5)\n", "\n", "plt.subplot(2,2,2)\n", "plt.title('Yosemite in summer with random jitter')\n", "plt.imshow(random_jitter(sample_summer) * 0.5 + 0.5)\n", "\n", "plt.subplot(2,2,3)\n", "plt.title('Yosemite in winter')\n", "plt.imshow(sample_winter * 0.5 + 0.5)\n", "\n", "plt.subplot(2,2,4)\n", "plt.title('Yosemite in winter with random jitter')\n", "plt.imshow(random_jitter(sample_winter) * 0.5 + 0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "hPjKtW11M7nc" }, "source": [ "## Build a CycleGAN Model" ] }, { "cell_type": "markdown", "metadata": { "id": "QF5zqpBxM7nc" }, "source": [ "與Pix2Pix不一樣,CycleGAN因為是unpaired data的訓練,所以需要各兩個Generator與Discriminator一起訓練,\n", "\n", "兩個Generators : G 與 F ,的任務分別是將X Domain轉成Y Domain與將Y Domain轉成X Domain,\n", "\n", "而Discriminators的任務也一樣,要判別這兩個Domain的資料到底是真實的還是生成的。\n", "##### 模型的結構上與Pix2Pix不同的部分有:\n", "- Normalize的部分使用了Instance Normalization而非Batch Normalization(因為Batch_size = 1)\n", "- 整體的結構使用了resnet為基底的Generator而非Unet\n", "\n", "\n", "
\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DQ5fOMdIM7nd" }, "outputs": [], "source": [ "def identity_block(X, f, filters):\n", "\n", " # 儲存 input 來作 skip connection\n", " X_shortcut = X\n", "\n", " X = tf.pad(X, [[0, 0], [1, 1], [1, 1], [0, 0]], \"REFLECT\")\n", " X = Conv2D(filters=filters, kernel_size=(f, f), strides=(1, 1),\n", " padding='valid', use_bias=False)(X)\n", " X = InstanceNormalization()(X)\n", " X = ReLU()(X)\n", "\n", " X = tf.pad(X, [[0, 0], [1, 1], [1, 1], [0, 0]], \"REFLECT\")\n", " X = Conv2D(filters=filters, kernel_size=(f, f), strides=(1, 1),\n", " padding='valid', use_bias=False)(X)\n", " X = InstanceNormalization()(X)\n", "\n", " # 將 output 與 input 合在一起\n", " X = add([X, X_shortcut])\n", "\n", " return X\n", "\n", "\n", "def Conv(X, f, s, filters, leaky=False, padding='same'):\n", "\n", " if padding == 'same':\n", " conv = Conv2D(filters, (f, f), padding='same', use_bias=False, strides=(s, s))(X)\n", " elif padding == 'valid':\n", " conv = Conv2D(filters, (f, f), padding='valid', use_bias=False, strides=(s, s))(X)\n", "\n", " conv = InstanceNormalization()(conv)\n", "\n", " if leaky:\n", " conv = LeakyReLU(alpha=0.2)(conv)\n", " else:\n", " conv = ReLU()(conv)\n", " return conv" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6I7bTyFyM7nd" }, "outputs": [], "source": [ "def ResGen(input_size=(256, 256, 3), n_blocks=9):\n", "\n", " inputs = Input(input_size)\n", " # ======================Encoder======================\n", "\n", " # batch 與 channel 不需要 padding\n", " X = tf.pad(inputs, [[0, 0], [3, 3], [3, 3], [0, 0]], \"REFLECT\")\n", "\n", " X = Conv(X, 7, 1, 64, leaky=False, padding='valid') # (256,256,64)\n", "\n", " X = Conv(X, 3, 2, 128, leaky=False) # (128,128,128)\n", "\n", " X = Conv(X, 3, 2, 256, leaky=False) # (64,64,256)\n", "\n", " # ====================Encoder end====================\n", "\n", " for i in range(n_blocks):\n", " X = identity_block(X, 3, 256)\n", "\n", " # ====================Decoder========================\n", "\n", " X = Conv2DTranspose(128, (3, 3), strides=(2, 2),\n", " padding='same', use_bias=False)(X) # (128,128,128)\n", " X = InstanceNormalization()(X)\n", " X = ReLU()(X)\n", "\n", " X = Conv2DTranspose(64, (3, 3), strides=(2, 2),\n", " padding='same', use_bias=False)(X) # (256,256,64)\n", " X = InstanceNormalization()(X)\n", " X = ReLU()(X)\n", "\n", " X = tf.pad(X, [[0, 0], [3, 3], [3, 3], [0, 0]], \"REFLECT\")\n", " X = Conv2D(3, (7, 7), padding='valid', activation='tanh')(X) # (256,256,3)\n", "\n", " # ======================Decoder end==========================\n", "\n", " return Model(inputs=[inputs], outputs=[X])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zqQ5pEpQM7nd" }, "outputs": [], "source": [ "def Disc(input_size=(256, 256, 3), n_blocks=1):\n", " inputs = Input(input_size)\n", "\n", " # ======================Encoder======================\n", "\n", " X = Conv2D(64, (4, 4), strides=2, padding='same')(inputs)\n", " X = LeakyReLU(alpha=0.2)(X)\n", "\n", " X = Conv(X, 4, 2, 128, leaky=True)\n", "\n", " X = Conv(X, 4, 2, 256, leaky=True)\n", "\n", " X = Conv(X, 4, 1, 512, leaky=True)\n", "\n", " # ======================Encoder end======================\n", "\n", " X = Conv2D(1, (4, 4), padding='same', activation=None)(X)\n", "\n", " return Model(inputs=[inputs], outputs=[X])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZpSdWWCBM7nd" }, "outputs": [], "source": [ "generator_g = ResGen()\n", "generator_f = ResGen()\n", "\n", "discriminator_x = Disc()\n", "discriminator_y = Disc()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6ITK1Az4M7ne" }, "outputs": [], "source": [ "# 檢查 Generator 的 output\n", "to_winter = generator_g(sample_summer[tf.newaxis, ...])[0]\n", "to_summer = generator_f(sample_winter[tf.newaxis, ...])[0]\n", "plt.figure(figsize=(15, 15))\n", "contrast = 8\n", "\n", "imgs = [sample_summer, to_winter, sample_winter, to_summer]\n", "title = ['Summer', 'To Winter', 'Winter', 'To Summer']\n", "\n", "# 畫出 generator 的 output\n", "for i in range(len(imgs)):\n", " plt.subplot(2, 2, i+1)\n", " plt.title(title[i])\n", " if i % 2 == 0:\n", " plt.imshow(imgs[i] * 0.5 + 0.5) # rescale value to [0-1]\n", " else:\n", " plt.imshow(imgs[i] * 0.5 * contrast + 0.5)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnfQaH0OM7ne" }, "outputs": [], "source": [ "# 檢查 Discriminator 的 output\n", "# 越紅越接近 0,越藍越接近 1\n", "plt.figure(figsize=(15, 15))\n", "\n", "plt.subplot(121)\n", "plt.title('Is a real summer?')\n", "plt.imshow(discriminator_x(sample_summer[tf.newaxis, ...])[0, ..., -1], cmap='RdBu_r')\n", "\n", "plt.subplot(122)\n", "plt.title('Is a real winter?')\n", "plt.imshow(discriminator_y(sample_winter[tf.newaxis, ...])[0, ..., -1], cmap='RdBu_r')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "fphy6kGBM7ne" }, "source": [ "## Cycle Loss" ] }, { "cell_type": "markdown", "metadata": { "id": "DluQgPviM7nf" }, "source": [ "在CycleGAN裡面,由於沒有paired data的緣故,所以我們無法保證從generator output出來的圖片是原圖轉換風格還是只是隨便一張能騙過Discriminator的圖。所以為了要強制generator能夠學到另一個風格並且還能維持原本的圖片語意,作者提出了 cycle consistency loss。簡單來說,就是把input $X$ 丟進 $X\\rightarrow Y$ 的generator再丟進 $Y\\rightarrow X$ 的generator最後得到 $\\hat{X}$ ,計算 $X$ 與 $\\hat{X}$ 之間的差異就是cycle consistency loss。\n", "\n", "
\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xv9nYb4QM7nf" }, "outputs": [], "source": [ "# 這裡使用 lsgan 的 loss function\n", "loss_obj = tf.losses.MeanSquaredError()\n", "\n", "\n", "def discriminator_loss(real, generated):\n", " real_loss = loss_obj(tf.ones_like(real), real)\n", "\n", " generated_loss = loss_obj(tf.zeros_like(generated), generated)\n", "\n", " total_disc_loss = real_loss + generated_loss\n", "\n", " return total_disc_loss\n", "\n", "\n", "def generator_loss(generated):\n", " return loss_obj(tf.ones_like(generated), generated)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WSch1u07M7nf" }, "outputs": [], "source": [ "# Cycle loss 是將 X 轉為 fake_Y 後,再將 fake_Y 轉回 X_cycle\n", "# 計算 X 與 X_cycle 的差異來得到 cycle consistency loss\n", "def calc_cycle_loss(real_image, cycled_image):\n", " loss = tf.reduce_mean(tf.abs(real_image - cycled_image))\n", "\n", " return LAMBDA_cycle * loss\n", "\n", "\n", "# Identity loss 即將 X 丟進 Y -> X 的 generator 後得到的 X_hat,應該也要跟 X 一樣而非胡亂轉換,\n", "# 效果類似提供一個 identity mapping 的參考\n", "def identity_loss(real_image, same_image):\n", " loss = tf.reduce_mean(tf.abs(real_image - same_image))\n", " return LAMBDA_identity * loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HP804DRuM7nf" }, "outputs": [], "source": [ "class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):\n", " # 如果目前的 step < step_dacay 就使用原本的 learning rate\n", " # 否則會以線性的衰退來挑整 learning rate 直至 0\n", "\n", " def __init__(self, initial_learning_rate, total_steps, step_decay):\n", " super(LinearDecay, self).__init__()\n", " self._initial_learning_rate = initial_learning_rate\n", " self._steps = total_steps\n", " self._step_decay = step_decay\n", " self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate,\n", " trainable=False, dtype=tf.float32)\n", "\n", " def __call__(self, step):\n", " self.current_learning_rate.assign(tf.cond(\n", " step >= self._step_decay,\n", " true_fn=lambda: self._initial_learning_rate * (1 - 1 / (self._steps - self._step_decay) *\n", " (step - self._step_decay)),\n", " false_fn=lambda: self._initial_learning_rate\n", " ))\n", " return self.current_learning_rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4CCs7IJtM7ng" }, "outputs": [], "source": [ "lr_scheduler = LinearDecay(2e-4, EPOCHS * BUFFER_SIZE, epoch_decay * BUFFER_SIZE)\n", "\n", "generator_g_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)\n", "generator_f_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)\n", "\n", "discriminator_x_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)\n", "discriminator_y_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Oa6bsdRBM7ng" }, "outputs": [], "source": [ "checkpoint_path = \"./training_checkpoints\"\n", "\n", "ckpt = tf.train.Checkpoint(generator_g=generator_g,\n", " generator_f=generator_f,\n", " discriminator_x=discriminator_x,\n", " discriminator_y=discriminator_y,\n", " generator_g_optimizer=generator_g_optimizer,\n", " generator_f_optimizer=generator_f_optimizer,\n", " discriminator_x_optimizer=discriminator_x_optimizer,\n", " discriminator_y_optimizer=discriminator_y_optimizer)\n", "\n", "ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) # 最多存 5 組 weight\n", "\n", "# 如果有 checkpoint ,載入其權重\n", "if ckpt_manager.latest_checkpoint:\n", " ckpt.restore(ckpt_manager.latest_checkpoint)\n", " print('Latest checkpoint restored!!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u08a23oCM7ng" }, "outputs": [], "source": [ "def generate_images(model, test_input):\n", " prediction = model(test_input)\n", "\n", " plt.figure(figsize=(12, 12))\n", "\n", " display_list = [test_input[0], prediction[0]]\n", " title = ['Input Image', 'Predicted Image']\n", "\n", " for i in range(2):\n", " plt.subplot(1, 2, i+1)\n", " plt.title(title[i])\n", " plt.imshow(display_list[i] * 0.5 + 0.5)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NYAzHdbcM7ng" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(real_x, real_y):\n", "\n", " # 當 tape 被拿來計算 graident 就會被記憶體釋放,persistent 可以避免被釋放\n", " # 通常用於計算多個 gradient 時使用\n", " with tf.GradientTape(persistent=True) as tape:\n", "\n", " # Generator G translates X -> Y\n", " # Generator F translates Y -> X\n", "\n", " fake_y = generator_g(real_x, training=True)\n", " cycled_x = generator_f(fake_y, training=True)\n", "\n", " fake_x = generator_f(real_y, training=True)\n", " cycled_y = generator_g(fake_x, training=True)\n", "\n", " # output discriminator logits\n", " disc_real_x = discriminator_x(real_x, training=True)\n", " disc_real_y = discriminator_y(real_y, training=True)\n", "\n", " disc_fake_x = discriminator_x(fake_x, training=True)\n", " disc_fake_y = discriminator_y(fake_y, training=True)\n", "\n", " # same_x 、 same_y 用來計算 identity loss\n", " same_x = generator_f(real_x, training=True)\n", " same_y = generator_g(real_y, training=True)\n", "\n", " # 計算 cycle loss\n", " total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)\n", "\n", " # Total generator loss = adversarial loss + cycle loss + identity loss\n", " gen_g_loss = generator_loss(disc_fake_y)\n", " gen_f_loss = generator_loss(disc_fake_x)\n", "\n", " total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)\n", " total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)\n", "\n", " disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)\n", " disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)\n", "\n", " # 分別計算 gradients\n", " generator_g_gradients = tape.gradient(total_gen_g_loss,\n", " generator_g.trainable_variables)\n", " generator_f_gradients = tape.gradient(total_gen_f_loss,\n", " generator_f.trainable_variables)\n", "\n", " discriminator_x_gradients = tape.gradient(disc_x_loss,\n", " discriminator_x.trainable_variables)\n", " discriminator_y_gradients = tape.gradient(disc_y_loss,\n", " discriminator_y.trainable_variables)\n", "\n", " generator_g_optimizer.apply_gradients(zip(generator_g_gradients,\n", " generator_g.trainable_variables))\n", "\n", " generator_f_optimizer.apply_gradients(zip(generator_f_gradients,\n", " generator_f.trainable_variables))\n", "\n", " discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,\n", " discriminator_x.trainable_variables))\n", "\n", " discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,\n", " discriminator_y.trainable_variables))\n", "\n", " return total_gen_g_loss, total_gen_f_loss, disc_y_loss, disc_x_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HBVmkxBjM7nh" }, "outputs": [], "source": [ "for epoch in range(EPOCHS):\n", " start = time.time()\n", "\n", " n = 0\n", " for image_x, image_y in tf.data.Dataset.zip((train_summer, train_winter)):\n", " s2w_Gloss, w2s_Gloss, y_Dloss, x_Dloss = train_step(image_x, image_y)\n", " if n % 10 == 0:\n", " print('.', end='')\n", " n+=1\n", "\n", " # 利用一樣的照片 (sample_summer) 來觀察模型學習的效果\n", " if (epoch + 1) % 2 == 0:\n", " generate_images(generator_g, sample_summer[tf.newaxis, ...])\n", "\n", " if (epoch + 1) % 5 == 0:\n", " ckpt_save_path = ckpt_manager.save()\n", " print('')\n", " print('Saving checkpoint for epoch {} at {}'.format(epoch+1,\n", " ckpt_save_path))\n", " print('')\n", " print('Summer2Winter G_loss: %.5f, Winter2Summer G_loss: %.5f' % (s2w_Gloss, w2s_Gloss))\n", " print('Summer2Winter D_loss: %.5f, Winter2Summer D_loss: %.5f' % (y_Dloss, x_Dloss))\n", " print('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", " time.time()-start))" ] }, { "cell_type": "markdown", "metadata": { "id": "vKyuDg1bM7nh" }, "source": [ "## Check result(After 200 epoches training, ~12hrs)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qSizuQjAM7nn" }, "outputs": [], "source": [ "# summer to winter\n", "for inp in test_summer.take(5):\n", " generate_images(generator_g, inp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mkj9BpINM7nn" }, "outputs": [], "source": [ "# winter to summer\n", "for inp in test_winter.take(5):\n", " generate_images(generator_f, inp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZCc878Y3M7nn" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [], "gpuType": "T4", "include_colab_link": true }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }