ネコ科の猛獣識別にfine tuningを適用したら結構うまくいきました

はじめに

ここ数日、画像識別に行き詰っていたのですがようやくネコ科の猛獣判別でいい精度が出せましたので、内容をまとめておきます.

画像識別で迷走した記録 - 他力本願で生き抜く(本気)

いろいろと迷走してきたのですが、今回はタイトルにあるように既存の学習済モデルの一部を利用するfine tuningというテクニックを使って、一気に識別精度を上げたいと思います.


手順概要

  • VGG16のモデルと学習済の重みをダウンロードする
  • モデルの一部のみ学習する(重みを修正する)ようにモデルの変更を行う(fine tuning)
  • fit_generator()を利用して、データの水増しと学習を同時に行う


学習用(画像)データについて

  • 今回扱う問題は4クラス問題(ネコ、ライオン、トラ、リンクス)
  • スクリプトを実行するディレクトリの下にimagesディレクトリを作成.さらに、その直下に訓練用、検証用、テスト画像用フォルダを作成.
  • データの中身を確認し、学習に不適なデータが入らないよう精査しておく(ここ大事です)
  • 今回、訓練用データは各150枚、検証用データは各50枚、テスト用データは各40枚準備しました

f:id:shirakonotempura:20190113005403p:plain
分類する4クラス(左からライオン、トラ、リンクス、ネコ)

f:id:shirakonotempura:20190113013137p:plain
フォルダ構成

訓練データ、検証データ、テストデータの分け方にはいろいろな方法、考え方があると思うのですが、訓練データと検証データで過学習していないことを確認した後、最後に全く学習・検証に関わっていないテストデータで精度のチェックを行えるよう、完全に分けています.


転移学習(Transfer Lerning)とfine tuningについて

転移学習とFine Tuningを完全に混同していましたが、厳密には2つは微妙に違うようです.既存の学習済モデル(出力層以外の部分)を、重みデータは変更せずに特徴量抽出機として利用するとことを転移学習、学習済モデルの重みの一部を再学習して特徴量抽出機として利用することをFine Tuningと呼ぶようです.
私が今回やることは、重みの一部を再学習させているので、おそらく後者のFine Tuningに該当します.

参考:
What is the difference between transfer learning and fine tuning? - Quora


違いがあることは分かったのですが、既存の学習済みモデルを使うというのが転移学習の概念なのであれば、一部を再学習さえてしまおうというFine Tuningは転移学習の派生版ですよね.くくりとしては、転移学習の方が大きいような気もします(完全に私見です.)


転移学習・fine tuningのメリット

これまで何度か試してきたCNNの実装においては、一から学習を行う必要があるため、大量の学習データが必要となります.また、その分学習時間も多くかかります.そこで、既存の学習済モデルを利用することで、少ない画像で学習効率を上げるのが転移学習およびfine tuningになります. 今回は、ImageNETの約120万枚の画像を1000クラスに分類したVGG16というモデルを利用します.VGG16のネットワーク構造を以下に示します.

f:id:shirakonotempura:20190113020724p:plain

今回は、上図において、fine-tuningと書かれた15層目以降のみ学習を行い、14層目までの学習済みの重みは学習済の重みをそのまま使用します.

なぜそんなことが可能、疑問だったのですが、CNNによる学習においては浅い層ではざっくりとした特徴の抽出が行われ、深い層では画像特融の特徴を抽出するということが分かっています.車における特徴抽出のイメージ図を以下に示します.

f:id:shirakonotempura:20190113015136p:plain

(申し訳ありません.よくブログ記事で拝見する画像なのですが、出典を理解できておりません)


実装

実装にあたっては、以下の記事を参考にしています.
絶対VGG16のクラスにないであろうアニメキャラの画像識別を行っています.あらかじめ訓練・検証・テスト用に学習データを分離して学習および最終評価をしており、その検討手順が理解しやすかったです.

qiita.com


VGG16モデルの構築およびコンパイル

主な手順は以下です.

  • VGG16モデルと学習済みの重みをロード
  • 今回の分類用の全結合層を構築
  • VGG16とFC層を接続
  • 学習させない層を定義
  • モデルのコンパイル
# 必要なライブラリのインポート
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.callbacks import CSVLogger

# 今回はネコ、ライオン、トラ、リンクスの4カテゴリに分類する
n_categories=4
batch_size=32
img_size = 100

# 画像データを保存しているディレクトリの定義
train_dir="images/train"
validation_dir="images/valid"
test_dir="images/test"

# display_dirの中身はテストデータと同じもの.ただし、クラスごとのフォルダ分けはしていない
display_dir='images/display'

file_name='cats_fintuning_vgg16'

# VGG16の既存の全結合層は、1000クラス分類用なので使えない→Falseで削除
vgg16_model=VGG16(weights='imagenet',include_top=False,
                 input_tensor=Input(shape=(img_size,img_size,3)))

# 全結合層(FC層)を構築
x=vgg16_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation='relu')(x)
prediction=Dense(n_categories,activation='softmax')(x)

# vgg16と全結合層をつなぐ
model=Model(inputs=vgg16_model.input, outputs=prediction)

# 最後のCONV層の直前までの層を更新しない(freeze)
for layer in vgg16_model.layers[:15]:
    layer.trainable=False

# fine-tuningにおいては、optimizerはSGDが多い.
model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

# 結果の描画用(勉強不足のため理解が追い付かず)
#save model
json_string=model.to_json()
open(file_name+'.json','w').write(json_string)


必要パラメータの定義

  • epochs はエポックの数
  • num_trainingおよびnum_validation:generatorで繰り返し作成する画像の枚数. -1回のエポックでnum_training枚およびnum_validation枚の画像を作成して学習あるいは検証を行う.
  • num_of_test:テストフォルダに入っている画像の総数(今回は40枚×4クラスで160枚)
  • label:テストデータで推定した結果を表示する際に使用します.(alphabetにしておかないと、ダメっぽい?(要精査))
epochs = 50
num_training = 1600 
num_validation = 400
num_of_test = 160
label = ["cat","lion", "lynx", "tiger"]


fit_genratorを使って学習

ImageDataGeneratorで学習データの水増しを行います. 今回は、これまでのようにデータを一度作成して保存しておくわけではなくfit_generator()により、学習データの水増しと、それを使った学習を同時に行っています.flow_from_directoryで指定したフォルダ名がそのままラベルに使われるので非常に楽です.

train_datagen=ImageDataGenerator(
    rescale=1.0/255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

validation_datagen=ImageDataGenerator(rescale=1.0/255)

train_generator=train_datagen.flow_from_directory(
    train_dir,
    target_size=(img_size,img_size),
    color_mode = "rgb",
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

validation_generator=validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(img_size,img_size),
    color_mode = "rgb",
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

hist=model.fit_generator(train_generator,
                         steps_per_epoch = num_training//batch_size,
                         epochs=epochs,
                         verbose=1,
                         validation_data=validation_generator,
                         validation_steps = num_validation//batch_size,
                         callbacks=[CSVLogger(file_name+'.csv')])

#save weights
model.save(file_name+'.h5')


学習は順調に進み、最終的には検証データの精度90.6%となりました.工夫をすればもう少し良くなるかもしれません.

(一部)省略
Epoch 50/50
50/50 [===] - 7s 134ms/step - acc: 0.9975  - val_acc: 0.9056


テストデータによる結果

最後に、訓練にも検証にも使用していない、テスト用の画像データを使ってモデルの評価を行います. まず、精度の値.

Found 160 images belonging to 4 classes.
 test loss: 0.2039002388715744
 test_acc: 0.91875


テストデータによる精度は、検証精度とほぼ同等の91.8%!今までテストデータの精度だけ悪かったので、これは素直にうれしいです.fine tuning おそるべし・・.

次にdisplayディレクトリからランダムに選んだ画像に対してラベルを予測した結果を示します.
ちゃんと予想できており、一安心です.
完全にGod_KonaBananaさんのやり方を利用させてもらっています.

f:id:shirakonotempura:20190113022435p:plain

補足

補足1:
実は最初img_sizeを50x50で実行したのですが、その際の精度は80%程度でした.次に100x100にして今回の結果となりましたが元の画像サイズが75x75なので、75より大きくしても余り意味はないのかもしれません.

補足2:
水増しの影響も見るため、用意した600枚の訓練データ、200枚の検証データのみを使った検討も行いました.テストデータに対して90%近い精度を出してしまいました.水増しの効果も結構あると思っていたので、あまり素直に喜べない・・.水増しなしのアプローチは最初に行うべきなのかと思います.あとはfine tuningが効果的すぎるということなんでしょう.

補足3:
これも検討の前に調べてとけっていう話なんですが、今回用意した4クラスはいずれもVGG16の1000クラスの中に名前が挙がっていました.lynxはないだろうと思っていたのですが見事に存在しておりました.ですので、少し悲しいのですが、かなり有利な問題設定になってしまっていると思います.


まとめ

fine tuningを使って、ネコ科の猛獣4クラス分類問題を再度行いました.
結果としては、まだ改善の余地ありという結果ですがこれまでよりはかなり高精度で分類することができました.ただし、用意した4クラスがVGG16の中に含まれているため、高精度は当然の結果とも言えます.今後、全く関係ないクラスでの識別などもやっていきたいと思います.

Google Colaboratory上のノートはコチラ


今回も多くの記事を参考にさせていただきました.ありがとうございました

実装の手順が、細かく書かれています

fit_generatorのstep数などの理解に役立ちました.

非常に丁寧に説明がされています.顔認識に挑戦したくなりました.

もちろん公式ドキュメントも見ましょう.