{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Variational Auto Encoder"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparations"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports and Installs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install plotly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from plotly import express as px\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"本次使用MNIST數字資料集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow import keras\n",
"\n",
"(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
"mnist_digits = np.concatenate([x_train, x_test], axis=0)\n",
"mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Construction"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sampling Cell\n",
"我們在課堂上學到,Auto Encoder會將圖片壓縮成較低維的feature map再還原,藉此學習特徵的萃取或壓縮\n",
"\n",
"
\n",
"\n",
"而VAE會進一步將雜訊加入encode完的feature map中增加這個萃取訓練的穩定性:\n",
"\n",
"$$\n",
"z = \\mu + \\epsilon , \\epsilon \\sim \\mathcal{N}(0,\\sigma)\n",
"$$\n",
"\n",
"但因為Normal Distribution這種取樣的步驟不可微分,所以實際上的做法則是使用 NN encode出兩個latent變數$\\mu$與$\\sigma$,再使用一個sample來的權重$\\epsilon$就可以等效為上面z的算式:\n",
"\n",
"
\n",
"\n",
"下面我們來實作一下這個算式作為一個Layer\n",
"$$\n",
"\\epsilon \\sim \\mathcal{N}(0,1)\n",
"$$\n",
"$$\n",
"z = \\mu + \\sigma \\odot \\epsilon\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import layers\n",
"\n",
"class Sampling(layers.Layer):\n",
" \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
"\n",
" def call(self, inputs):\n",
" z_mean, z_log_var = inputs\n",
" batch = tf.shape(z_mean)[0]\n",
" dim = tf.shape(z_mean)[1]\n",
" epsilon = tf.random.normal(shape=(batch, dim))\n",
" return z_mean + tf.exp(0.5 * z_log_var) * epsilon"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Encoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Encoder(keras.Model):\n",
" def __init__(self, latent_dim=2):\n",
" super().__init__()\n",
" self.conv1 = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")\n",
" self.conv2 = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")\n",
" self.flatten = layers.Flatten()\n",
" self.dense1 = layers.Dense(16, activation=\"relu\")\n",
" self.z_mean = layers.Dense(latent_dim, name=\"z_mean\")\n",
" self.z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")\n",
" self.sampling = Sampling()\n",
"\n",
" def call(self, inputs):\n",
" x = self.conv1(inputs)\n",
" x = self.conv2(x)\n",
" x = self.flatten(x)\n",
" x = self.dense1(x)\n",
" z_mean = self.z_mean(x)\n",
" z_log_var = self.z_log_var(x)\n",
" z = self.sampling([z_mean, z_log_var])\n",
" return z_mean, z_log_var, z\n",
"\n",
"latent_dim = 2\n",
"encoder = Encoder(latent_dim=latent_dim)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Decoder (The Generator)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"輸入vector (noise),輸出圖片"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Decoder(keras.Model):\n",
" def __init__(self, latent_dim=2):\n",
" super().__init__()\n",
" self.dense1 = layers.Dense(7 * 7 * 64, activation=\"relu\")\n",
" self.reshape = layers.Reshape((7, 7, 64))\n",
" self.deconv1 = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")\n",
" self.deconv2 = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")\n",
" self.outputs = layers.Conv2DTranspose(1, 3, activation=\"sigmoid\", padding=\"same\")\n",
"\n",
" def call(self, inputs):\n",
" x = self.dense1(inputs)\n",
" x = self.reshape(x)\n",
" x = self.deconv1(x)\n",
" x = self.deconv2(x)\n",
" y = self.outputs(x)\n",
" return y\n",
"\n",
"decoder = Decoder()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Step"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model Class with training step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class VAE(keras.Model):\n",
" def __init__(self, encoder, decoder, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.total_loss_tracker = keras.metrics.Mean(name=\"total_loss\")\n",
" self.reconstruction_loss_tracker = keras.metrics.Mean(\n",
" name=\"reconstruction_loss\"\n",
" )\n",
" self.kl_loss_tracker = keras.metrics.Mean(name=\"kl_loss\")\n",
"\n",
" @property\n",
" def metrics(self):\n",
" return [\n",
" self.total_loss_tracker,\n",
" self.reconstruction_loss_tracker,\n",
" self.kl_loss_tracker,\n",
" ]\n",
"\n",
" def train_step(self, data):\n",
" with tf.GradientTape() as tape:\n",
" z_mean, z_log_var, z = self.encoder(data)\n",
" reconstruction = self.decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.reduce_sum(\n",
" keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)\n",
" )\n",
" )\n",
" kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n",
" kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n",
" total_loss = reconstruction_loss + kl_loss\n",
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
" self.total_loss_tracker.update_state(total_loss)\n",
" self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n",
" self.kl_loss_tracker.update_state(kl_loss)\n",
" return {\n",
" \"loss\": self.total_loss_tracker.result(),\n",
" \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
" \"kl_loss\": self.kl_loss_tracker.result(),\n",
" }\n",
"\n",
" def test_step(self, data):\n",
" z_mean, z_log_var, z = self.encoder(data)\n",
" reconstruction = self.decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.reduce_sum(\n",
" keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)\n",
" )\n",
" )\n",
" kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n",
" kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n",
" total_loss = reconstruction_loss + kl_loss\n",
"\n",
" self.total_loss_tracker.update_state(total_loss)\n",
" self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n",
" self.kl_loss_tracker.update_state(kl_loss)\n",
" \n",
" return {\n",
" \"loss\": self.total_loss_tracker.result(),\n",
" \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
" \"kl_loss\": self.kl_loss_tracker.result(),\n",
" }\n",
"\n",
"vae = VAE(encoder, decoder)\n",
"vae.compile(optimizer=keras.optimizers.Adam())"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"看一下我們的兩個部分:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Encoder, 負責將圖片對應的Latent vector產出,方便訓練latent vector到圖的 generation\n",
"encoder.build(input_shape=(None, 28, 28, 1))\n",
"encoder.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Decoder, 我們欲訓練的生成器,給予一個 vector 生一張圖片\n",
"decoder.build(input_shape=(None, latent_dim))\n",
"decoder.summary()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Start"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"訓練將壓縮至latent space的資料還原回圖片,我們以這種還原訓練圖片生成的能力"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"history = vae.fit(mnist_digits, epochs=20, batch_size=128, validation_split=0.2)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate for convergence\n",
"訓練完成後看看 loss 是否收斂"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Show History"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(history.history[\"val_loss\"])\n",
"plt.plot(history.history[\"loss\"])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Output Observation"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Encoder Latent Space\n",
"在Encoder latent space中我們可以觀察到我們訓練出來的Encoder究竟將影像embed成什麼樣子的分布"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly.express as px\n",
"\n",
"def plot_label_clusters(vae, data, labels):\n",
" # Display a 2D plot of the digit classes in the latent space\n",
" z_mean, _, _ = vae.encoder.predict(data, verbose=False)\n",
"\n",
" fig = px.scatter(x=z_mean[:, 0], y=z_mean[:, 1], color=labels)\n",
" fig.update_layout(\n",
" width=600,\n",
" height=500,\n",
" xaxis_title=\"z[0]\",\n",
" yaxis_title=\"z[1]\",\n",
" coloraxis_colorbar_title=\"Labels\"\n",
" )\n",
" fig.show()\n",
"\n",
"\n",
"\n",
"_ , (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"x_test = np.expand_dims(x_test, -1).astype(\"float32\") / 255\n",
"\n",
"plot_label_clusters(vae, x_test, y_test)\n",
"\n",
"# Output 出來每個class的分布都不一樣,可看出大致上在latent space (z[0],z[1]) 中不同數字佔不同區塊,分布也有些差異"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Decode Latent Space\n",
"\n",
"Decoder 負責將latent space中的vector還原到image space中,我們來看看space中每個位置還原回來長什麼樣子"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"\n",
"def plot_latent_space(vae, n=20, figsize=600):\n",
" # Display an n*n 2D manifold of digits\n",
" digit_size = 28\n",
" scale = 1.0\n",
" figure = np.zeros((digit_size * n, digit_size * n))\n",
" # Linearly spaced coordinates corresponding to the 2D plot\n",
" # of digit classes in the latent space\n",
" grid_x = np.linspace(-scale, scale, n)\n",
" grid_y = np.linspace(-scale, scale, n)[::-1]\n",
"\n",
" for i, yi in enumerate(grid_y):\n",
" for j, xi in enumerate(grid_x):\n",
" z_sample = np.array([[xi, yi]])\n",
" x_decoded = vae.decoder.predict(z_sample, verbose=False)\n",
" digit = x_decoded[0].reshape(digit_size, digit_size)\n",
" figure[\n",
" i * digit_size : (i + 1) * digit_size,\n",
" j * digit_size : (j + 1) * digit_size,\n",
" ] = digit\n",
"\n",
" start_range = digit_size // 2\n",
" end_range = n * digit_size + start_range\n",
" pixel_range = np.arange(start_range, end_range, digit_size)\n",
" sample_range_x = np.round(grid_x, 1)\n",
" sample_range_y = np.round(grid_y, 1)\n",
"\n",
" fig = go.Figure(data=go.Heatmap(z=figure[::-1], colorscale=\"gray\"))\n",
" fig.update_xaxes(tickvals=pixel_range, ticktext=sample_range_x)\n",
" fig.update_yaxes(tickvals=pixel_range, ticktext=sample_range_y)\n",
" fig.update_layout(\n",
" width=figsize,\n",
" height=figsize,\n",
" xaxis_title=\"z[0]\",\n",
" yaxis_title=\"z[1]\",\n",
" )\n",
" fig.show()\n",
"plot_latent_space(vae)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}