1
1
# -*- coding: utf-8 -*-
2
2
"""
3
- TorchMultimodal Tutorial: Finetuning FLAVA
3
+ TorchMultimodal νν 리μΌ: FLAVA λ―ΈμΈμ‘°μ
4
4
============================================
5
+
6
+ **λ²μ:** `κΉμ°¬ <https://github.com/chanmuzi>`__
7
+
5
8
"""
6
9
10
+
7
11
######################################################################
8
- # Multimodal AI has recently become very popular owing to its ubiquitous
9
- # nature, from use cases like image captioning and visual search to more
10
- # recent applications like image generation from text. **TorchMultimodal
11
- # is a library powered by Pytorch consisting of building blocks and end to
12
- # end examples, aiming to enable and accelerate research in
13
- # multimodality**.
14
- #
15
- # In this tutorial, we will demonstrate how to use a **pretrained SoTA
16
- # model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
17
- # TorchMultimodal library to finetune on a multimodal task i.e. visual
18
- # question answering** (VQA). The model consists of two unimodal transformer
19
- # based encoders for text and image and a multimodal encoder to combine
20
- # the two embeddings. It is pretrained using contrastive, image text matching and
21
- # text, image and multimodal masking losses.
12
+ # λ©ν° λͺ¨λ¬ AIλ μ΅κ·Όμ μ΄λ―Έμ§ μλ§μΆκ°, μκ°μ κ²μλΆν° ν
μ€νΈλ‘λΆν° μ΄λ―Έμ§λ₯Ό μμ±κ°μ
13
+ # μ΅κ·Όμ μμ©κΉμ§ κ·Έ μ¬μ©μ΄ λΉ λ₯΄κ² νμ°λκ³ μμ΅λλ€. **TorchMultimodalμ PyTorchλ₯Ό
14
+ # κΈ°λ°μΌλ‘ νλ λΌμ΄λΈλ¬λ¦¬λ‘, λ©ν° λͺ¨λ¬ μ°κ΅¬λ₯Ό κ°λ₯νκ² νκ³ κ°μννκΈ° μν λΉλ© λΈλ‘κ³Ό
15
+ # end-to-end μμ λ€μ μ 곡ν©λλ€**.
16
+ #
17
+ # λ³Έ νν 리μΌμμλ **μ¬μ νλ ¨λ SoTA λͺ¨λΈμΈ** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **λ₯Ό**
18
+ # **TorchMultimodal λΌμ΄λΈλ¬λ¦¬μμ μ¬μ©νμ¬ λ©ν° λͺ¨λ¬ μμ
μΈ μκ°μ μ§μ μλ΅(VQA)μ λ―ΈμΈμ‘°μ νλ λ°©λ²μ λ³΄μ¬ λλ¦¬κ² μ΅λλ€.**
19
+ # μ΄ λͺ¨λΈμ ν
μ€νΈμ μ΄λ―Έμ§λ₯Ό μν λ κ°μ λ¨μΌ λͺ¨λ¬ νΈλμ€ν¬λ¨Έ κΈ°λ° μΈμ½λμ
20
+ # λ μλ² λ©μ κ²°ν©νλ λ€μ€ λͺ¨λ¬ μΈμ½λλ‘ κ΅¬μ±λμ΄ μμ΅λλ€.
21
+ # μ΄ λͺ¨λΈμ λμ‘°μ , μ΄λ―Έμ§-ν
μ€νΈ λ§€μΉ, κ·Έλ¦¬κ³ ν
μ€νΈ, μ΄λ―Έμ§ λ° λ€μ€ λͺ¨λ¬ λ§μ€νΉ μμ€μ μ¬μ©νμ¬ μ¬μ νλ ¨λμμ΅λλ€.
22
+
22
23
23
24
24
25
######################################################################
25
- # Installation
26
+ # μ€μΉ
26
27
# -----------------
27
- # We will use TextVQA dataset and ``bert tokenizer `` from Hugging Face for this
28
- # tutorial. So you need to install datasets and transformers in addition to TorchMultimodal .
28
+ # μ΄ νν 리μΌμ μν΄μλ TextVQA λ°μ΄ν°μ
κ³Ό Hugging Faceμ ``bert ν ν¬λμ΄μ `` λ₯Ό μ¬μ©ν κ²μ
λλ€.
29
+ # λ°λΌμ TorchMultimodal μΈμλ datasetsκ³Ό transformersλ₯Ό μ€μΉν΄μΌ ν©λλ€ .
29
30
#
30
31
# .. note::
31
- #
32
- # When running this tutorial in Google Colab, install the required packages by
33
- # creating a new cell and running the following commands :
32
+ #
33
+ # μ΄ νν 리μΌμ Google Colabμμ μ€νν κ²½μ°, μλ‘μ΄ μ
μ λ§λ€κ³ λ€μμ λͺ
λ Ήμ΄λ₯Ό μ€ννμ¬
34
+ # νμν ν¨ν€μ§λ₯Ό μ€μΉνμΈμ :
34
35
#
35
36
# .. code-block::
36
37
#
40
41
#
41
42
42
43
######################################################################
43
- # Steps
44
+ # λ¨κ³
44
45
# -----
45
46
#
46
- # 1. Download the Hugging Face dataset to a directory on your computer by running the following command :
47
+ # 1. λ€μ λͺ
λ Ήμ΄λ₯Ό μ€ννμ¬ Hugging Face λ°μ΄ν°μ
μ μ»΄ν¨ν°μ λλ ν 리μ λ€μ΄λ‘λνμΈμ :
47
48
#
48
49
# .. code-block::
49
50
#
50
51
# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
51
52
# tar xf vocab.tar.gz
52
53
#
53
54
# .. note::
54
- # If you are running this tutorial in Google Colab, run these commands
55
- # in a new cell and prepend these commands with an exclamation mark (!)
55
+ # μ΄ νν 리μΌμ Google Colabμμ μ€ννλ κ²½μ°, μ μ
μμ μ΄ λͺ
λ Ήμ΄λ₯Ό μ€ννκ³ λͺ
λ Ήμ΄ μμ λλν (!)λ₯Ό λΆμ΄μΈμ.
56
56
#
57
57
#
58
- # 2. For this tutorial, we treat VQA as a classification task where
59
- # the inputs are images and question (text) and the output is an answer class.
60
- # So we need to download the vocab file with answer classes and create the answer to
61
- # label mapping.
58
+ # 2. λ³Έ νν 리μΌμμλ VQAλ₯Ό μ΄λ―Έμ§μ μ§λ¬Έ(ν
μ€νΈ)μ΄ μ
λ ₯λκ³ μΆλ ₯μ΄ λ΅λ³ ν΄λμ€μΈ λΆλ₯ μμ
μΌλ‘ μ·¨κΈν©λλ€.
59
+ # λ°λΌμ λ΅λ³ ν΄λμ€μ λ μ΄λΈ λ§€νμ μμ±ν λ¨μ΄μ₯ νμΌμ λ€μ΄λ‘λν΄μΌ ν©λλ€.
62
60
#
63
- # We also load the `textvqa
64
- # dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
65
- # (images,questions and answers) from Hugging Face
61
+ # λν Hugging Faceμμ `textvqa λ°μ΄ν°μ
<https://arxiv.org/pdf/1904.08920.pdf>`__ μ λΆλ¬μ€λλ°,
62
+ # μ΄ λ°μ΄ν°μ
μ 34602κ°μ νλ ¨ μν(μ΄λ―Έμ§, μ§λ¬Έ, λ΅λ³)μ ν¬ν¨νκ³ μμ΅λλ€.
66
63
#
67
- # We see there are 3997 answer classes including a class representing
68
- # unknown answers.
64
+ # 3997κ°μ λ΅λ³ ν΄λμ€κ° μμμ νμΈν μ μμΌλ©°, μ΄μλ μ μ μλ λ΅λ³μ λνλ΄λ ν΄λμ€λ ν¬ν¨λμ΄ μμ΅λλ€.
69
65
#
70
66
71
67
with open ("data/vocabs/answers_textvqa_more_than_1.txt" ) as f :
81
77
dataset = load_dataset ("textvqa" )
82
78
83
79
######################################################################
84
- # Lets display a sample entry from the dataset :
80
+ # λ°μ΄ν°μ
μμ μν μνΈλ¦¬λ₯Ό νμν΄ λ΄
μλ€ :
85
81
#
86
82
87
83
import matplotlib .pyplot as plt
95
91
96
92
97
93
######################################################################
98
- # 3. Next, we write the transform function to convert the image and text into
99
- # Tensors consumable by our model - For images, we use the transforms from
100
- # torchvision to convert to Tensor and resize to uniform sizes - For text,
101
- # we tokenize (and pad) them using the ``BertTokenizer`` from Hugging Face -
102
- # For answers (i.e. labels), we take the most frequently occurring answer
103
- # as the label to train with:
94
+ # 3. λ€μμΌλ‘, μ΄λ―Έμ§μ ν
μ€νΈλ₯Ό λͺ¨λΈμμ μ¬μ©ν μ μλ ν
μλ‘ λ³ννκΈ° μν λ³ν ν¨μλ₯Ό μμ±ν©λλ€.
95
+ # - μ΄λ―Έμ§μ κ²½μ°, torchvisionμ λ³νμ μ¬μ©νμ¬ ν
μλ‘ λ³ννκ³ μΌμ ν ν¬κΈ°λ‘ μ‘°μ ν©λλ€.
96
+ # - ν
μ€νΈμ κ²½μ°, Hugging Faceμ ``BertTokenizer`` λ₯Ό μ¬μ©νμ¬ ν ν°ν(λ° ν¨λ©)ν©λλ€.
97
+ # - λ΅λ³(μ¦, λ μ΄λΈ)μ κ²½μ°, κ°μ₯ λΉλ²νκ² λνλλ λ΅λ³μ νλ ¨ λ μ΄λΈλ‘ μ¬μ©ν©λλ€:
104
98
#
105
99
106
100
import torch
@@ -133,25 +127,21 @@ def transform(tokenizer, input):
133
127
134
128
135
129
######################################################################
136
- # 4. Finally, we import the ``flava_model_for_classification`` from
137
- # ``torchmultimodal``. It loads the pretrained FLAVA checkpoint by default and
138
- # includes a classification head.
130
+ # 4. λ§μ§λ§μΌλ‘, ``torchmultimodal`` μμ ``flava_model_for_classification`` μ κ°μ Έμ΅λλ€.
131
+ # μ΄κ²μ κΈ°λ³Έμ μΌλ‘ μ¬μ νλ ¨λ FLAVA 체ν¬ν¬μΈνΈλ₯Ό λ‘λνκ³ λΆλ₯ ν€λλ₯Ό ν¬ν¨ν©λλ€.
139
132
#
140
- # The model forward function passes the image through the visual encoder
141
- # and the question through the text encoder. The image and question
142
- # embeddings are then passed through the multimodal encoder. The final
143
- # embedding corresponding to the CLS token is passed through a MLP head
144
- # which finally gives the probability distribution over each possible
145
- # answers.
133
+ # λͺ¨λΈμ μλ°©ν₯ ν¨μλ μ΄λ―Έμ§λ₯Ό μκ° μΈμ½λμ ν΅κ³Όμν€κ³ μ§λ¬Έμ ν
μ€νΈ μΈμ½λμ ν΅κ³Όμν΅λλ€.
134
+ # μ΄λ―Έμ§μ μ§λ¬Έμ μλ² λ©μ κ·Έ ν λ©ν° λͺ¨λ¬ μΈμ½λλ₯Ό ν΅κ³Όν©λλ€.
135
+ # μ΅μ’
μλ² λ©μ CLS ν ν°μ ν΄λΉνλ©°, μ΄λ MLP ν€λλ₯Ό ν΅κ³Όνμ¬ κ° κ°λ₯ν λ΅λ³μ λν νλ₯ λΆν¬λ₯Ό μ 곡ν©λλ€.
146
136
#
147
137
148
138
from torchmultimodal .models .flava .model import flava_model_for_classification
149
139
model = flava_model_for_classification (num_classes = len (vocab ))
150
140
151
141
152
142
######################################################################
153
- # 5. We put together the dataset and model in a toy training loop to
154
- # demonstrate how to train the model for 3 iterations :
143
+ # 5. λ°μ΄ν°μ
κ³Ό λͺ¨λΈμ ν¨κ» λͺ¨μ 3ν λ°λ³΅μ μν κ°λ¨ν νλ ¨ 루νλ₯Ό μμ±νμ¬
144
+ # λͺ¨λΈ νλ ¨ λ°©λ²μ 보μ¬μ€λλ€ :
155
145
#
156
146
157
147
from torch import nn
@@ -177,14 +167,12 @@ def transform(tokenizer, input):
177
167
178
168
179
169
######################################################################
180
- # Conclusion
170
+ # κ²°λ‘
181
171
# -------------------
182
172
#
183
- # This tutorial introduced the basics around how to finetune on a
184
- # multimodal task using FLAVA from TorchMultimodal. Please also check out
185
- # other examples from the library like
186
- # `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
187
- # which is a multimodal model for object detection and
188
- # `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
189
- # which is multitask model spanning image, video and 3d classification.
173
+ # μ΄ νν 리μΌμμλ TorchMultimodalμ FLAVAλ₯Ό μ¬μ©νμ¬ λ©ν° λͺ¨λ¬ μμ
μ λ―ΈμΈ μ‘°μ νλ
174
+ # κΈ°λ³Έμ μΈ λ°©μμ μκ°νμ΅λλ€. κ°μ²΄ νμ§λ₯Ό μν λ©ν° λͺ¨λ¬ λͺ¨λΈμΈ `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__ κ³Ό
175
+ # μ΄λ―Έμ§, λΉλμ€, 3D λΆλ₯λ₯Ό ν¬κ΄νλ λ€μμ
λͺ¨λΈ `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
176
+ # κ°μ λΌμ΄λΈλ¬λ¦¬μ λ€λ₯Έ μμ λ€λ νμΈν΄ 보μΈμ.
177
+ #
190
178
#
0 commit comments