{ "cells": [ { "cell_type": "markdown", "id": "999d65a0", "metadata": { "id": "999d65a0" }, "source": [ "# **自定義資料集(PyTorch)**\n", "此份程式碼為 Custom_dataset 的 PyTorch 參考寫法。\n", "\n", "## 本章節內容大綱\n", "* ### [torch.utils.data.Dataset](#Dataset)\n", "* ### [torch.utils.data.DataLoader (Dataset operation)](#DataLoader)\n", "---" ] }, { "cell_type": "markdown", "id": "3988fa89", "metadata": { "id": "3988fa89" }, "source": [ "## 匯入套件" ] }, { "cell_type": "code", "execution_count": null, "id": "a1ae3118", "metadata": { "id": "a1ae3118" }, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "markdown", "id": "2c697117", "metadata": { "id": "2c697117" }, "source": [ "## torch.utils.data.Dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "879fea0f", "metadata": { "id": "879fea0f" }, "outputs": [], "source": [ "value = torch.randn((5, 2))\n", "target = torch.randint(high=5, size=(5,), dtype=torch.int64)" ] }, { "cell_type": "markdown", "id": "f40f303c", "metadata": { "id": "f40f303c" }, "source": [ "* ### TensorDataset" ] }, { "cell_type": "code", "execution_count": null, "id": "7b78c316", "metadata": { "id": "7b78c316" }, "outputs": [], "source": [ "# 等同於 tf.data.Dataset.from_tensor_slices\n", "dataset1 = torch.utils.data.TensorDataset(value, target)" ] }, { "cell_type": "code", "execution_count": null, "id": "8a1d1dae", "metadata": { "id": "8a1d1dae" }, "outputs": [], "source": [ "dataset1" ] }, { "cell_type": "code", "execution_count": null, "id": "61319313", "metadata": { "id": "61319313" }, "outputs": [], "source": [ "it = iter(dataset1)\n", "print(next(it))" ] }, { "cell_type": "code", "execution_count": null, "id": "76705383", "metadata": { "id": "76705383" }, "outputs": [], "source": [ "for idx, elem in enumerate(dataset1):\n", " print(f'{idx}. {elem}')" ] }, { "cell_type": "markdown", "id": "6fc54fc0", "metadata": { "id": "6fc54fc0" }, "source": [ "* ### Custom Dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "248b8f8c", "metadata": { "id": "248b8f8c" }, "outputs": [], "source": [ "class myDataset(torch.utils.data.Dataset): # build dataset object by custom class\n", " def __init__(self, X, y):\n", " # --------------------------------------------\n", " # Initialize paths, transforms, and so on\n", " # --------------------------------------------\n", " self.X = X\n", " self.y = y\n", " \n", " def __getitem__(self, index):\n", " # --------------------------------------------\n", " # 1. Read from file (using numpy.fromfile, PIL.Image.open)\n", " # 2. Preprocess the data (torchvision.Transform).\n", " # 3. Return the data (e.g. image and label)\n", " # --------------------------------------------\n", " return self.X[index], self.y[index]\n", " \n", " def __len__(self):\n", " # --------------------------------------------\n", " # Indicate the total size of the dataset\n", " # --------------------------------------------\n", " return len(self.X)" ] }, { "cell_type": "code", "execution_count": null, "id": "22c8d53f", "metadata": { "id": "22c8d53f" }, "outputs": [], "source": [ "dataset2 = myDataset(value, target)" ] }, { "cell_type": "code", "execution_count": null, "id": "98729db5", "metadata": { "id": "98729db5" }, "outputs": [], "source": [ "it = iter(dataset2)\n", "print(next(it))" ] }, { "cell_type": "code", "execution_count": null, "id": "cbb5a93b", "metadata": { "id": "cbb5a93b" }, "outputs": [], "source": [ "for idx, elem in enumerate(dataset2):\n", " print(f'{idx}. {elem}')" ] }, { "cell_type": "markdown", "id": "aeb62475", "metadata": { "id": "aeb62475" }, "source": [ "## torch.utils.data.DataLoader" ] }, { "cell_type": "code", "execution_count": null, "id": "de7d6776", "metadata": { "id": "de7d6776" }, "outputs": [], "source": [ "dataset1_loader = torch.utils.data.DataLoader(dataset1, batch_size=2, shuffle=True)\n", "next(iter(dataset1_loader))" ] }, { "cell_type": "code", "execution_count": null, "id": "0523babc", "metadata": { "id": "0523babc" }, "outputs": [], "source": [ "dataset2_loader = torch.utils.data.DataLoader(dataset2, batch_size=2, shuffle=True)\n", "next(iter(dataset2_loader))" ] }, { "cell_type": "markdown", "id": "59a0ea47", "metadata": { "id": "59a0ea47" }, "source": [ "參考資料:https://pytorch.org/docs/stable/data.html" ] }, { "cell_type": "code", "execution_count": null, "id": "38a6e573", "metadata": { "id": "38a6e573" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 5 }