Keras Transformer
2021.12.13 更新
2010年代ごろから始まったAIブーム以来、画像認識といえばCNN(Convolutional neural network, 畳み込み)という状況が長くつづいてきましたが、2020年になって、とうとうCNNを一切使わないモデルがSoTAを叩き出したことは驚きのニュースでした。Vision Transformer(ViT)です。
今回は、Vision Transformer(ViT)について調べてみました。
ViTとは、2020年にGoogleが発表した画像認識モデルです。自然言語解析の分野では、BERTやGPT-3などがSelf-Attention構造にもとづいたTransformerを応用したモデルとしてそれまでのRNNに取って替わりましたが、それを画像認識タスクにも応用したものです。
ちなみにAttentionとは、データのどこに注意を向けるべきかを学習するための機構です。
出典: https://arxiv.org/abs/2010.11929
ViTの構造は、BERTと同様にTransformerのEncoder部分と同じです。BERTでは各単語のベクトル表現をEncoderに入力していましたが、ViTでは画像を小さなパッチに分割してベクトル化したものを入力します。
今回もKeras Code Exampleからサンプルを借りて動作を見てみました。
Image classification with Vision Transformer
https://keras.io/examples/vision/image_classification_with_vision_transformer/
使用するデータセットはCIFAR-100です。CIFAR-100は32×32の画像が60000枚あるデータセットで、100のカテゴリからなります。画像をパッチに分割すると以下のようなイメージになります。
Google Colaboratryにサンプスソースを準備して、初期値から学習をスタートしました。
学習自体は正常に動作させることができましたが、処理にかかる時間が長く、1エポックあたり約45分かかりました。サンプルの確認として本来は100エポックまで学習させたいのですが、Google Colabの時間制限にひっかかるため、今回はここで断念してKeras Code Exampleの検証結果をそのまま見たいと思います。
※弊社のGPUパソコンに空きがあるときに再度試したいと思います。
Epoch 1/100 176/176 [==============================] - 33s 136ms/step - loss: 4.8863 - accuracy: 0.0294 - top-5-accuracy: 0.1117 - val_loss: 3.9661 - val_accuracy: 0.0992 - val_top-5-accuracy: 0.3056 Epoch 2/100 176/176 [==============================] - 22s 127ms/step - loss: 4.0162 - accuracy: 0.0865 - top-5-accuracy: 0.2683 - val_loss: 3.5691 - val_accuracy: 0.1630 - val_top-5-accuracy: 0.4226 Epoch 3/100 176/176 [==============================] - 22s 127ms/step - loss: 3.7313 - accuracy: 0.1254 - top-5-accuracy: 0.3535 - val_loss: 3.3455 - val_accuracy: 0.1976 - val_top-5-accuracy: 0.4756 Epoch 4/100 176/176 [==============================] - 23s 128ms/step - loss: 3.5411 - accuracy: 0.1541 - top-5-accuracy: 0.4121 - val_loss: 3.1925 - val_accuracy: 0.2274 - val_top-5-accuracy: 0.5126 Epoch 5/100 176/176 [==============================] - 22s 127ms/step - loss: 3.3749 - accuracy: 0.1847 - top-5-accuracy: 0.4572 - val_loss: 3.1043 - val_accuracy: 0.2388 - val_top-5-accuracy: 0.5320 Epoch 6/100 176/176 [==============================] - 22s 127ms/step - loss: 3.2589 - accuracy: 0.2057 - top-5-accuracy: 0.4906 - val_loss: 2.9319 - val_accuracy: 0.2782 - val_top-5-accuracy: 0.5756 Epoch 7/100 176/176 [==============================] - 22s 127ms/step - loss: 3.1165 - accuracy: 0.2331 - top-5-accuracy: 0.5273 - val_loss: 2.8072 - val_accuracy: 0.2972 - val_top-5-accuracy: 0.5946 Epoch 8/100 176/176 [==============================] - 22s 127ms/step - loss: 2.9902 - accuracy: 0.2563 - top-5-accuracy: 0.5556 - val_loss: 2.7207 - val_accuracy: 0.3188 - val_top-5-accuracy: 0.6258 Epoch 9/100 176/176 [==============================] - 22s 127ms/step - loss: 2.8828 - accuracy: 0.2800 - top-5-accuracy: 0.5827 - val_loss: 2.6396 - val_accuracy: 0.3244 - val_top-5-accuracy: 0.6402 Epoch 10/100 176/176 [==============================] - 23s 128ms/step - loss: 2.7824 - accuracy: 0.2997 - top-5-accuracy: 0.6110 - val_loss: 2.5580 - val_accuracy: 0.3494 - val_top-5-accuracy: 0.6568 Epoch 11/100 176/176 [==============================] - 23s 130ms/step - loss: 2.6743 - accuracy: 0.3209 - top-5-accuracy: 0.6333 - val_loss: 2.5000 - val_accuracy: 0.3594 - val_top-5-accuracy: 0.6726 Epoch 12/100 176/176 [==============================] - 23s 130ms/step - loss: 2.5800 - accuracy: 0.3431 - top-5-accuracy: 0.6522 - val_loss: 2.3900 - val_accuracy: 0.3798 - val_top-5-accuracy: 0.6878 Epoch 13/100 176/176 [==============================] - 23s 128ms/step - loss: 2.5019 - accuracy: 0.3559 - top-5-accuracy: 0.6671 - val_loss: 2.3464 - val_accuracy: 0.3960 - val_top-5-accuracy: 0.7002 Epoch 14/100 176/176 [==============================] - 22s 128ms/step - loss: 2.4207 - accuracy: 0.3728 - top-5-accuracy: 0.6905 - val_loss: 2.3130 - val_accuracy: 0.4032 - val_top-5-accuracy: 0.7040 Epoch 15/100 176/176 [==============================] - 23s 128ms/step - loss: 2.3371 - accuracy: 0.3932 - top-5-accuracy: 0.7093 - val_loss: 2.2447 - val_accuracy: 0.4136 - val_top-5-accuracy: 0.7202 Epoch 16/100 176/176 [==============================] - 23s 128ms/step - loss: 2.2650 - accuracy: 0.4077 - top-5-accuracy: 0.7201 - val_loss: 2.2101 - val_accuracy: 0.4222 - val_top-5-accuracy: 0.7246 Epoch 17/100 176/176 [==============================] - 22s 127ms/step - loss: 2.1822 - accuracy: 0.4204 - top-5-accuracy: 0.7376 - val_loss: 2.1446 - val_accuracy: 0.4344 - val_top-5-accuracy: 0.7416 Epoch 18/100 176/176 [==============================] - 22s 128ms/step - loss: 2.1485 - accuracy: 0.4284 - top-5-accuracy: 0.7476 - val_loss: 2.1094 - val_accuracy: 0.4432 - val_top-5-accuracy: 0.7454 Epoch 19/100 176/176 [==============================] - 22s 128ms/step - loss: 2.0717 - accuracy: 0.4464 - top-5-accuracy: 0.7618 - val_loss: 2.0718 - val_accuracy: 0.4584 - val_top-5-accuracy: 0.7570 Epoch 20/100 176/176 [==============================] - 22s 127ms/step - loss: 2.0031 - accuracy: 0.4605 - top-5-accuracy: 0.7731 - val_loss: 2.0286 - val_accuracy: 0.4610 - val_top-5-accuracy: 0.7654 Epoch 21/100 176/176 [==============================] - 22s 127ms/step - loss: 1.9650 - accuracy: 0.4700 - top-5-accuracy: 0.7820 - val_loss: 2.0225 - val_accuracy: 0.4642 - val_top-5-accuracy: 0.7628 Epoch 22/100 176/176 [==============================] - 22s 127ms/step - loss: 1.9066 - accuracy: 0.4839 - top-5-accuracy: 0.7904 - val_loss: 1.9961 - val_accuracy: 0.4746 - val_top-5-accuracy: 0.7656 Epoch 23/100 176/176 [==============================] - 22s 127ms/step - loss: 1.8564 - accuracy: 0.4952 - top-5-accuracy: 0.8030 - val_loss: 1.9769 - val_accuracy: 0.4828 - val_top-5-accuracy: 0.7742 Epoch 24/100 176/176 [==============================] - 22s 128ms/step - loss: 1.8167 - accuracy: 0.5034 - top-5-accuracy: 0.8099 - val_loss: 1.9730 - val_accuracy: 0.4766 - val_top-5-accuracy: 0.7728 Epoch 25/100 176/176 [==============================] - 22s 128ms/step - loss: 1.7788 - accuracy: 0.5124 - top-5-accuracy: 0.8174 - val_loss: 1.9187 - val_accuracy: 0.4926 - val_top-5-accuracy: 0.7854 Epoch 26/100 176/176 [==============================] - 23s 128ms/step - loss: 1.7437 - accuracy: 0.5187 - top-5-accuracy: 0.8206 - val_loss: 1.9732 - val_accuracy: 0.4792 - val_top-5-accuracy: 0.7772 Epoch 27/100 176/176 [==============================] - 23s 128ms/step - loss: 1.6929 - accuracy: 0.5300 - top-5-accuracy: 0.8287 - val_loss: 1.9109 - val_accuracy: 0.4928 - val_top-5-accuracy: 0.7912 Epoch 28/100 176/176 [==============================] - 23s 129ms/step - loss: 1.6647 - accuracy: 0.5400 - top-5-accuracy: 0.8362 - val_loss: 1.9031 - val_accuracy: 0.4984 - val_top-5-accuracy: 0.7824 Epoch 29/100 176/176 [==============================] - 23s 129ms/step - loss: 1.6295 - accuracy: 0.5488 - top-5-accuracy: 0.8402 - val_loss: 1.8744 - val_accuracy: 0.4982 - val_top-5-accuracy: 0.7910 Epoch 30/100 176/176 [==============================] - 22s 128ms/step - loss: 1.5860 - accuracy: 0.5548 - top-5-accuracy: 0.8504 - val_loss: 1.8551 - val_accuracy: 0.5108 - val_top-5-accuracy: 0.7946 Epoch 31/100 176/176 [==============================] - 22s 127ms/step - loss: 1.5666 - accuracy: 0.5614 - top-5-accuracy: 0.8548 - val_loss: 1.8720 - val_accuracy: 0.5076 - val_top-5-accuracy: 0.7960 Epoch 32/100 176/176 [==============================] - 22s 127ms/step - loss: 1.5272 - accuracy: 0.5712 - top-5-accuracy: 0.8596 - val_loss: 1.8840 - val_accuracy: 0.5106 - val_top-5-accuracy: 0.7966 Epoch 33/100 176/176 [==============================] - 22s 128ms/step - loss: 1.4995 - accuracy: 0.5779 - top-5-accuracy: 0.8651 - val_loss: 1.8660 - val_accuracy: 0.5116 - val_top-5-accuracy: 0.7904 Epoch 34/100 176/176 [==============================] - 22s 128ms/step - loss: 1.4686 - accuracy: 0.5849 - top-5-accuracy: 0.8685 - val_loss: 1.8544 - val_accuracy: 0.5126 - val_top-5-accuracy: 0.7954 Epoch 35/100 176/176 [==============================] - 22s 127ms/step - loss: 1.4276 - accuracy: 0.5992 - top-5-accuracy: 0.8743 - val_loss: 1.8497 - val_accuracy: 0.5164 - val_top-5-accuracy: 0.7990 Epoch 36/100 176/176 [==============================] - 22s 127ms/step - loss: 1.4102 - accuracy: 0.5970 - top-5-accuracy: 0.8768 - val_loss: 1.8496 - val_accuracy: 0.5198 - val_top-5-accuracy: 0.7948 Epoch 37/100 176/176 [==============================] - 22s 126ms/step - loss: 1.3800 - accuracy: 0.6112 - top-5-accuracy: 0.8814 - val_loss: 1.8033 - val_accuracy: 0.5284 - val_top-5-accuracy: 0.8068 Epoch 38/100 176/176 [==============================] - 22s 126ms/step - loss: 1.3500 - accuracy: 0.6103 - top-5-accuracy: 0.8862 - val_loss: 1.8092 - val_accuracy: 0.5214 - val_top-5-accuracy: 0.8128 Epoch 39/100 176/176 [==============================] - 22s 127ms/step - loss: 1.3575 - accuracy: 0.6127 - top-5-accuracy: 0.8857 - val_loss: 1.8175 - val_accuracy: 0.5198 - val_top-5-accuracy: 0.8086 Epoch 40/100 176/176 [==============================] - 22s 126ms/step - loss: 1.3030 - accuracy: 0.6283 - top-5-accuracy: 0.8927 - val_loss: 1.8361 - val_accuracy: 0.5170 - val_top-5-accuracy: 0.8056 Epoch 41/100 176/176 [==============================] - 22s 125ms/step - loss: 1.3160 - accuracy: 0.6247 - top-5-accuracy: 0.8923 - val_loss: 1.8074 - val_accuracy: 0.5260 - val_top-5-accuracy: 0.8082 Epoch 42/100 176/176 [==============================] - 22s 126ms/step - loss: 1.2679 - accuracy: 0.6329 - top-5-accuracy: 0.9002 - val_loss: 1.8430 - val_accuracy: 0.5244 - val_top-5-accuracy: 0.8100 Epoch 43/100 176/176 [==============================] - 22s 126ms/step - loss: 1.2514 - accuracy: 0.6375 - top-5-accuracy: 0.9034 - val_loss: 1.8318 - val_accuracy: 0.5196 - val_top-5-accuracy: 0.8034 Epoch 44/100 176/176 [==============================] - 22s 126ms/step - loss: 1.2311 - accuracy: 0.6431 - top-5-accuracy: 0.9067 - val_loss: 1.8283 - val_accuracy: 0.5218 - val_top-5-accuracy: 0.8050 Epoch 45/100 176/176 [==============================] - 22s 125ms/step - loss: 1.2073 - accuracy: 0.6484 - top-5-accuracy: 0.9098 - val_loss: 1.8384 - val_accuracy: 0.5302 - val_top-5-accuracy: 0.8056 Epoch 46/100 176/176 [==============================] - 22s 125ms/step - loss: 1.1775 - accuracy: 0.6558 - top-5-accuracy: 0.9117 - val_loss: 1.8409 - val_accuracy: 0.5294 - val_top-5-accuracy: 0.8078 Epoch 47/100 176/176 [==============================] - 22s 126ms/step - loss: 1.1891 - accuracy: 0.6563 - top-5-accuracy: 0.9103 - val_loss: 1.8167 - val_accuracy: 0.5346 - val_top-5-accuracy: 0.8142 Epoch 48/100 176/176 [==============================] - 22s 127ms/step - loss: 1.1586 - accuracy: 0.6621 - top-5-accuracy: 0.9161 - val_loss: 1.8285 - val_accuracy: 0.5314 - val_top-5-accuracy: 0.8086 Epoch 49/100 176/176 [==============================] - 22s 126ms/step - loss: 1.1586 - accuracy: 0.6634 - top-5-accuracy: 0.9154 - val_loss: 1.8189 - val_accuracy: 0.5366 - val_top-5-accuracy: 0.8134 Epoch 50/100 176/176 [==============================] - 22s 126ms/step - loss: 1.1306 - accuracy: 0.6682 - top-5-accuracy: 0.9199 - val_loss: 1.8442 - val_accuracy: 0.5254 - val_top-5-accuracy: 0.8096 Epoch 51/100 176/176 [==============================] - 22s 126ms/step - loss: 1.1175 - accuracy: 0.6708 - top-5-accuracy: 0.9227 - val_loss: 1.8513 - val_accuracy: 0.5230 - val_top-5-accuracy: 0.8104 Epoch 52/100 176/176 [==============================] - 22s 126ms/step - loss: 1.1104 - accuracy: 0.6743 - top-5-accuracy: 0.9226 - val_loss: 1.8041 - val_accuracy: 0.5332 - val_top-5-accuracy: 0.8142 Epoch 53/100 176/176 [==============================] - 22s 127ms/step - loss: 1.0914 - accuracy: 0.6809 - top-5-accuracy: 0.9236 - val_loss: 1.8213 - val_accuracy: 0.5342 - val_top-5-accuracy: 0.8094 Epoch 54/100 176/176 [==============================] - 22s 126ms/step - loss: 1.0681 - accuracy: 0.6856 - top-5-accuracy: 0.9270 - val_loss: 1.8429 - val_accuracy: 0.5328 - val_top-5-accuracy: 0.8086 Epoch 55/100 176/176 [==============================] - 22s 126ms/step - loss: 1.0625 - accuracy: 0.6862 - top-5-accuracy: 0.9301 - val_loss: 1.8316 - val_accuracy: 0.5364 - val_top-5-accuracy: 0.8090 Epoch 56/100 176/176 [==============================] - 22s 127ms/step - loss: 1.0474 - accuracy: 0.6920 - top-5-accuracy: 0.9308 - val_loss: 1.8310 - val_accuracy: 0.5440 - val_top-5-accuracy: 0.8132 Epoch 57/100 176/176 [==============================] - 22s 127ms/step - loss: 1.0381 - accuracy: 0.6974 - top-5-accuracy: 0.9297 - val_loss: 1.8447 - val_accuracy: 0.5368 - val_top-5-accuracy: 0.8126 Epoch 58/100 176/176 [==============================] - 22s 126ms/step - loss: 1.0230 - accuracy: 0.7011 - top-5-accuracy: 0.9341 - val_loss: 1.8241 - val_accuracy: 0.5418 - val_top-5-accuracy: 0.8094 Epoch 59/100 176/176 [==============================] - 22s 127ms/step - loss: 1.0113 - accuracy: 0.7023 - top-5-accuracy: 0.9361 - val_loss: 1.8216 - val_accuracy: 0.5380 - val_top-5-accuracy: 0.8134 Epoch 60/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9953 - accuracy: 0.7031 - top-5-accuracy: 0.9386 - val_loss: 1.8356 - val_accuracy: 0.5422 - val_top-5-accuracy: 0.8122 Epoch 61/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9928 - accuracy: 0.7084 - top-5-accuracy: 0.9375 - val_loss: 1.8514 - val_accuracy: 0.5342 - val_top-5-accuracy: 0.8182 Epoch 62/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9740 - accuracy: 0.7121 - top-5-accuracy: 0.9387 - val_loss: 1.8674 - val_accuracy: 0.5366 - val_top-5-accuracy: 0.8092 Epoch 63/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9742 - accuracy: 0.7112 - top-5-accuracy: 0.9413 - val_loss: 1.8274 - val_accuracy: 0.5414 - val_top-5-accuracy: 0.8144 Epoch 64/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9633 - accuracy: 0.7147 - top-5-accuracy: 0.9393 - val_loss: 1.8250 - val_accuracy: 0.5434 - val_top-5-accuracy: 0.8180 Epoch 65/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9407 - accuracy: 0.7221 - top-5-accuracy: 0.9444 - val_loss: 1.8456 - val_accuracy: 0.5424 - val_top-5-accuracy: 0.8120 Epoch 66/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9410 - accuracy: 0.7194 - top-5-accuracy: 0.9447 - val_loss: 1.8559 - val_accuracy: 0.5460 - val_top-5-accuracy: 0.8144 Epoch 67/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9359 - accuracy: 0.7252 - top-5-accuracy: 0.9421 - val_loss: 1.8352 - val_accuracy: 0.5458 - val_top-5-accuracy: 0.8110 Epoch 68/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9232 - accuracy: 0.7254 - top-5-accuracy: 0.9460 - val_loss: 1.8479 - val_accuracy: 0.5444 - val_top-5-accuracy: 0.8132 Epoch 69/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9138 - accuracy: 0.7283 - top-5-accuracy: 0.9456 - val_loss: 1.8697 - val_accuracy: 0.5312 - val_top-5-accuracy: 0.8052 Epoch 70/100 176/176 [==============================] - 22s 126ms/step - loss: 0.9095 - accuracy: 0.7295 - top-5-accuracy: 0.9478 - val_loss: 1.8550 - val_accuracy: 0.5376 - val_top-5-accuracy: 0.8170 Epoch 71/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8945 - accuracy: 0.7332 - top-5-accuracy: 0.9504 - val_loss: 1.8286 - val_accuracy: 0.5436 - val_top-5-accuracy: 0.8198 Epoch 72/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8936 - accuracy: 0.7344 - top-5-accuracy: 0.9479 - val_loss: 1.8727 - val_accuracy: 0.5438 - val_top-5-accuracy: 0.8182 Epoch 73/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8775 - accuracy: 0.7355 - top-5-accuracy: 0.9510 - val_loss: 1.8522 - val_accuracy: 0.5404 - val_top-5-accuracy: 0.8170 Epoch 74/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8660 - accuracy: 0.7390 - top-5-accuracy: 0.9513 - val_loss: 1.8432 - val_accuracy: 0.5448 - val_top-5-accuracy: 0.8156 Epoch 75/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8583 - accuracy: 0.7441 - top-5-accuracy: 0.9532 - val_loss: 1.8419 - val_accuracy: 0.5462 - val_top-5-accuracy: 0.8226 Epoch 76/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8549 - accuracy: 0.7443 - top-5-accuracy: 0.9529 - val_loss: 1.8757 - val_accuracy: 0.5454 - val_top-5-accuracy: 0.8086 Epoch 77/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8578 - accuracy: 0.7384 - top-5-accuracy: 0.9531 - val_loss: 1.9051 - val_accuracy: 0.5462 - val_top-5-accuracy: 0.8136 Epoch 78/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8530 - accuracy: 0.7442 - top-5-accuracy: 0.9526 - val_loss: 1.8496 - val_accuracy: 0.5384 - val_top-5-accuracy: 0.8124 Epoch 79/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8403 - accuracy: 0.7485 - top-5-accuracy: 0.9542 - val_loss: 1.8701 - val_accuracy: 0.5550 - val_top-5-accuracy: 0.8228 Epoch 80/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8410 - accuracy: 0.7491 - top-5-accuracy: 0.9538 - val_loss: 1.8737 - val_accuracy: 0.5502 - val_top-5-accuracy: 0.8150 Epoch 81/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8275 - accuracy: 0.7547 - top-5-accuracy: 0.9532 - val_loss: 1.8391 - val_accuracy: 0.5534 - val_top-5-accuracy: 0.8156 Epoch 82/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8221 - accuracy: 0.7528 - top-5-accuracy: 0.9562 - val_loss: 1.8775 - val_accuracy: 0.5428 - val_top-5-accuracy: 0.8120 Epoch 83/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8270 - accuracy: 0.7526 - top-5-accuracy: 0.9550 - val_loss: 1.8464 - val_accuracy: 0.5468 - val_top-5-accuracy: 0.8148 Epoch 84/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8080 - accuracy: 0.7551 - top-5-accuracy: 0.9576 - val_loss: 1.8789 - val_accuracy: 0.5486 - val_top-5-accuracy: 0.8204 Epoch 85/100 176/176 [==============================] - 22s 125ms/step - loss: 0.8058 - accuracy: 0.7593 - top-5-accuracy: 0.9573 - val_loss: 1.8691 - val_accuracy: 0.5446 - val_top-5-accuracy: 0.8156 Epoch 86/100 176/176 [==============================] - 22s 126ms/step - loss: 0.8092 - accuracy: 0.7564 - top-5-accuracy: 0.9560 - val_loss: 1.8588 - val_accuracy: 0.5524 - val_top-5-accuracy: 0.8172 Epoch 87/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7897 - accuracy: 0.7613 - top-5-accuracy: 0.9604 - val_loss: 1.8649 - val_accuracy: 0.5490 - val_top-5-accuracy: 0.8166 Epoch 88/100 176/176 [==============================] - 22s 126ms/step - loss: 0.7890 - accuracy: 0.7635 - top-5-accuracy: 0.9598 - val_loss: 1.9060 - val_accuracy: 0.5446 - val_top-5-accuracy: 0.8112 Epoch 89/100 176/176 [==============================] - 22s 126ms/step - loss: 0.7682 - accuracy: 0.7687 - top-5-accuracy: 0.9620 - val_loss: 1.8645 - val_accuracy: 0.5474 - val_top-5-accuracy: 0.8150 Epoch 90/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7958 - accuracy: 0.7617 - top-5-accuracy: 0.9600 - val_loss: 1.8549 - val_accuracy: 0.5496 - val_top-5-accuracy: 0.8140 Epoch 91/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7978 - accuracy: 0.7603 - top-5-accuracy: 0.9590 - val_loss: 1.9169 - val_accuracy: 0.5440 - val_top-5-accuracy: 0.8140 Epoch 92/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7898 - accuracy: 0.7630 - top-5-accuracy: 0.9594 - val_loss: 1.9015 - val_accuracy: 0.5540 - val_top-5-accuracy: 0.8174 Epoch 93/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7550 - accuracy: 0.7722 - top-5-accuracy: 0.9622 - val_loss: 1.9219 - val_accuracy: 0.5410 - val_top-5-accuracy: 0.8098 Epoch 94/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7692 - accuracy: 0.7689 - top-5-accuracy: 0.9599 - val_loss: 1.8928 - val_accuracy: 0.5506 - val_top-5-accuracy: 0.8184 Epoch 95/100 176/176 [==============================] - 22s 126ms/step - loss: 0.7783 - accuracy: 0.7661 - top-5-accuracy: 0.9597 - val_loss: 1.8646 - val_accuracy: 0.5490 - val_top-5-accuracy: 0.8166 Epoch 96/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7547 - accuracy: 0.7711 - top-5-accuracy: 0.9638 - val_loss: 1.9347 - val_accuracy: 0.5484 - val_top-5-accuracy: 0.8150 Epoch 97/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7603 - accuracy: 0.7692 - top-5-accuracy: 0.9616 - val_loss: 1.8966 - val_accuracy: 0.5522 - val_top-5-accuracy: 0.8144 Epoch 98/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7595 - accuracy: 0.7730 - top-5-accuracy: 0.9610 - val_loss: 1.8728 - val_accuracy: 0.5470 - val_top-5-accuracy: 0.8170 Epoch 99/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7542 - accuracy: 0.7736 - top-5-accuracy: 0.9622 - val_loss: 1.9132 - val_accuracy: 0.5504 - val_top-5-accuracy: 0.8156 Epoch 100/100 176/176 [==============================] - 22s 125ms/step - loss: 0.7410 - accuracy: 0.7787 - top-5-accuracy: 0.9635 - val_loss: 1.9233 - val_accuracy: 0.5428 - val_top-5-accuracy: 0.8120 313/313 [==============================] - 4s 12ms/step - loss: 1.8487 - accuracy: 0.5514 - top-5-accuracy: 0.8186 Test accuracy: 55.14% Test top 5 accuracy: 81.86%
100エポックの学習が全て完了すると、モデルのTop-1正解率は55.14%、Top-5は81.86%という結果になります。あれ? 決して良いとはいえない精度です。
ViTの論文をみると、ImageNetだけでなく、JFT-300M のような別の大きなデータセットも用いて学習していました。ViTではこのように大きなデータセットでの学習が前提で、サンプルで試したような小さな(?)データセットで、学習済みモデルなしで初期値から学習しただけのような場合には、そこまで真価を発揮しないようです。とってもData Hungry!なのです。
今回は、Self-Attentionを使ったTransformerを画像認識タスクに適用したVision Transformer(ViT)について調べました。ViTでは大量データによる学習を必要としており、サンプルで試したタスクでは良い精度は得られませんでした。
これまでCNNが当たり前だった画像認識タスクですが、全く新しい構造を用いてすばらしい成果を出しているViTは非常に期待の大きいモデルです。しかしAIをもっと社会に適用していきたい現在においては、むしろ少ないデータで良いモデルを得たいことが多いです。
そういった意味では、これまでの転移学習やファインチューニングであったり、その他のFew Shot Learng関連のテクニックのほうが、現状においては活用しやすいかもしれません。
より小さな計算量で、良い精度が得られるViTの改良が期待されます。
[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
https://arxiv.org/abs/2010.11929
[2] Image classification with Vision Transformer
https://keras.io/examples/vision/image_classification_with_vision_transformer/