{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\"Open" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Variational Auto Encoder" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Preparations" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Imports and Installs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install plotly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "from plotly import express as px\n", "\n", "import numpy as np\n", "import tensorflow as tf" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "本次使用MNIST數字資料集" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tensorflow import keras\n", "\n", "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n", "mnist_digits = np.concatenate([x_train, x_test], axis=0)\n", "mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Model Construction" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Sampling Cell\n", "我們在課堂上學到,Auto Encoder會將圖片壓縮成較低維的feature map再還原,藉此學習特徵的萃取或壓縮\n", "\n", "\n", "\n", "而VAE會進一步將雜訊加入encode完的feature map中增加這個萃取訓練的穩定性:\n", "\n", "$$\n", "z = \\mu + \\epsilon , \\epsilon \\sim \\mathcal{N}(0,\\sigma)\n", "$$\n", "\n", "但因為Normal Distribution這種取樣的步驟不可微分,所以實際上的做法則是使用 NN encode出兩個latent變數$\\mu$與$\\sigma$,再使用一個sample來的權重$\\epsilon$就可以等效為上面z的算式:\n", "\n", "\n", "\n", "下面我們來實作一下這個算式作為一個Layer\n", "$$\n", "\\epsilon \\sim \\mathcal{N}(0,1)\n", "$$\n", "$$\n", "z = \\mu + \\sigma \\odot \\epsilon\n", "$$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras import layers\n", "\n", "class Sampling(layers.Layer):\n", " \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n", "\n", " def call(self, inputs):\n", " z_mean, z_log_var = inputs\n", " batch = tf.shape(z_mean)[0]\n", " dim = tf.shape(z_mean)[1]\n", " epsilon = tf.random.normal(shape=(batch, dim))\n", " return z_mean + tf.exp(0.5 * z_log_var) * epsilon" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Encoder(keras.Model):\n", " def __init__(self, latent_dim=2):\n", " super().__init__()\n", " self.conv1 = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")\n", " self.conv2 = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")\n", " self.flatten = layers.Flatten()\n", " self.dense1 = layers.Dense(16, activation=\"relu\")\n", " self.z_mean = layers.Dense(latent_dim, name=\"z_mean\")\n", " self.z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")\n", " self.sampling = Sampling()\n", "\n", " def call(self, inputs):\n", " x = self.conv1(inputs)\n", " x = self.conv2(x)\n", " x = self.flatten(x)\n", " x = self.dense1(x)\n", " z_mean = self.z_mean(x)\n", " z_log_var = self.z_log_var(x)\n", " z = self.sampling([z_mean, z_log_var])\n", " return z_mean, z_log_var, z\n", "\n", "latent_dim = 2\n", "encoder = Encoder(latent_dim=latent_dim)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Decoder (The Generator)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "輸入vector (noise),輸出圖片" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Decoder(keras.Model):\n", " def __init__(self, latent_dim=2):\n", " super().__init__()\n", " self.dense1 = layers.Dense(7 * 7 * 64, activation=\"relu\")\n", " self.reshape = layers.Reshape((7, 7, 64))\n", " self.deconv1 = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")\n", " self.deconv2 = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")\n", " self.outputs = layers.Conv2DTranspose(1, 3, activation=\"sigmoid\", padding=\"same\")\n", "\n", " def call(self, inputs):\n", " x = self.dense1(inputs)\n", " x = self.reshape(x)\n", " x = self.deconv1(x)\n", " x = self.deconv2(x)\n", " y = self.outputs(x)\n", " return y\n", "\n", "decoder = Decoder()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training Step" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Model Class with training step" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class VAE(keras.Model):\n", " def __init__(self, encoder, decoder, **kwargs):\n", " super().__init__(**kwargs)\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.total_loss_tracker = keras.metrics.Mean(name=\"total_loss\")\n", " self.reconstruction_loss_tracker = keras.metrics.Mean(\n", " name=\"reconstruction_loss\"\n", " )\n", " self.kl_loss_tracker = keras.metrics.Mean(name=\"kl_loss\")\n", "\n", " @property\n", " def metrics(self):\n", " return [\n", " self.total_loss_tracker,\n", " self.reconstruction_loss_tracker,\n", " self.kl_loss_tracker,\n", " ]\n", "\n", " def train_step(self, data):\n", " with tf.GradientTape() as tape:\n", " z_mean, z_log_var, z = self.encoder(data)\n", " reconstruction = self.decoder(z)\n", " reconstruction_loss = tf.reduce_mean(\n", " tf.reduce_sum(\n", " keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)\n", " )\n", " )\n", " kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n", " kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n", " total_loss = reconstruction_loss + kl_loss\n", " grads = tape.gradient(total_loss, self.trainable_weights)\n", " self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n", " self.total_loss_tracker.update_state(total_loss)\n", " self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n", " self.kl_loss_tracker.update_state(kl_loss)\n", " return {\n", " \"loss\": self.total_loss_tracker.result(),\n", " \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n", " \"kl_loss\": self.kl_loss_tracker.result(),\n", " }\n", "\n", " def test_step(self, data):\n", " z_mean, z_log_var, z = self.encoder(data)\n", " reconstruction = self.decoder(z)\n", " reconstruction_loss = tf.reduce_mean(\n", " tf.reduce_sum(\n", " keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)\n", " )\n", " )\n", " kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n", " kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n", " total_loss = reconstruction_loss + kl_loss\n", "\n", " self.total_loss_tracker.update_state(total_loss)\n", " self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n", " self.kl_loss_tracker.update_state(kl_loss)\n", " \n", " return {\n", " \"loss\": self.total_loss_tracker.result(),\n", " \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n", " \"kl_loss\": self.kl_loss_tracker.result(),\n", " }\n", "\n", "vae = VAE(encoder, decoder)\n", "vae.compile(optimizer=keras.optimizers.Adam())" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "看一下我們的兩個部分:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Encoder, 負責將圖片對應的Latent vector產出,方便訓練latent vector到圖的 generation\n", "encoder.build(input_shape=(None, 28, 28, 1))\n", "encoder.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Decoder, 我們欲訓練的生成器,給予一個 vector 生一張圖片\n", "decoder.build(input_shape=(None, latent_dim))\n", "decoder.summary()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Training Start" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "訓練將壓縮至latent space的資料還原回圖片,我們以這種還原訓練圖片生成的能力" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "history = vae.fit(mnist_digits, epochs=20, batch_size=128, validation_split=0.2)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate for convergence\n", "訓練完成後看看 loss 是否收斂" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Show History" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(history.history[\"val_loss\"])\n", "plt.plot(history.history[\"loss\"])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Output Observation" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder Latent Space\n", "在Encoder latent space中我們可以觀察到我們訓練出來的Encoder究竟將影像embed成什麼樣子的分布" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import plotly.express as px\n", "\n", "def plot_label_clusters(vae, data, labels):\n", " # Display a 2D plot of the digit classes in the latent space\n", " z_mean, _, _ = vae.encoder.predict(data, verbose=False)\n", "\n", " fig = px.scatter(x=z_mean[:, 0], y=z_mean[:, 1], color=labels)\n", " fig.update_layout(\n", " width=600,\n", " height=500,\n", " xaxis_title=\"z[0]\",\n", " yaxis_title=\"z[1]\",\n", " coloraxis_colorbar_title=\"Labels\"\n", " )\n", " fig.show()\n", "\n", "\n", "\n", "_ , (x_test, y_test) = keras.datasets.mnist.load_data()\n", "x_test = np.expand_dims(x_test, -1).astype(\"float32\") / 255\n", "\n", "plot_label_clusters(vae, x_test, y_test)\n", "\n", "# Output 出來每個class的分布都不一樣,可看出大致上在latent space (z[0],z[1]) 中不同數字佔不同區塊,分布也有些差異" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Decode Latent Space\n", "\n", "Decoder 負責將latent space中的vector還原到image space中,我們來看看space中每個位置還原回來長什麼樣子" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import plotly.graph_objects as go\n", "\n", "def plot_latent_space(vae, n=20, figsize=600):\n", " # Display an n*n 2D manifold of digits\n", " digit_size = 28\n", " scale = 1.0\n", " figure = np.zeros((digit_size * n, digit_size * n))\n", " # Linearly spaced coordinates corresponding to the 2D plot\n", " # of digit classes in the latent space\n", " grid_x = np.linspace(-scale, scale, n)\n", " grid_y = np.linspace(-scale, scale, n)[::-1]\n", "\n", " for i, yi in enumerate(grid_y):\n", " for j, xi in enumerate(grid_x):\n", " z_sample = np.array([[xi, yi]])\n", " x_decoded = vae.decoder.predict(z_sample, verbose=False)\n", " digit = x_decoded[0].reshape(digit_size, digit_size)\n", " figure[\n", " i * digit_size : (i + 1) * digit_size,\n", " j * digit_size : (j + 1) * digit_size,\n", " ] = digit\n", "\n", " start_range = digit_size // 2\n", " end_range = n * digit_size + start_range\n", " pixel_range = np.arange(start_range, end_range, digit_size)\n", " sample_range_x = np.round(grid_x, 1)\n", " sample_range_y = np.round(grid_y, 1)\n", "\n", " fig = go.Figure(data=go.Heatmap(z=figure[::-1], colorscale=\"gray\"))\n", " fig.update_xaxes(tickvals=pixel_range, ticktext=sample_range_x)\n", " fig.update_yaxes(tickvals=pixel_range, ticktext=sample_range_y)\n", " fig.update_layout(\n", " width=figsize,\n", " height=figsize,\n", " xaxis_title=\"z[0]\",\n", " yaxis_title=\"z[1]\",\n", " )\n", " fig.show()\n", "plot_latent_space(vae)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }