{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q4-heyrYM7nL"
},
"source": [
"# Cycle-GAN"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ls9dpE03M7nR"
},
"source": [
"### 本章節內容大綱\n",
"* [Build a CyclaGAN Model](#Build-a-CyclaGAN-Model)\n",
"* [Cycle Loss](#Cycle-Loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS_1kw3JM7nS"
},
"source": [
"
\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eY1k3YonM7nS"
},
"source": [
"相較於conditional GAN需要一組的paired data來訓練模型,cycle GAN並不需要paired data,或是說,當我們無法取得paired data的時候,就可以使用cycle GAN。(像是要把照片的風格轉換成莫內的風格,但是莫內早已不在人世,我們就沒辦法得到對應於照片的畫作)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "39eA5wuWM7nT"
},
"source": [
"## Import"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZYbHgsapM7nV"
},
"outputs": [],
"source": [
"# 使用 instance normalization 所需要的套件\n",
"!pip install tensorflow-addons"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EENUvKkkM7nW"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras import Model\n",
"from tensorflow.keras.layers import (\n",
" Input, Conv2DTranspose, Conv2D, BatchNormalization,\n",
" ReLU, LeakyReLU, Dropout, Reshape, Activation, add, concatenate\n",
")\n",
"from tensorflow.keras.initializers import glorot_uniform\n",
"from tensorflow_addons.layers import InstanceNormalization\n",
"\n",
"import os\n",
"import time\n",
"import numpy as np\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"\n",
"AUTOTUNE = tf.data.experimental.AUTOTUNE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eekG90CsM7nY"
},
"outputs": [],
"source": [
"len(tf.config.experimental.list_physical_devices('GPU'))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fOxSQhLOM7nY"
},
"source": [
"## Preparing data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_8T5PPccM7nZ"
},
"outputs": [],
"source": [
"# 上傳資料\n",
"!wget -q https://github.com/TA-aiacademy/course_3.0/releases/download/v2.5_gan/GAN_part3.zip\n",
"!unzip -q GAN_part3.zip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eIBDIZs9M7nZ"
},
"outputs": [],
"source": [
"CLASS_MAP = {'trainA':0, 'trainB':1, 'testA':0, 'testB':1}\n",
"def paths2labels(paths):\n",
" return [CLASS_MAP[p.split(os.sep)[-2]] for p in paths]\n",
"\n",
"# 影像讀取 & resize\n",
"def load_image(path):\n",
" image = tf.io.read_file(path)\n",
" image = tf.image.decode_jpeg(image, channels=3)\n",
" image = tf.image.resize(image, [256, 256])\n",
" return image\n",
"\n",
"# 使用路徑建構 tf.data.Dataset\n",
"def build_ds(paths):\n",
" labels = paths2labels(paths) # paths -> labels\n",
" image_ds = tf.data.Dataset.from_tensor_slices((paths, labels))\n",
" image_ds = image_ds.map(lambda path, label: (load_image(path), label)) # path -> img, labels\n",
" image_ds = image_ds.prefetch(AUTOTUNE)\n",
" return image_ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A6Gxe_lxM7nZ"
},
"outputs": [],
"source": [
"image_dir = glob.glob('summer2winter_yosemite/*/*.jpg')\n",
"metadata_dict = dict()\n",
"\n",
"for i in image_dir:\n",
" _, dirs, files = i.split('/')\n",
" if dirs not in metadata_dict:\n",
" metadata_dict[dirs] = [i]\n",
" else:\n",
" metadata_dict[dirs].append(i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mv25dbWIM7na"
},
"outputs": [],
"source": [
"train_summer, train_winter = build_ds(metadata_dict['trainA']), build_ds(metadata_dict['trainB'])\n",
"test_summer, test_winter = build_ds(metadata_dict['testA']), build_ds(metadata_dict['testB'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5u2xB4PRM7na"
},
"outputs": [],
"source": [
"BUFFER_SIZE = 1000\n",
"BATCH_SIZE = 4\n",
"IMG_WIDTH = 256\n",
"IMG_HEIGHT = 256\n",
"\n",
"LAMBDA_cycle = 10 # cycle loss 的權重\n",
"LAMBDA_identity = 0.5 # identity loss 的權重\n",
"\n",
"EPOCHS = 5\n",
"epoch_decay = 100 # 在 100 個 epoches 後 weight decay"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lAFlRHisM7nb"
},
"outputs": [],
"source": [
"def random_crop(image):\n",
" cropped_image = tf.image.random_crop(\n",
" image, size=[IMG_HEIGHT, IMG_WIDTH, 3])\n",
"\n",
" return cropped_image\n",
"\n",
"\n",
"# 將圖片正規化到 [-1, 1]\n",
"def normalize(image):\n",
" image = tf.cast(image, tf.float32)\n",
" image = (image / 127.5) - 1\n",
" return image\n",
"\n",
"\n",
"def random_jitter(image):\n",
"\n",
" # resize 到 286 x 286 x 3\n",
" image = tf.image.resize(image, [286, 286],\n",
" method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" # 隨機 crop 到 256 x 256 x 3\n",
" image = random_crop(image)\n",
"\n",
" # 隨機翻轉\n",
" image = tf.image.random_flip_left_right(image)\n",
"\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Isdgvj7fM7nb"
},
"outputs": [],
"source": [
"def preprocess_image_train(image, label):\n",
" # 'label' 這個參數是為了接原本圖片有多一個label的維度(這邊不需要使用label)\n",
" image = random_jitter(image)\n",
" image = normalize(image)\n",
" return image\n",
"\n",
"\n",
"def preprocess_image_test(image, label):\n",
" image = normalize(image)\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GQSmMKewM7nb"
},
"outputs": [],
"source": [
"# cache 可以將 data 先讀入 memory 加快速度\n",
"# num_parallel_calls 是一次準備多少圖片一起處理,AUTOTUNE 是我們前面有定義的變數,他可以最佳化到底要讀多少圖的這個參數\n",
"\n",
"train_summer = train_summer.map(\n",
" preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
" BUFFER_SIZE).batch(BATCH_SIZE)\n",
"\n",
"train_winter = train_winter.map(\n",
" preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
" BUFFER_SIZE).batch(BATCH_SIZE)\n",
"\n",
"test_summer = test_summer.map(\n",
" preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
" BUFFER_SIZE).batch(1)\n",
"\n",
"test_winter = test_winter.map(\n",
" preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
" BUFFER_SIZE).batch(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mdv8dYj6M7nc"
},
"outputs": [],
"source": [
"sample_summer = next(iter(train_summer))[0]\n",
"sample_winter = next(iter(train_winter))[0]\n",
"\n",
"plt.figure(figsize=(15,15))\n",
"plt.subplot(2,2,1)\n",
"plt.title('Yosemite in summer')\n",
"plt.imshow(sample_summer * 0.5 + 0.5)\n",
"\n",
"plt.subplot(2,2,2)\n",
"plt.title('Yosemite in summer with random jitter')\n",
"plt.imshow(random_jitter(sample_summer) * 0.5 + 0.5)\n",
"\n",
"plt.subplot(2,2,3)\n",
"plt.title('Yosemite in winter')\n",
"plt.imshow(sample_winter * 0.5 + 0.5)\n",
"\n",
"plt.subplot(2,2,4)\n",
"plt.title('Yosemite in winter with random jitter')\n",
"plt.imshow(random_jitter(sample_winter) * 0.5 + 0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hPjKtW11M7nc"
},
"source": [
"## Build a CycleGAN Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QF5zqpBxM7nc"
},
"source": [
"與Pix2Pix不一樣,CycleGAN因為是unpaired data的訓練,所以需要各兩個Generator與Discriminator一起訓練,\n",
"\n",
"兩個Generators : G 與 F ,的任務分別是將X Domain轉成Y Domain與將Y Domain轉成X Domain,\n",
"\n",
"而Discriminators的任務也一樣,要判別這兩個Domain的資料到底是真實的還是生成的。\n",
"##### 模型的結構上與Pix2Pix不同的部分有:\n",
"- Normalize的部分使用了Instance Normalization而非Batch Normalization(因為Batch_size = 1)\n",
"- 整體的結構使用了resnet為基底的Generator而非Unet\n",
"\n",
"\n",
"
\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DQ5fOMdIM7nd"
},
"outputs": [],
"source": [
"def identity_block(X, f, filters):\n",
"\n",
" # 儲存 input 來作 skip connection\n",
" X_shortcut = X\n",
"\n",
" X = tf.pad(X, [[0, 0], [1, 1], [1, 1], [0, 0]], \"REFLECT\")\n",
" X = Conv2D(filters=filters, kernel_size=(f, f), strides=(1, 1),\n",
" padding='valid', use_bias=False)(X)\n",
" X = InstanceNormalization()(X)\n",
" X = ReLU()(X)\n",
"\n",
" X = tf.pad(X, [[0, 0], [1, 1], [1, 1], [0, 0]], \"REFLECT\")\n",
" X = Conv2D(filters=filters, kernel_size=(f, f), strides=(1, 1),\n",
" padding='valid', use_bias=False)(X)\n",
" X = InstanceNormalization()(X)\n",
"\n",
" # 將 output 與 input 合在一起\n",
" X = add([X, X_shortcut])\n",
"\n",
" return X\n",
"\n",
"\n",
"def Conv(X, f, s, filters, leaky=False, padding='same'):\n",
"\n",
" if padding == 'same':\n",
" conv = Conv2D(filters, (f, f), padding='same', use_bias=False, strides=(s, s))(X)\n",
" elif padding == 'valid':\n",
" conv = Conv2D(filters, (f, f), padding='valid', use_bias=False, strides=(s, s))(X)\n",
"\n",
" conv = InstanceNormalization()(conv)\n",
"\n",
" if leaky:\n",
" conv = LeakyReLU(alpha=0.2)(conv)\n",
" else:\n",
" conv = ReLU()(conv)\n",
" return conv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6I7bTyFyM7nd"
},
"outputs": [],
"source": [
"def ResGen(input_size=(256, 256, 3), n_blocks=9):\n",
"\n",
" inputs = Input(input_size)\n",
" # ======================Encoder======================\n",
"\n",
" # batch 與 channel 不需要 padding\n",
" X = tf.pad(inputs, [[0, 0], [3, 3], [3, 3], [0, 0]], \"REFLECT\")\n",
"\n",
" X = Conv(X, 7, 1, 64, leaky=False, padding='valid') # (256,256,64)\n",
"\n",
" X = Conv(X, 3, 2, 128, leaky=False) # (128,128,128)\n",
"\n",
" X = Conv(X, 3, 2, 256, leaky=False) # (64,64,256)\n",
"\n",
" # ====================Encoder end====================\n",
"\n",
" for i in range(n_blocks):\n",
" X = identity_block(X, 3, 256)\n",
"\n",
" # ====================Decoder========================\n",
"\n",
" X = Conv2DTranspose(128, (3, 3), strides=(2, 2),\n",
" padding='same', use_bias=False)(X) # (128,128,128)\n",
" X = InstanceNormalization()(X)\n",
" X = ReLU()(X)\n",
"\n",
" X = Conv2DTranspose(64, (3, 3), strides=(2, 2),\n",
" padding='same', use_bias=False)(X) # (256,256,64)\n",
" X = InstanceNormalization()(X)\n",
" X = ReLU()(X)\n",
"\n",
" X = tf.pad(X, [[0, 0], [3, 3], [3, 3], [0, 0]], \"REFLECT\")\n",
" X = Conv2D(3, (7, 7), padding='valid', activation='tanh')(X) # (256,256,3)\n",
"\n",
" # ======================Decoder end==========================\n",
"\n",
" return Model(inputs=[inputs], outputs=[X])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zqQ5pEpQM7nd"
},
"outputs": [],
"source": [
"def Disc(input_size=(256, 256, 3), n_blocks=1):\n",
" inputs = Input(input_size)\n",
"\n",
" # ======================Encoder======================\n",
"\n",
" X = Conv2D(64, (4, 4), strides=2, padding='same')(inputs)\n",
" X = LeakyReLU(alpha=0.2)(X)\n",
"\n",
" X = Conv(X, 4, 2, 128, leaky=True)\n",
"\n",
" X = Conv(X, 4, 2, 256, leaky=True)\n",
"\n",
" X = Conv(X, 4, 1, 512, leaky=True)\n",
"\n",
" # ======================Encoder end======================\n",
"\n",
" X = Conv2D(1, (4, 4), padding='same', activation=None)(X)\n",
"\n",
" return Model(inputs=[inputs], outputs=[X])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZpSdWWCBM7nd"
},
"outputs": [],
"source": [
"generator_g = ResGen()\n",
"generator_f = ResGen()\n",
"\n",
"discriminator_x = Disc()\n",
"discriminator_y = Disc()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6ITK1Az4M7ne"
},
"outputs": [],
"source": [
"# 檢查 Generator 的 output\n",
"to_winter = generator_g(sample_summer[tf.newaxis, ...])[0]\n",
"to_summer = generator_f(sample_winter[tf.newaxis, ...])[0]\n",
"plt.figure(figsize=(15, 15))\n",
"contrast = 8\n",
"\n",
"imgs = [sample_summer, to_winter, sample_winter, to_summer]\n",
"title = ['Summer', 'To Winter', 'Winter', 'To Summer']\n",
"\n",
"# 畫出 generator 的 output\n",
"for i in range(len(imgs)):\n",
" plt.subplot(2, 2, i+1)\n",
" plt.title(title[i])\n",
" if i % 2 == 0:\n",
" plt.imshow(imgs[i] * 0.5 + 0.5) # rescale value to [0-1]\n",
" else:\n",
" plt.imshow(imgs[i] * 0.5 * contrast + 0.5)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnfQaH0OM7ne"
},
"outputs": [],
"source": [
"# 檢查 Discriminator 的 output\n",
"# 越紅越接近 0,越藍越接近 1\n",
"plt.figure(figsize=(15, 15))\n",
"\n",
"plt.subplot(121)\n",
"plt.title('Is a real summer?')\n",
"plt.imshow(discriminator_x(sample_summer[tf.newaxis, ...])[0, ..., -1], cmap='RdBu_r')\n",
"\n",
"plt.subplot(122)\n",
"plt.title('Is a real winter?')\n",
"plt.imshow(discriminator_y(sample_winter[tf.newaxis, ...])[0, ..., -1], cmap='RdBu_r')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fphy6kGBM7ne"
},
"source": [
"## Cycle Loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DluQgPviM7nf"
},
"source": [
"在CycleGAN裡面,由於沒有paired data的緣故,所以我們無法保證從generator output出來的圖片是原圖轉換風格還是只是隨便一張能騙過Discriminator的圖。所以為了要強制generator能夠學到另一個風格並且還能維持原本的圖片語意,作者提出了 cycle consistency loss。簡單來說,就是把input $X$ 丟進 $X\\rightarrow Y$ 的generator再丟進 $Y\\rightarrow X$ 的generator最後得到 $\\hat{X}$ ,計算 $X$ 與 $\\hat{X}$ 之間的差異就是cycle consistency loss。\n",
"\n",
"
\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xv9nYb4QM7nf"
},
"outputs": [],
"source": [
"# 這裡使用 lsgan 的 loss function\n",
"loss_obj = tf.losses.MeanSquaredError()\n",
"\n",
"\n",
"def discriminator_loss(real, generated):\n",
" real_loss = loss_obj(tf.ones_like(real), real)\n",
"\n",
" generated_loss = loss_obj(tf.zeros_like(generated), generated)\n",
"\n",
" total_disc_loss = real_loss + generated_loss\n",
"\n",
" return total_disc_loss\n",
"\n",
"\n",
"def generator_loss(generated):\n",
" return loss_obj(tf.ones_like(generated), generated)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WSch1u07M7nf"
},
"outputs": [],
"source": [
"# Cycle loss 是將 X 轉為 fake_Y 後,再將 fake_Y 轉回 X_cycle\n",
"# 計算 X 與 X_cycle 的差異來得到 cycle consistency loss\n",
"def calc_cycle_loss(real_image, cycled_image):\n",
" loss = tf.reduce_mean(tf.abs(real_image - cycled_image))\n",
"\n",
" return LAMBDA_cycle * loss\n",
"\n",
"\n",
"# Identity loss 即將 X 丟進 Y -> X 的 generator 後得到的 X_hat,應該也要跟 X 一樣而非胡亂轉換,\n",
"# 效果類似提供一個 identity mapping 的參考\n",
"def identity_loss(real_image, same_image):\n",
" loss = tf.reduce_mean(tf.abs(real_image - same_image))\n",
" return LAMBDA_identity * loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HP804DRuM7nf"
},
"outputs": [],
"source": [
"class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):\n",
" # 如果目前的 step < step_dacay 就使用原本的 learning rate\n",
" # 否則會以線性的衰退來挑整 learning rate 直至 0\n",
"\n",
" def __init__(self, initial_learning_rate, total_steps, step_decay):\n",
" super(LinearDecay, self).__init__()\n",
" self._initial_learning_rate = initial_learning_rate\n",
" self._steps = total_steps\n",
" self._step_decay = step_decay\n",
" self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate,\n",
" trainable=False, dtype=tf.float32)\n",
"\n",
" def __call__(self, step):\n",
" self.current_learning_rate.assign(tf.cond(\n",
" step >= self._step_decay,\n",
" true_fn=lambda: self._initial_learning_rate * (1 - 1 / (self._steps - self._step_decay) *\n",
" (step - self._step_decay)),\n",
" false_fn=lambda: self._initial_learning_rate\n",
" ))\n",
" return self.current_learning_rate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4CCs7IJtM7ng"
},
"outputs": [],
"source": [
"lr_scheduler = LinearDecay(2e-4, EPOCHS * BUFFER_SIZE, epoch_decay * BUFFER_SIZE)\n",
"\n",
"generator_g_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)\n",
"generator_f_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)\n",
"\n",
"discriminator_x_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)\n",
"discriminator_y_optimizer = tf.keras.optimizers.legacy.Adam(lr_scheduler, beta_1=0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Oa6bsdRBM7ng"
},
"outputs": [],
"source": [
"checkpoint_path = \"./training_checkpoints\"\n",
"\n",
"ckpt = tf.train.Checkpoint(generator_g=generator_g,\n",
" generator_f=generator_f,\n",
" discriminator_x=discriminator_x,\n",
" discriminator_y=discriminator_y,\n",
" generator_g_optimizer=generator_g_optimizer,\n",
" generator_f_optimizer=generator_f_optimizer,\n",
" discriminator_x_optimizer=discriminator_x_optimizer,\n",
" discriminator_y_optimizer=discriminator_y_optimizer)\n",
"\n",
"ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) # 最多存 5 組 weight\n",
"\n",
"# 如果有 checkpoint ,載入其權重\n",
"if ckpt_manager.latest_checkpoint:\n",
" ckpt.restore(ckpt_manager.latest_checkpoint)\n",
" print('Latest checkpoint restored!!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u08a23oCM7ng"
},
"outputs": [],
"source": [
"def generate_images(model, test_input):\n",
" prediction = model(test_input)\n",
"\n",
" plt.figure(figsize=(12, 12))\n",
"\n",
" display_list = [test_input[0], prediction[0]]\n",
" title = ['Input Image', 'Predicted Image']\n",
"\n",
" for i in range(2):\n",
" plt.subplot(1, 2, i+1)\n",
" plt.title(title[i])\n",
" plt.imshow(display_list[i] * 0.5 + 0.5)\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NYAzHdbcM7ng"
},
"outputs": [],
"source": [
"@tf.function\n",
"def train_step(real_x, real_y):\n",
"\n",
" # 當 tape 被拿來計算 graident 就會被記憶體釋放,persistent 可以避免被釋放\n",
" # 通常用於計算多個 gradient 時使用\n",
" with tf.GradientTape(persistent=True) as tape:\n",
"\n",
" # Generator G translates X -> Y\n",
" # Generator F translates Y -> X\n",
"\n",
" fake_y = generator_g(real_x, training=True)\n",
" cycled_x = generator_f(fake_y, training=True)\n",
"\n",
" fake_x = generator_f(real_y, training=True)\n",
" cycled_y = generator_g(fake_x, training=True)\n",
"\n",
" # output discriminator logits\n",
" disc_real_x = discriminator_x(real_x, training=True)\n",
" disc_real_y = discriminator_y(real_y, training=True)\n",
"\n",
" disc_fake_x = discriminator_x(fake_x, training=True)\n",
" disc_fake_y = discriminator_y(fake_y, training=True)\n",
"\n",
" # same_x 、 same_y 用來計算 identity loss\n",
" same_x = generator_f(real_x, training=True)\n",
" same_y = generator_g(real_y, training=True)\n",
"\n",
" # 計算 cycle loss\n",
" total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)\n",
"\n",
" # Total generator loss = adversarial loss + cycle loss + identity loss\n",
" gen_g_loss = generator_loss(disc_fake_y)\n",
" gen_f_loss = generator_loss(disc_fake_x)\n",
"\n",
" total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)\n",
" total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)\n",
"\n",
" disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)\n",
" disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)\n",
"\n",
" # 分別計算 gradients\n",
" generator_g_gradients = tape.gradient(total_gen_g_loss,\n",
" generator_g.trainable_variables)\n",
" generator_f_gradients = tape.gradient(total_gen_f_loss,\n",
" generator_f.trainable_variables)\n",
"\n",
" discriminator_x_gradients = tape.gradient(disc_x_loss,\n",
" discriminator_x.trainable_variables)\n",
" discriminator_y_gradients = tape.gradient(disc_y_loss,\n",
" discriminator_y.trainable_variables)\n",
"\n",
" generator_g_optimizer.apply_gradients(zip(generator_g_gradients,\n",
" generator_g.trainable_variables))\n",
"\n",
" generator_f_optimizer.apply_gradients(zip(generator_f_gradients,\n",
" generator_f.trainable_variables))\n",
"\n",
" discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,\n",
" discriminator_x.trainable_variables))\n",
"\n",
" discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,\n",
" discriminator_y.trainable_variables))\n",
"\n",
" return total_gen_g_loss, total_gen_f_loss, disc_y_loss, disc_x_loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HBVmkxBjM7nh"
},
"outputs": [],
"source": [
"for epoch in range(EPOCHS):\n",
" start = time.time()\n",
"\n",
" n = 0\n",
" for image_x, image_y in tf.data.Dataset.zip((train_summer, train_winter)):\n",
" s2w_Gloss, w2s_Gloss, y_Dloss, x_Dloss = train_step(image_x, image_y)\n",
" if n % 10 == 0:\n",
" print('.', end='')\n",
" n+=1\n",
"\n",
" # 利用一樣的照片 (sample_summer) 來觀察模型學習的效果\n",
" if (epoch + 1) % 2 == 0:\n",
" generate_images(generator_g, sample_summer[tf.newaxis, ...])\n",
"\n",
" if (epoch + 1) % 5 == 0:\n",
" ckpt_save_path = ckpt_manager.save()\n",
" print('')\n",
" print('Saving checkpoint for epoch {} at {}'.format(epoch+1,\n",
" ckpt_save_path))\n",
" print('')\n",
" print('Summer2Winter G_loss: %.5f, Winter2Summer G_loss: %.5f' % (s2w_Gloss, w2s_Gloss))\n",
" print('Summer2Winter D_loss: %.5f, Winter2Summer D_loss: %.5f' % (y_Dloss, x_Dloss))\n",
" print('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
" time.time()-start))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vKyuDg1bM7nh"
},
"source": [
"## Check result(After 200 epoches training, ~12hrs)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qSizuQjAM7nn"
},
"outputs": [],
"source": [
"# summer to winter\n",
"for inp in test_summer.take(5):\n",
" generate_images(generator_g, inp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mkj9BpINM7nn"
},
"outputs": [],
"source": [
"# winter to summer\n",
"for inp in test_winter.take(5):\n",
" generate_images(generator_f, inp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZCc878Y3M7nn"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}