{
"cells": [
{
"cell_type": "markdown",
"id": "14c6c536",
"metadata": {
"id": "14c6c536"
},
"source": [
"# **自定義資料集(Custom Dataset)**\n",
"神經網路的訓練中,往往資料量都是相當龐大的(無法一次讀取進記憶體的資料量),因此需要透過 Dataset 的建立,拆分成數個較小的資料,批次讀取進模型訓練。\n",
"\n",
"## 本章節內容大綱\n",
"* ### [tf.data.Dataset](#Dataset)\n",
" * #### from_tensors\n",
" * #### from_tensor_slices\n",
" * #### from_generator\n",
"* ### [Dataset operation](#Operation)\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "628e39a2",
"metadata": {
"id": "628e39a2"
},
"source": [
"## 匯入套件"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f149c344",
"metadata": {
"id": "f149c344"
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"id": "f45b5381",
"metadata": {
"id": "f45b5381"
},
"source": [
"\n",
"## tf.data.Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16719c03",
"metadata": {
"id": "16719c03"
},
"outputs": [],
"source": [
"value = tf.random.uniform((5, 2))\n",
"target = tf.random.uniform((5,), maxval=5, dtype=tf.int64)"
]
},
{
"cell_type": "markdown",
"id": "589d0e7f",
"metadata": {
"id": "589d0e7f"
},
"source": [
"* ### from_tensors"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18a85b52",
"metadata": {
"id": "18a85b52"
},
"outputs": [],
"source": [
"dataset1 = tf.data.Dataset.from_tensors((value, target))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eae82d8e",
"metadata": {
"id": "eae82d8e"
},
"outputs": [],
"source": [
"dataset1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5c8af35",
"metadata": {
"id": "f5c8af35"
},
"outputs": [],
"source": [
"it = iter(dataset1)\n",
"print(next(it))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98c789bd",
"metadata": {
"id": "98c789bd"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(dataset1):\n",
" print(f'{idx}. {elem}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9817477",
"metadata": {
"id": "d9817477"
},
"outputs": [],
"source": [
"list(dataset1.as_numpy_iterator())"
]
},
{
"cell_type": "markdown",
"id": "76aa6b5f",
"metadata": {
"id": "76aa6b5f"
},
"source": [
"* ### from_tensor_slices"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "781b7595",
"metadata": {
"id": "781b7595"
},
"outputs": [],
"source": [
"dataset2 = tf.data.Dataset.from_tensor_slices((value, target))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5b2bb11",
"metadata": {
"id": "f5b2bb11"
},
"outputs": [],
"source": [
"it = iter(dataset2)\n",
"print('0.', next(it))\n",
"print('1.', next(it))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa92f76",
"metadata": {
"id": "3aa92f76"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(dataset2):\n",
" print(f'{idx}. {elem}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c0990d0",
"metadata": {
"id": "4c0990d0"
},
"outputs": [],
"source": [
"list(dataset2.as_numpy_iterator())"
]
},
{
"cell_type": "markdown",
"id": "228698d9",
"metadata": {
"id": "228698d9"
},
"source": [
"* ### from_generator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "045eab63",
"metadata": {
"id": "045eab63"
},
"outputs": [],
"source": [
"# generator function\n",
"def sample(value, target):\n",
" i = 0\n",
" stop = 5\n",
" while i < stop:\n",
" yield (value[i, :], target[i])\n",
" i += 1\n",
"\n",
"\n",
"dataset3 = tf.data.Dataset.from_generator(sample,\n",
" args=(value, target),\n",
" output_types=(tf.float32, tf.uint8))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e9179a8",
"metadata": {
"id": "2e9179a8"
},
"outputs": [],
"source": [
"it = iter(dataset3)\n",
"print('0.', next(it))\n",
"print('1.', next(it))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87d2ecdf",
"metadata": {
"id": "87d2ecdf"
},
"outputs": [],
"source": [
"list(dataset3.as_numpy_iterator())"
]
},
{
"cell_type": "markdown",
"id": "40022490",
"metadata": {
"id": "40022490"
},
"source": [
"### ZipDataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14fa4ebb",
"metadata": {
"id": "14fa4ebb"
},
"outputs": [],
"source": [
"x_dataset = tf.data.Dataset.from_tensor_slices(value)\n",
"y_dataset = tf.data.Dataset.from_tensor_slices(target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdf6638f",
"metadata": {
"id": "fdf6638f"
},
"outputs": [],
"source": [
"dataset4 = tf.data.Dataset.zip((x_dataset, y_dataset))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb76987f",
"metadata": {
"id": "bb76987f"
},
"outputs": [],
"source": [
"# zip dataset 批次輸出的是來自兩個資料集的樣本\n",
"for idx, elem in dataset4:\n",
" print(f'{idx}, {elem.numpy()}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02b98815",
"metadata": {
"id": "02b98815"
},
"outputs": [],
"source": [
"list(dataset4.as_numpy_iterator())"
]
},
{
"cell_type": "markdown",
"id": "8e6eaf9a",
"metadata": {
"id": "8e6eaf9a"
},
"source": [
"\n",
"## Dataset 物件的操作"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9caa4432",
"metadata": {
"id": "9caa4432"
},
"outputs": [],
"source": [
"random_v = tf.random.normal((10, 4))\n",
"dataset = tf.data.Dataset.from_tensor_slices(random_v)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbc6dfd2",
"metadata": {
"id": "fbc6dfd2"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(dataset):\n",
" print(f'{idx}. {elem.numpy()}')"
]
},
{
"cell_type": "markdown",
"id": "0050c2f5",
"metadata": {
"id": "0050c2f5"
},
"source": [
"* ### shuffle"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6df6cd08",
"metadata": {
"id": "6df6cd08"
},
"outputs": [],
"source": [
"shuffle_dataset = dataset.shuffle(buffer_size=3, reshuffle_each_iteration=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d36bcfcd",
"metadata": {
"id": "d36bcfcd"
},
"outputs": [],
"source": [
"shuffle_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e813a256",
"metadata": {
"id": "e813a256"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(shuffle_dataset):\n",
" print(f'{idx}. {elem.numpy()}')"
]
},
{
"cell_type": "markdown",
"id": "90a8feb7",
"metadata": {
"id": "90a8feb7"
},
"source": [
"* ### batch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af9e03b8",
"metadata": {
"id": "af9e03b8"
},
"outputs": [],
"source": [
"batch_dataset = dataset.batch(batch_size=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f51d460d",
"metadata": {
"id": "f51d460d"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(batch_dataset):\n",
" print(f'{idx}. {elem.numpy()}')"
]
},
{
"cell_type": "markdown",
"id": "f2010fba",
"metadata": {
"id": "f2010fba"
},
"source": [
"* ### repeat"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ef2eb07",
"metadata": {
"id": "5ef2eb07"
},
"outputs": [],
"source": [
"repeat_dataset = dataset.repeat(count=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df73c366",
"metadata": {
"id": "df73c366"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(repeat_dataset):\n",
" print(f'{idx}. {elem.numpy()}')"
]
},
{
"cell_type": "markdown",
"id": "6392bfcf",
"metadata": {
"id": "6392bfcf"
},
"source": [
"* #### take"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "652e080f",
"metadata": {
"id": "652e080f"
},
"outputs": [],
"source": [
"take_dataset = dataset.take(count=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4596ce1",
"metadata": {
"id": "c4596ce1"
},
"outputs": [],
"source": [
"for idx, elem in enumerate(take_dataset):\n",
" print(f'{idx}. {elem.numpy()}')"
]
},
{
"cell_type": "markdown",
"id": "ed51b2b5",
"metadata": {
"id": "ed51b2b5"
},
"source": [
"* ### prefetch: 在訓練時,同時讀取下一批資料,並做轉換。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c001a395",
"metadata": {
"id": "c001a395"
},
"outputs": [],
"source": [
"import time\n",
"class ArtificialDataset(tf.data.Dataset):\n",
" def _generator(num_samples):\n",
" # Opening the file\n",
" time.sleep(0.03)\n",
"\n",
" for sample_idx in range(num_samples):\n",
" # Reading data (line, record) from the file\n",
" time.sleep(0.015)\n",
"\n",
" yield (sample_idx,)\n",
"\n",
" def __new__(cls, num_samples=3):\n",
" return tf.data.Dataset.from_generator(\n",
" cls._generator,\n",
" output_signature=tf.TensorSpec(shape=(1,), dtype=tf.int64),\n",
" args=(num_samples,)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f822dd6",
"metadata": {
"id": "0f822dd6"
},
"outputs": [],
"source": [
"# 模擬訓練運行時間\n",
"def benchmark(dataset, num_epochs=2):\n",
" start_time = time.perf_counter()\n",
" for epoch_num in range(num_epochs):\n",
" for sample in dataset:\n",
" # Performing a training step\n",
" time.sleep(0.01)\n",
" print(\"Execution time:\", time.perf_counter() - start_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64b246cb",
"metadata": {
"id": "64b246cb"
},
"outputs": [],
"source": [
"benchmark(ArtificialDataset())"
]
},
{
"cell_type": "markdown",
"id": "6540e9f0",
"metadata": {
"id": "6540e9f0"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1209ad64",
"metadata": {
"id": "1209ad64"
},
"outputs": [],
"source": [
"benchmark(\n",
" ArtificialDataset()\n",
" .prefetch(tf.data.AUTOTUNE)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7ced4471",
"metadata": {
"id": "7ced4471"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"id": "07d83eef",
"metadata": {
"id": "07d83eef"
},
"source": [
"* ### cache: 可將讀出的資料留在快取記憶體,之後重複使用。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49b86ec7",
"metadata": {
"id": "49b86ec7"
},
"outputs": [],
"source": [
"benchmark(\n",
" ArtificialDataset()\n",
" .cache()\n",
")"
]
},
{
"cell_type": "markdown",
"id": "63d4c992",
"metadata": {
"id": "63d4c992"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88158d7b",
"metadata": {
"id": "88158d7b"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"colab": {
"provenance": []
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 5
}