diff --git a/examples/vision/supervised_contrastive_learning.ipynb b/examples/vision/supervised_contrastive_learning.ipynb new file mode 100644 index 00000000..7e3715f8 --- /dev/null +++ b/examples/vision/supervised_contrastive_learning.ipynb @@ -0,0 +1,884 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "supervised-contrastive-learning", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true, + "machine_shape": "hm" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "vg1RoGTshaeA" + }, + "source": [ + "# 지도학습 기반의 대조학습\n", + "\n", + "**지은이**: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
\n", + "**옮긴이**: [Janghoo Lee](https://www.linkedin.com/in/janghoo-lee-25212a1a0/)
\n", + "**원본노트북:** [Supervised Contrastive Learning](https://keras.io/examples/vision/supervised-contrastive-learning/)
\n", + "**원본작성일:** 2020/11/30
\n", + "**최종수정일:** 2020/11/30
\n", + "**번역일:** 2021/09/04
\n", + "**번역최종수정일:** 2021/10/11
\n", + "**설명:** 지도학습 기반의 대조학습contrastive learning을 사용해서 이미지 분류 문제를 풀어 봅니다." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "66HEO9Dd9yCd" + }, + "source": [ + "**선행 추천 노트북**
\n", + "\n", + "1. \n", + "Eng : [Semi-supervised image classification using contrastive pretraining with SimCLR](https://keras.io/examples/vision/semisupervised_simclr/)
\n", + "Kor : 준지도학습 기반의 대조적 사전학습 모델 (SimCLR) 을 이용한 이미지 분류\n", + "2.\n", + "Eng : [Image similarity estimation using a Siamese Network with a contrastive loss](https://keras.io/examples/vision/siamese_contrastive/)\n", + "\n", + "
\n", + "\n", + "*이 노트북은 2021 Open Source Contribution Contribution Academy, Keras Korea 의 지원을 받아 제작되었습니다. 한국어로 옮겨진 노트북은 이해를 돕기 위해 원본 노트북에서 제공하는 설명에 대해 추가적인 내용이 들어가 있음을 알립니다. 원문 설명은 [원본 노트북](https://keras.io/examples/vision/simsiam/)을 참고하세요.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UkGY6ko5iDbP" + }, + "source": [ + "## 들어가며\n", + "\n", + "기존에 널리 사용되던 크로스엔트로피 손실 기반의 지도학습 방식을 뛰어넘은 최근 학습 방식이 있습니다. 바로 대조학습contrastive learning 이라는 방법입니다. 대조학습은 주로 자기지도학습self-supervised 중심으로 연구가 진행되어 왔습니다. [MoCo](https://arxiv.org/abs/1911.05722) 나 [SimCLR](https://arxiv.org/abs/2002.05709) 같은 연구들이 이에 속합니다. 하지만 이 노트북에서 소개할 내용은 [지도학습 기반의 대조학습](https://arxiv.org/abs/2004.11362) (Prannay Khosla et al.) 입니다. 이미지 분류모델을 지도학습 기반의 대조학습으로 학습시키려면 아래와 같은 과정을 따라야 합니다.\n", + "\n", + "1. 모델에 입력된 이미지를 잘 표현하는 벡터를 만들어내는 인코더를 학습시킵니다. 이때 인코더가 해야 하는 일은 대략적으로 다음과 같습니다. 범주 A 에 속하는 이미지를 `{a1, a2, a3, ...}` 라고 하고, 범주 B 에 속하는 이미지를 `{b1, b2, b3, ...}` 라고 해 봅시다.\n", + " - `I`. 동일한 범주의 이미지에 대한 표현 벡터 쌍을 (a1, a2), (a1, a3), ... 이라고 하겠습니다.\n", + " - `II`. 다른 범주의 이미지에 대한 표현 벡터 쌍을 (a1, b1), (a1, b2), ... 이라고 하겠습니다.\n", + " - `III`. 이때, 모델은 `I`. 벡터 쌍 `(a1, a2)`, 또는 `(a1, a3)`, ... 등이 가까운 코사인거리를 가지고, `II`. 두 벡터 `(a1, b1)`, 또는 `(a1, b2)`, ... 등이 `I`. 에 비해 상대적으로 높은 코사인거리를 가지도록 학습합니다. \n", + "2. 그 다음, 훈련이 불가능하도록 동결시킨frozen 인코더의 끝단에, 인코더가 생성하는 표현 벡터를 입력받아 이미지의 클래스를 구분해내는 분류기 레이어를 추가로 붙이고 학습시킵니다.\n", + "\n", + "이 노트북의 예제를 실행해 보기 위해서는 [TensorFlow Addons]((https://www.tensorflow.org/addons)) 가 필요합니다. 이 커맨드를 통해 다운받도록 합니다.\n", + "\n", + "```python\n", + "pip install tensorflow-addons\n", + "```" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sKf_ayfvk71V", + "outputId": "ffc69fbd-e2c3-428c-bd15-3a9ead190ec8" + }, + "source": [ + "# 구글 코랩 (Google COLAB) 환경이라면 아랫줄 코드를 주석 해제한 뒤 셀을 실행합니다.\n", + "!pip install tensorflow-addons" + ], + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.7/dist-packages (0.14.0)\n", + "Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow-addons) (2.7.1)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tJlVOPHhlFRI" + }, + "source": [ + "## 환경 준비하기" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "t8kkxBEchMqz" + }, + "source": [ + "import tensorflow as tf\n", + "import tensorflow_addons as tfa\n", + "import numpy as np\n", + "from tensorflow import keras\n", + "from tensorflow.keras import layers" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1KiZmZj4lMnA" + }, + "source": [ + "## 데이터 준비하기" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KGxz2a9lhMq1", + "outputId": "effae69a-0dd5-4c01-e2a2-8a23f34e6f25" + }, + "source": [ + "num_classes = 10\n", + "input_shape = (32, 32, 3)\n", + "\n", + "# 훈련 데이터 세트와 평가 데이터 세트를 로드합니다.\n", + "(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n", + "\n", + "# 훈련 데이터 세트와 평가 데이터 세트의 모양을 확인합니다.\n", + "print(f\"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}\")\n", + "print(f\"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}\")" + ], + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)\n", + "x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Yml2UAblOnb" + }, + "source": [ + "## 이미지 변형 정의하기" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "yN8W0X0EhMq2" + }, + "source": [ + "data_augmentation = keras.Sequential(\n", + " [\n", + " layers.Normalization(),\n", + " layers.RandomFlip(\"horizontal\"),\n", + " layers.RandomRotation(0.02),\n", + " layers.RandomWidth(0.2),\n", + " layers.RandomHeight(0.2),\n", + " ]\n", + ")\n", + "\n", + "# 일부 레이어는 (이 코드에서는 Normalization 레이어) 내부적으로 상태를 가지고 있습니다.\n", + "# 이러한 상태는 데이터세트에 맞게 미리 설정되어야 합니다.\n", + "# https://keras.io/guides/preprocessing_layers/ 을 참고하세요.\n", + "data_augmentation.layers[0].adapt(x_train)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jK3SfiJmlUFX" + }, + "source": [ + "## 인코더 모델 만들기\n", + "\n", + "인코더는 입력 이미지를 받아서 2048 차원의 특징 벡터를 만들어냅니다." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fYZuqTLAhMq3", + "outputId": "a0b9d351-c46f-4a86-fe98-5f297101ef31" + }, + "source": [ + "def create_encoder():\n", + " resnet = keras.applications.ResNet50V2(\n", + " include_top=False, weights=None, input_shape=input_shape, pooling=\"avg\"\n", + " )\n", + "\n", + " inputs = keras.Input(shape=input_shape)\n", + " augmented = data_augmentation(inputs)\n", + " outputs = resnet(augmented)\n", + " model = keras.Model(inputs=inputs, outputs=outputs, name=\"cifar10-encoder\")\n", + " return model\n", + "\n", + "\n", + "encoder = create_encoder()\n", + "encoder.summary()" + ], + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"cifar10-encoder\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_10 (InputLayer) [(None, 32, 32, 3)] 0 \n", + "_________________________________________________________________\n", + "sequential_1 (Sequential) (None, None, None, 3) 7 \n", + "_________________________________________________________________\n", + "resnet50v2 (Functional) (None, 2048) 23564800 \n", + "=================================================================\n", + "Total params: 23,564,807\n", + "Trainable params: 23,519,360\n", + "Non-trainable params: 45,447\n", + "_________________________________________________________________\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yzVII368leub" + }, + "source": [ + "## 분류 모델 만들기\n", + "\n", + "분류 모델은 아까 만들었던 인코더에 완전연결층과 소프트맥스층을 추가적으로 붙여서 완성합니다." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3Y8B3Z8chMq4" + }, + "source": [ + "def create_classifier(encoder, trainable=True):\n", + "\n", + " for layer in encoder.layers:\n", + " layer.trainable = trainable\n", + "\n", + " inputs = keras.Input(shape=input_shape)\n", + " features = encoder(inputs)\n", + " features = layers.Dropout(dropout_rate)(features)\n", + " features = layers.Dense(hidden_units, activation=\"relu\")(features)\n", + " features = layers.Dropout(dropout_rate)(features)\n", + " outputs = layers.Dense(num_classes, activation=\"softmax\")(features)\n", + "\n", + " model = keras.Model(inputs=inputs, outputs=outputs, name=\"cifar10-classifier\")\n", + " model.compile(\n", + " optimizer=keras.optimizers.Adam(learning_rate),\n", + " loss=keras.losses.SparseCategoricalCrossentropy(),\n", + " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + " )\n", + " return model" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yXgXFGzUlssG" + }, + "source": [ + "## 실험 1: 기준이 되는 분류 모델 만들어보기\n", + "\n", + "기존의 분류 모델에서는 인코더에 분류기를 붙여 크로스엔트로피 손실로 인코더와 분류기 전체를 학습했고, 최근 제안된 지도학습 방식의 대조학습에 따르면 대조학습 방식으로 사전학습된 인코더에 분류기(완전연결층과 소프트맥스층) 를 붙여 완성한다고 했습니다.\n", + "\n", + "이 실험에서는 인코더에 분류기를 붙여 크로스엔트로피 손실로 인코더와 분류기 전체를 학습하는 일반적으로 널리 사용되는 방식의 분류기를 만들 것입니다." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5z_TMMOohMq5", + "outputId": "544ec6c0-e96e-4e21-e6c3-e2eebac5c8e2" + }, + "source": [ + "learning_rate = 0.001\n", + "batch_size = 265\n", + "hidden_units = 512\n", + "projection_units = 128\n", + "num_epochs = 50\n", + "dropout_rate = 0.5\n", + "temperature = 0.05\n", + "\n", + "encoder = create_encoder()\n", + "classifier = create_classifier(encoder)\n", + "classifier.summary()\n", + "\n", + "history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)\n", + "\n", + "accuracy = classifier.evaluate(x_test, y_test)[1]\n", + "print(f\"Test accuracy: {round(accuracy * 100, 2)}%\")" + ], + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"cifar10-classifier\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_13 (InputLayer) [(None, 32, 32, 3)] 0 \n", + "_________________________________________________________________\n", + "cifar10-encoder (Functional) (None, 2048) 23564807 \n", + "_________________________________________________________________\n", + "dropout_2 (Dropout) (None, 2048) 0 \n", + "_________________________________________________________________\n", + "dense_3 (Dense) (None, 512) 1049088 \n", + "_________________________________________________________________\n", + "dropout_3 (Dropout) (None, 512) 0 \n", + "_________________________________________________________________\n", + "dense_4 (Dense) (None, 10) 5130 \n", + "=================================================================\n", + "Total params: 24,619,025\n", + "Trainable params: 24,573,578\n", + "Non-trainable params: 45,447\n", + "_________________________________________________________________\n", + "Epoch 1/50\n", + "189/189 [==============================] - 21s 86ms/step - loss: 1.9464 - sparse_categorical_accuracy: 0.2857\n", + "Epoch 2/50\n", + "189/189 [==============================] - 16s 85ms/step - loss: 1.5153 - sparse_categorical_accuracy: 0.4535\n", + "Epoch 3/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 1.4006 - sparse_categorical_accuracy: 0.4939\n", + "Epoch 4/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 1.2541 - sparse_categorical_accuracy: 0.5567\n", + "Epoch 5/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 1.1425 - sparse_categorical_accuracy: 0.5990\n", + "Epoch 6/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 1.0840 - sparse_categorical_accuracy: 0.6243\n", + "Epoch 7/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.9992 - sparse_categorical_accuracy: 0.6570\n", + "Epoch 8/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.9312 - sparse_categorical_accuracy: 0.6831\n", + "Epoch 9/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.8784 - sparse_categorical_accuracy: 0.6970\n", + "Epoch 10/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.8552 - sparse_categorical_accuracy: 0.7068\n", + "Epoch 11/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.7908 - sparse_categorical_accuracy: 0.7318\n", + "Epoch 12/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.7521 - sparse_categorical_accuracy: 0.7428\n", + "Epoch 13/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 0.7325 - sparse_categorical_accuracy: 0.7500\n", + "Epoch 14/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.7226 - sparse_categorical_accuracy: 0.7552\n", + "Epoch 15/50\n", + "189/189 [==============================] - 16s 83ms/step - loss: 0.7012 - sparse_categorical_accuracy: 0.7621\n", + "Epoch 16/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.7739 - sparse_categorical_accuracy: 0.7394\n", + "Epoch 17/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.6715 - sparse_categorical_accuracy: 0.7725\n", + "Epoch 18/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.6285 - sparse_categorical_accuracy: 0.7889\n", + "Epoch 19/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 0.5905 - sparse_categorical_accuracy: 0.7993\n", + "Epoch 20/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.5729 - sparse_categorical_accuracy: 0.8084\n", + "Epoch 21/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.5559 - sparse_categorical_accuracy: 0.8130\n", + "Epoch 22/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.6524 - sparse_categorical_accuracy: 0.7797\n", + "Epoch 23/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.5609 - sparse_categorical_accuracy: 0.8115\n", + "Epoch 24/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 0.5822 - sparse_categorical_accuracy: 0.8041\n", + "Epoch 25/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.5737 - sparse_categorical_accuracy: 0.8063\n", + "Epoch 26/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.5072 - sparse_categorical_accuracy: 0.8285\n", + "Epoch 27/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.4793 - sparse_categorical_accuracy: 0.8378\n", + "Epoch 28/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.4598 - sparse_categorical_accuracy: 0.8430\n", + "Epoch 29/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.4438 - sparse_categorical_accuracy: 0.8498\n", + "Epoch 30/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.4334 - sparse_categorical_accuracy: 0.8511\n", + "Epoch 31/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.4379 - sparse_categorical_accuracy: 0.8508\n", + "Epoch 32/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.4195 - sparse_categorical_accuracy: 0.8568\n", + "Epoch 33/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.3972 - sparse_categorical_accuracy: 0.8628\n", + "Epoch 34/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.7014 - sparse_categorical_accuracy: 0.7644\n", + "Epoch 35/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.5443 - sparse_categorical_accuracy: 0.8155\n", + "Epoch 36/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.4521 - sparse_categorical_accuracy: 0.8470\n", + "Epoch 37/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.4085 - sparse_categorical_accuracy: 0.8590\n", + "Epoch 38/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.3918 - sparse_categorical_accuracy: 0.8661\n", + "Epoch 39/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.5150 - sparse_categorical_accuracy: 0.8283\n", + "Epoch 40/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.4703 - sparse_categorical_accuracy: 0.8420\n", + "Epoch 41/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.4059 - sparse_categorical_accuracy: 0.8643\n", + "Epoch 42/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.3850 - sparse_categorical_accuracy: 0.8682\n", + "Epoch 43/50\n", + "189/189 [==============================] - 15s 79ms/step - loss: 0.3451 - sparse_categorical_accuracy: 0.8816\n", + "Epoch 44/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 0.3264 - sparse_categorical_accuracy: 0.8863\n", + "Epoch 45/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.3097 - sparse_categorical_accuracy: 0.8934\n", + "Epoch 46/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.3180 - sparse_categorical_accuracy: 0.8911\n", + "Epoch 47/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 0.8022 - sparse_categorical_accuracy: 0.7336\n", + "Epoch 48/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 0.5037 - sparse_categorical_accuracy: 0.8293\n", + "Epoch 49/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.3876 - sparse_categorical_accuracy: 0.8684\n", + "Epoch 50/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 0.3992 - sparse_categorical_accuracy: 0.8646\n", + "313/313 [==============================] - 5s 13ms/step - loss: 0.9051 - sparse_categorical_accuracy: 0.7858\n", + "Test accuracy: 78.58%\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bUmUKZNzoTKv" + }, + "source": [ + "## 실험 2: 지도학습 방식의 대조학습\n", + "\n", + "이제부터는 모델을 두 단계로 나누어 훈련시킵니다. \n", + "\n", + "1. 첫 단계에는, [Prannay Khosla et al.](https://arxiv.org/abs/2004.11362) 에서 제안한 대로, 인코더가 크로스엔트로피 손실이 아닌, 대조학습 손실을 최적화시키도록 사전학습시킵니다.\n", + "2. 두 번째 단계에는, 사전학습된 인코더를 사용하는 분류기의 가중치만 크로스엔트로피 손실을 최적화하도록 학습시킵니다.\n", + "\n", + "*역주 : 이 노트북에서는 논문이 제안하는 loss 를 정확히 재구현하지는 않고, 지도학습 기반의 대조학습이라는 컨셉트만 유지합니다. 이를 간단히 구현하기 위해 n pair loss 를 사용하는데, 이에 대해서 더 궁금하다면 [이 블로그(영문)](https://towardsdatascience.com/contrasting-contrastive-loss-functions-3c13ca5f055e) 을 참고하세요.* " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VYQ4qzQhoS_N" + }, + "source": [ + "### 1. 지도학습 방식의 대조 손실 함수" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "iuqQXAE7hMq6" + }, + "source": [ + "class SupervisedContrastiveLoss(keras.losses.Loss):\n", + " def __init__(self, temperature=1, name=None):\n", + " super(SupervisedContrastiveLoss, self).__init__(name=name)\n", + " self.temperature = temperature\n", + "\n", + " def __call__(self, labels, feature_vectors, sample_weight=None):\n", + " \n", + " # 128 차원으로 투영된 b 개의 특징벡터들 [b, 128]\n", + " feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)\n", + " \n", + " # 코사인거리 [b, b]\n", + " cosine_sim = tf.matmul(feature_vectors_normalized, tf.transpose(feature_vectors_normalized)) # [b, 128] * [128, b] = [b, b]\n", + " # 참고 : 코사인거리 i 행 j 열이 나타내는 값은 i 번째 이미지와 j 번째 이미지의 코사인거리를 나타냅니다.\n", + " logits = tf.divide(cosine_sim, self.temperature,)\n", + "\n", + " # npairs_loss(y_true shape : [b,], y_pred shape:[b, b])\n", + " # n pair loss 를 적용합니다.\n", + " return tfa.losses.npairs_loss(tf.squeeze(labels), logits)\n", + "\n", + "\n", + "def add_projection_head(encoder):\n", + " inputs = keras.Input(shape=input_shape)\n", + " features = encoder(inputs) # features : 인코더가 만든 특징 벡터 [b, 2048]\n", + " outputs = layers.Dense(projection_units, activation=\"relu\")(features) # outputs : 투영된 특징 벡터 [b, 128]\n", + " model = keras.Model(\n", + " inputs=inputs, outputs=outputs, name=\"cifar-encoder_with_projection-head\"\n", + " )\n", + " return model" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eA16iGmPtcAo" + }, + "source": [ + "### 2. 인코더 사전학습" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "background_save": true, + "base_uri": "https://localhost:8080/" + }, + "id": "3UX19YfHhMq7", + "outputId": "5bd61c9a-8592-4801-e2ef-929a9d980edd" + }, + "source": [ + "encoder = create_encoder()\n", + "\n", + "encoder_with_projection_head = add_projection_head(encoder)\n", + "encoder_with_projection_head.compile(\n", + " optimizer=keras.optimizers.Adam(learning_rate),\n", + " loss=SupervisedContrastiveLoss(temperature),\n", + ")\n", + "\n", + "encoder_with_projection_head.summary()\n", + "\n", + "history = encoder_with_projection_head.fit(\n", + " x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,\n", + ")" + ], + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"cifar-encoder_with_projection-head\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_16 (InputLayer) [(None, 32, 32, 3)] 0 \n", + "_________________________________________________________________\n", + "cifar10-encoder (Functional) (None, 2048) 23564807 \n", + "_________________________________________________________________\n", + "dense_5 (Dense) (None, 128) 262272 \n", + "=================================================================\n", + "Total params: 23,827,079\n", + "Trainable params: 23,781,632\n", + "Non-trainable params: 45,447\n", + "_________________________________________________________________\n", + "Epoch 1/50\n", + "189/189 [==============================] - 20s 81ms/step - loss: 5.3682\n", + "Epoch 2/50\n", + "189/189 [==============================] - 16s 83ms/step - loss: 5.1349\n", + "Epoch 3/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 5.0188\n", + "Epoch 4/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 4.9108\n", + "Epoch 5/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.8326\n", + "Epoch 6/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.7406\n", + "Epoch 7/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.6803\n", + "Epoch 8/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 4.6140\n", + "Epoch 9/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 4.5639\n", + "Epoch 10/50\n", + "189/189 [==============================] - 16s 83ms/step - loss: 4.4961\n", + "Epoch 11/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.4604\n", + "Epoch 12/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.4255\n", + "Epoch 13/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.3784\n", + "Epoch 14/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.3428\n", + "Epoch 15/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 4.3088\n", + "Epoch 16/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.2886\n", + "Epoch 17/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.2633\n", + "Epoch 18/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.2312\n", + "Epoch 19/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.2044\n", + "Epoch 20/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.1854\n", + "Epoch 21/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 4.1569\n", + "Epoch 22/50\n", + "189/189 [==============================] - 15s 79ms/step - loss: 4.1476\n", + "Epoch 23/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 4.1111\n", + "Epoch 24/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 4.0948\n", + "Epoch 25/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 4.0644\n", + "Epoch 26/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.0529\n", + "Epoch 27/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 4.0437\n", + "Epoch 28/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 4.0156\n", + "Epoch 29/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 3.9909\n", + "Epoch 30/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 3.9837\n", + "Epoch 31/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.9744\n", + "Epoch 32/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.9566\n", + "Epoch 33/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.9408\n", + "Epoch 34/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.9320\n", + "Epoch 35/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.9232\n", + "Epoch 36/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.8979\n", + "Epoch 37/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 3.8850\n", + "Epoch 38/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.8749\n", + "Epoch 39/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.8712\n", + "Epoch 40/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.8489\n", + "Epoch 41/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.8374\n", + "Epoch 42/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.8349\n", + "Epoch 43/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 3.8176\n", + "Epoch 44/50\n", + "189/189 [==============================] - 15s 82ms/step - loss: 3.8166\n", + "Epoch 45/50\n", + "189/189 [==============================] - 16s 82ms/step - loss: 3.7990\n", + "Epoch 46/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.7941\n", + "Epoch 47/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.7782\n", + "Epoch 48/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.7689\n", + "Epoch 49/50\n", + "189/189 [==============================] - 15s 80ms/step - loss: 3.7417\n", + "Epoch 50/50\n", + "189/189 [==============================] - 15s 81ms/step - loss: 3.7507\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mv9aQmG0ta1z" + }, + "source": [ + "### 3. 훈련되지 않도록 가중치가 고정된 인코더의 결과를 사용하는 분류기 훈련" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "background_save": true + }, + "id": "cJbwVmeMhMq7", + "outputId": "0b23f20f-42f2-423c-ecd0-6a3393f83433" + }, + "source": [ + "classifier = create_classifier(encoder, trainable=False)\n", + "\n", + "history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)\n", + "\n", + "accuracy = classifier.evaluate(x_test, y_test)[1]\n", + "print(f\"Test accuracy: {round(accuracy * 100, 2)}%\")" + ], + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/50\n", + "189/189 [==============================] - 7s 26ms/step - loss: 0.3634 - sparse_categorical_accuracy: 0.8994\n", + "Epoch 2/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.3038 - sparse_categorical_accuracy: 0.9100\n", + "Epoch 3/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2902 - sparse_categorical_accuracy: 0.9115\n", + "Epoch 4/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2894 - sparse_categorical_accuracy: 0.9093\n", + "Epoch 5/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2775 - sparse_categorical_accuracy: 0.9151\n", + "Epoch 6/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2791 - sparse_categorical_accuracy: 0.9124\n", + "Epoch 7/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2789 - sparse_categorical_accuracy: 0.9134\n", + "Epoch 8/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2765 - sparse_categorical_accuracy: 0.9134\n", + "Epoch 9/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2790 - sparse_categorical_accuracy: 0.9135\n", + "Epoch 10/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2771 - sparse_categorical_accuracy: 0.9130\n", + "Epoch 11/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2895 - sparse_categorical_accuracy: 0.9116\n", + "Epoch 12/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2762 - sparse_categorical_accuracy: 0.9132\n", + "Epoch 13/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2822 - sparse_categorical_accuracy: 0.9125\n", + "Epoch 14/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2779 - sparse_categorical_accuracy: 0.9126\n", + "Epoch 15/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2697 - sparse_categorical_accuracy: 0.9155\n", + "Epoch 16/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2766 - sparse_categorical_accuracy: 0.9129\n", + "Epoch 17/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2795 - sparse_categorical_accuracy: 0.9128\n", + "Epoch 18/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2761 - sparse_categorical_accuracy: 0.9134\n", + "Epoch 19/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2817 - sparse_categorical_accuracy: 0.9113\n", + "Epoch 20/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2890 - sparse_categorical_accuracy: 0.9113\n", + "Epoch 21/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2788 - sparse_categorical_accuracy: 0.9123\n", + "Epoch 22/50\n", + "189/189 [==============================] - 5s 25ms/step - loss: 0.2879 - sparse_categorical_accuracy: 0.9089\n", + "Epoch 23/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2714 - sparse_categorical_accuracy: 0.9151\n", + "Epoch 24/50\n", + "189/189 [==============================] - 5s 25ms/step - loss: 0.2837 - sparse_categorical_accuracy: 0.9101\n", + "Epoch 25/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2785 - sparse_categorical_accuracy: 0.9127\n", + "Epoch 26/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2840 - sparse_categorical_accuracy: 0.9102\n", + "Epoch 27/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2656 - sparse_categorical_accuracy: 0.9177\n", + "Epoch 28/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2687 - sparse_categorical_accuracy: 0.9150\n", + "Epoch 29/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2789 - sparse_categorical_accuracy: 0.9121\n", + "Epoch 30/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2754 - sparse_categorical_accuracy: 0.9135\n", + "Epoch 31/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2793 - sparse_categorical_accuracy: 0.9123\n", + "Epoch 32/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2788 - sparse_categorical_accuracy: 0.9132\n", + "Epoch 33/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2729 - sparse_categorical_accuracy: 0.9136\n", + "Epoch 34/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2842 - sparse_categorical_accuracy: 0.9108\n", + "Epoch 35/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2810 - sparse_categorical_accuracy: 0.9103\n", + "Epoch 36/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2719 - sparse_categorical_accuracy: 0.9137\n", + "Epoch 37/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2779 - sparse_categorical_accuracy: 0.9126\n", + "Epoch 38/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2658 - sparse_categorical_accuracy: 0.9157\n", + "Epoch 39/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2836 - sparse_categorical_accuracy: 0.9118\n", + "Epoch 40/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2786 - sparse_categorical_accuracy: 0.9123\n", + "Epoch 41/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2762 - sparse_categorical_accuracy: 0.9125\n", + "Epoch 42/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2680 - sparse_categorical_accuracy: 0.9150\n", + "Epoch 43/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2755 - sparse_categorical_accuracy: 0.9121\n", + "Epoch 44/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2775 - sparse_categorical_accuracy: 0.9116\n", + "Epoch 45/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2832 - sparse_categorical_accuracy: 0.9102\n", + "Epoch 46/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2752 - sparse_categorical_accuracy: 0.9138\n", + "Epoch 47/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2733 - sparse_categorical_accuracy: 0.9134\n", + "Epoch 48/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2779 - sparse_categorical_accuracy: 0.9137\n", + "Epoch 49/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2847 - sparse_categorical_accuracy: 0.9107\n", + "Epoch 50/50\n", + "189/189 [==============================] - 5s 26ms/step - loss: 0.2707 - sparse_categorical_accuracy: 0.9139\n", + "313/313 [==============================] - 5s 13ms/step - loss: 0.7402 - sparse_categorical_accuracy: 0.8153\n", + "Test accuracy: 81.53%\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_dWDES1gtmGE" + }, + "source": [ + "검증 데이터셋 기준으로 실험1 에서 만든 모델을 통해 얻었던 결과보다 더 좋은 성능을 낼 수 있다는 것을 보여줍니다." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z01wTknStajI" + }, + "source": [ + "## 마치며\n", + "\n", + "실험에서 보았듯, 지도학습 방식의 대조학습을 사용한다면 비슷한 실험조건(훈련 에폭수 등...) 에서 전통적으로 사용되던 성능보다 더 우수한 성능을 얻을 수 있게 됩니다. 대조학습은 이 노트북에서 사용한 모델보다 훈련이 더 어렵고 복잡한 아키텍처에서도 잘 동작하고, 다중 클래스 분류같이 단순 이미지 분류보다 확장되고 복잡해진 작업에서도 잘 동작합니다.\n", + "\n", + "이 훈련 방법을 더 효율적으로 사용하기 위해서는 배치사이즈를 늘리고, 분류기를 더 깊게 쌓는 것이 도움이 될 수 있습니다. 더 자세한 내용이 궁금하다면, 논문 [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362) 을 참고하세요.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "background_save": true + }, + "id": "SPowTOMmI_A7" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/examples/vision/supervised_contrastive_learning.md b/examples/vision/supervised_contrastive_learning.md new file mode 100644 index 00000000..8bd23d5b --- /dev/null +++ b/examples/vision/supervised_contrastive_learning.md @@ -0,0 +1,16 @@ +# Supervised Contrastive Learning + +## 관련 파일 + +- `examples/vision/` + - `supervised_contrastive_learning.ipynb` : 작업된 노트북 + - `supervised_contrastive_learning.md` : 작업 기록 + +## 작업 내역 + +### 2021/10/11 + +- [Supervised Contrastive Learning](https://keras.io/examples/vision/supervised-contrastive-learning/) 를 기반으로 번역 +- 설명을 조금 더 다듬음 +- 일부 코드를 보기 좋게 정리 +- 선행 추천 노트북을 추가함 \ No newline at end of file