{ "cells": [ { "cell_type": "markdown", "id": "754b8ebe", "metadata": { "id": "754b8ebe" }, "source": [ "# Image Super-Resolution using an Efficient Sub-Pixel CNN (ESPCN)" ] }, { "cell_type": "markdown", "id": "58c9ca95", "metadata": { "id": "58c9ca95" }, "source": [ "\n", "\n", "- [source paper](https://arxiv.org/abs/1609.05158)\n", "- [reference source](https://keras.io/examples/vision/super_resolution_sub_pixel/)" ] }, { "cell_type": "code", "execution_count": null, "id": "61b69ce3", "metadata": { "id": "61b69ce3" }, "outputs": [], "source": [ "import os\n", "import math\n", "import cv2\n", "import random\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras.preprocessing import image_dataset_from_directory\n", "from tensorflow.keras.preprocessing.image import load_img\n", "from tensorflow.keras.preprocessing.image import array_to_img\n", "from tensorflow.keras.preprocessing.image import img_to_array\n", "\n", "from IPython.display import display\n", "\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.axes_grid1.inset_locator import mark_inset\n", "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes\n", "\n", "import PIL" ] }, { "cell_type": "markdown", "id": "2d7ce220", "metadata": { "id": "2d7ce220" }, "source": [ "# Load data: BSDS500 dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "08205c10", "metadata": { "id": "08205c10" }, "outputs": [], "source": [ "dataset_url = \"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz\"\n", "data_dir = tf.keras.utils.get_file(origin=dataset_url, fname=\"BSR\", untar=True)\n", "root_dir = os.path.join(data_dir, \"BSDS500/data\")" ] }, { "cell_type": "code", "execution_count": null, "id": "77f26e26", "metadata": { "id": "77f26e26" }, "outputs": [], "source": [ "crop_size = 300\n", "upscale_factor = 3\n", "input_size = crop_size // upscale_factor\n", "batch_size = 8" ] }, { "cell_type": "code", "execution_count": null, "id": "be4534b0", "metadata": { "id": "be4534b0" }, "outputs": [], "source": [ "dataset = os.path.join(root_dir, 'images', 'test')\n", "test_img_paths = sorted(\n", " [os.path.join(dataset, file)\n", " for file in os.listdir(dataset)\n", " if '.jpg' in file]\n", ")" ] }, { "cell_type": "markdown", "id": "cd854e5d", "metadata": { "id": "cd854e5d" }, "source": [ "# Creat Datasets, Crop and resize images" ] }, { "cell_type": "code", "execution_count": null, "id": "c7145508", "metadata": { "id": "c7145508" }, "outputs": [], "source": [ "def process_input(inputs, input_size):\n", " return tf.image.resize(inputs, [input_size, input_size], method=\"area\")\n", "\n", "\n", "def data_generater(dataset):\n", " datalist = [file\n", " for file in os.listdir(os.path.join(root_dir,\n", " 'images',\n", " bytes.decode(dataset)))\n", " if '.jpg' in file]\n", " random.shuffle(datalist)\n", " for file in datalist:\n", " image = cv2.imread(os.path.join(root_dir,\n", " 'images',\n", " bytes.decode(dataset),\n", " file))\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " image = cv2.resize(image, (crop_size, crop_size))\n", " image = image / 255.0\n", " yield process_input(image, input_size), image\n" ] }, { "cell_type": "code", "execution_count": null, "id": "637f45c9", "metadata": { "id": "637f45c9" }, "outputs": [], "source": [ "train_ds = tf.data.Dataset.from_generator(\n", " data_generater,\n", " output_signature=(tf.TensorSpec(shape=(None, None, 3),\n", " dtype=tf.float32),\n", " tf.TensorSpec(shape=(None, None, 3),\n", " dtype=tf.float32)),\n", " args=['train']\n", ")\n", "\n", "valid_ds = tf.data.Dataset.from_generator(\n", " data_generater,\n", " output_signature=(tf.TensorSpec(shape=(None, None, 3),\n", " dtype=tf.float32),\n", " tf.TensorSpec(shape=(None, None, 3),\n", " dtype=tf.float32)),\n", " args=['val']\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a76b63b8", "metadata": { "id": "a76b63b8" }, "outputs": [], "source": [ "train_ds = train_ds.batch(batch_size).prefetch(buffer_size=32)\n", "valid_ds = valid_ds.batch(batch_size).prefetch(buffer_size=32)" ] }, { "cell_type": "code", "execution_count": null, "id": "0dfec4d7", "metadata": { "id": "0dfec4d7" }, "outputs": [], "source": [ "for batch in train_ds.take(1):\n", " for img in batch[0]:\n", " print(img.shape)\n", " display(array_to_img(img))\n", " for img in batch[1]:\n", " print(img.shape)\n", " display(array_to_img(img))" ] }, { "cell_type": "markdown", "id": "ca310c67", "metadata": { "id": "ca310c67" }, "source": [ "# Build a model" ] }, { "cell_type": "code", "execution_count": null, "id": "0112aec0", "metadata": { "id": "0112aec0" }, "outputs": [], "source": [ "def get_model(upscale_factor=3, channels=3):\n", " inputs = tf.keras.Input(shape=(None, None, channels))\n", " x = tf.keras.layers.Conv2D(64, (5, 5),\n", " activation='relu', padding='same')(inputs)\n", " x = tf.keras.layers.Conv2D(64, (3, 3),\n", " activation='relu', padding='same')(x)\n", " x = tf.keras.layers.Conv2D(32, (3, 3),\n", " activation='relu', padding='same')(x)\n", " x = tf.keras.layers.Conv2D(channels * (upscale_factor ** 2), (3, 3),\n", " activation='relu', padding='same')(x)\n", " outputs = tf.nn.depth_to_space(x, upscale_factor)\n", " model = tf.keras.Model(inputs, outputs)\n", " return model\n", "\n", "\n", "model = get_model(upscale_factor=upscale_factor, channels=3)" ] }, { "cell_type": "markdown", "id": "fb11f3d7", "metadata": { "id": "fb11f3d7" }, "source": [ "# Define callbacks to monitor training" ] }, { "cell_type": "code", "execution_count": null, "id": "1bedc111", "metadata": { "id": "1bedc111" }, "outputs": [], "source": [ "def plot_results(img, prefix, title):\n", " \"\"\"Plot the result with zoom-in area.\"\"\"\n", " img_array = img_to_array(img)\n", " img_array = img_array.astype(\"float32\") / 255.0\n", "\n", " # Create a new figure with a default subplot.\n", " fig, ax = plt.subplots()\n", " im = ax.imshow(img_array[::-1], origin=\"lower\")\n", "\n", " plt.title(title)\n", " # zoom-factor: 2.0, location: upper-left\n", " axins = zoomed_inset_axes(ax, 2, loc=2)\n", " axins.imshow(img_array[::-1], origin=\"lower\")\n", "\n", " # Specify the limits.\n", " x1, x2, y1, y2 = 200, 300, 100, 200\n", " # Apply the x-limits.\n", " axins.set_xlim(x1, x2)\n", " # Apply the y-limits.\n", " axins.set_ylim(y1, y2)\n", "\n", " plt.yticks(visible=False)\n", " plt.xticks(visible=False)\n", "\n", " # Make the line.\n", " mark_inset(ax, axins, loc1=1, loc2=3, fc=\"none\", ec=\"blue\")\n", " plt.savefig(str(prefix) + \"-\" + title + \".png\")\n", " plt.show()\n", "\n", "\n", "def get_lowres_image(img, upscale_factor):\n", " \"\"\"Return low-resolution image to use as model input.\"\"\"\n", " new_img = img.resize((img.size[0] // upscale_factor,\n", " img.size[1] // upscale_factor),\n", " PIL.Image.BICUBIC,)\n", " return new_img\n", "\n", "\n", "def upscale_image(model, img):\n", " \"\"\"Predict the result based on input image and restore the image as RGB.\"\"\"\n", "\n", " img = img_to_array(img)\n", " img = img.astype(\"float32\") / 255.0\n", "\n", " inputs = np.expand_dims(img, axis=0)\n", " outputs = model.predict(inputs)\n", "\n", " output_img = outputs[0]\n", " output_img *= 255.0\n", " output_img = output_img.clip(0, 255)\n", "\n", " return output_img" ] }, { "cell_type": "code", "execution_count": null, "id": "02d73185", "metadata": { "id": "02d73185" }, "outputs": [], "source": [ "class ESPCNCallback(tf.keras.callbacks.Callback):\n", " def __init__(self):\n", " super().__init__()\n", " self.test_img = get_lowres_image(load_img(test_img_paths[0]),\n", " upscale_factor)\n", "\n", " # Store PSNR value in each epoch.\n", " def on_epoch_begin(self, epoch, logs=None):\n", " self.psnr = []\n", "\n", " def on_epoch_end(self, epoch, logs=None):\n", " print(\"Mean PSNR for epoch: %.2f\" % (np.mean(self.psnr)))\n", " if epoch % 20 == 0:\n", " prediction = upscale_image(self.model, self.test_img)\n", " plot_results(prediction, \"epoch-\" + str(epoch), \"prediction\")\n", "\n", " def on_test_batch_end(self, batch, logs=None):\n", " self.psnr.append(10 * math.log10(1 / logs[\"loss\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "d18278f5", "metadata": { "id": "d18278f5" }, "outputs": [], "source": [ "early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor=\"loss\",\n", " patience=10)\n", "\n", "checkpoint_filepath = \"/tmp/checkpoint\"\n", "\n", "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_filepath,\n", " save_weights_only=True,\n", " monitor=\"loss\",\n", " mode=\"min\",\n", " save_best_only=True,\n", ")\n", "\n", "model = get_model(upscale_factor=upscale_factor, channels=3)\n", "model.summary()\n", "\n", "callbacks = [ESPCNCallback(),\n", " early_stopping_callback,\n", " model_checkpoint_callback]\n", "\n", "loss_fn = tf.keras.losses.MeanSquaredError()\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)" ] }, { "cell_type": "markdown", "id": "bddc22d9", "metadata": { "id": "bddc22d9" }, "source": [ "# Train the model" ] }, { "cell_type": "code", "execution_count": null, "id": "f0e159f0", "metadata": { "id": "f0e159f0" }, "outputs": [], "source": [ "epochs = 100\n", "\n", "model.compile(optimizer=optimizer, loss=loss_fn)\n", "\n", "model.fit(train_ds, epochs=epochs, callbacks=callbacks,\n", " validation_data=valid_ds, verbose=2)\n", "\n", "# The model weights (that are considered the best) are loaded into the model.\n", "model.load_weights(checkpoint_filepath)" ] }, { "cell_type": "markdown", "id": "cb3ee9e0", "metadata": { "id": "cb3ee9e0" }, "source": [ "# Run model prediction and plot the results" ] }, { "cell_type": "code", "execution_count": null, "id": "f30863d8", "metadata": { "id": "f30863d8" }, "outputs": [], "source": [ "total_bicubic_psnr = 0.0\n", "total_test_psnr = 0.0\n", "\n", "for index, test_img_path in enumerate(test_img_paths[50:60]):\n", " img = load_img(test_img_path)\n", " lowres_input = get_lowres_image(img, upscale_factor)\n", " w = lowres_input.size[0] * upscale_factor\n", " h = lowres_input.size[1] * upscale_factor\n", " highres_img = img.resize((w, h))\n", " prediction = upscale_image(model, lowres_input)\n", " lowres_img = lowres_input.resize((w, h))\n", " lowres_img_arr = img_to_array(lowres_img)\n", " highres_img_arr = img_to_array(highres_img)\n", " predict_img_arr = img_to_array(prediction)\n", " bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)\n", " test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)\n", "\n", " total_bicubic_psnr += bicubic_psnr\n", " total_test_psnr += test_psnr\n", "\n", " print(\"PSNR of low resolution image and high resolution image is %.4f\"\n", " % bicubic_psnr)\n", " print(\"PSNR of predict and high resolution is %.4f\" % test_psnr)\n", " plot_results(lowres_img, index, \"lowres\")\n", " plot_results(highres_img, index, \"highres\")\n", " plot_results(prediction, index, \"prediction\")\n", "\n", "print(\"Avg. PSNR of lowres images is %.4f\" % (total_bicubic_psnr / 10))\n", "print(\"Avg. PSNR of reconstructions is %.4f\" % (total_test_psnr / 10))" ] }, { "cell_type": "code", "execution_count": null, "id": "db4a17c6", "metadata": { "id": "db4a17c6" }, "outputs": [], "source": [ "model.save('SRCNN_rgb.h5')" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }