fine tuningで芸術作品の識別をやってみた
はじめに
こんにちは、前々回に引き続きfine tuningを使ってまた別の識別問題に挑戦したいと思います.なお、前回は2クラスの識別問題として俳優の福士蒼汰さんと中川大志さんを識別するAIを作ってみました.
shirakonotempura.hatenablog.com
今回やること:芸術作品(主に絵画)の識別問題
今回は、タイトルにあるとおり絵画の識別をやってみたいと思います. つまり、作品を見てその作者(画家)を推定するAIを作ってみたいと思います.
絵画等芸術作品のブログ等への転載についてはついて詳しく調べる時間がなかったので、前回以上にモザイク多めです.ご了承ください.
集めたデータについて
今回対象とした作家は、セザンヌ、ゴーギャン、ゴッホ、ピカソ、ルノワールの5名としました.作品画像は、作家1人につき約50枚程度収集し、以下のようにデータを分けました. テストデータに各クラス5枚を確保し、残りを訓練と検証用に8:2で分けています.
class | train | valid | test |
---|---|---|---|
cezanne | 35 | 11 | 5 |
gaugain | 35 | 11 | 5 |
gogh | 34 | 11 | 5 |
picasso | 34 | 11 | 5 |
renoir | 33 | 11 | 5 |
かなりデータの数が少ないのですが、fine tuningの力を信じでやってみます.
実装
実装部分は、ほとんどこれまでと変わらないのでさくっと行きます.
- modelの最終層の活性化関数は
activation = "softmax"
- モデルのコンパイルで定義するloss関数は
loss = "categorical_crossentropy"
- ジェネレーターで定義するclass_modeは
class_mode = "categorical"
モデルの定義およびコンパイル
#---------Prameter for Modeling------------------------- # 今回は6カテゴリに分類する n_categories=5 batch_size=32 img_size = 120 # 画像データを保存しているディレクトリの定義(trainとvalidは同じデータ) train_dir="images/train" validation_dir="images/valid" test_dir="images/test" #------------------------------------------------------- # display_dirの中身はテストデータと同じもの.ただし、クラスごとのフォルダ分けはしていない display_dir='images/display' file_name='art_finetuning_vgg16' # VGG16の既存の全結合層は、1000クラス分類用なので使えない→Falseで削除 vgg16_model=VGG16(weights='imagenet',include_top=False, input_tensor=Input(shape=(img_size,img_size,3))) # 全結合層(FC層)を構築 x=vgg16_model.output x=GlobalAveragePooling2D()(x) x=Dense(1024,activation='relu')(x) prediction=Dense(n_categories, activation='softmax')(x) # vgg16と全結合層をつなぐ model=Model(inputs=vgg16_model.input, outputs=prediction) # 最後のCONV層の直前までの層を更新しない(freeze) for layer in vgg16_model.layers[:15]: layer.trainable=False # fine-tuningにおいては、optimizerはSGDが多い. model.compile(optimizer=SGD(lr=0.0001,momentum=0.9), loss='categorical_crossentropy', # 2クラスの場合はbinary_crossentropy metrics=['accuracy']) model.summary() #save model json_string=model.to_json() open(file_name+'.json','w').write(json_string)
fit_generatorでデータの水増しおよび学習
def gen_and_fit(model, file_name): train_datagen=ImageDataGenerator( rescale=1.0/255, rotation_range = 30, shear_range=0.3, zoom_range=0.3, width_shift_range = 0.2, height_shift_range = 0.2, horizontal_flip=True) validation_datagen=ImageDataGenerator( rescale=1.0/255 ) train_generator=train_datagen.flow_from_directory( train_dir, target_size=(img_size,img_size), color_mode = "rgb", batch_size=batch_size, class_mode='categorical', # categoricalからbinaryに変更 shuffle=True ) # 検証用データは一度に検証用データの枚数だけ作成(正規化するだけ) validation_generator=validation_datagen.flow_from_directory( validation_dir, target_size=(img_size,img_size), color_mode = "rgb", batch_size= num_validation, class_mode='categorical', shuffle=True ) hist=model.fit_generator(train_generator, steps_per_epoch = num_training//batch_size, epochs=epochs, verbose=1, validation_data=validation_generator, validation_steps = 1, #num_validation//batch_size, callbacks=[CSVLogger(file_name+'.csv')]) #save weights model.save(file_name+'.h5') return hist
# 実行して描画 log2graph(gen_and_fit(model, file_name), "art")
学習履歴のグラフを以下に示します.10エポック目くらいから検証データの精度が上がっていませんね.本当はここでモデルの見直しをする必要があるのですが、このままテストデータを使った評価に移ります.最終エポックの検証データに対する精度と同じく80%いかないくらいの精度になるのでしょう.
テストデータを使って評価
テストディレクトリに保存した計25枚のテストデータを使って、学習済モデルの評価を行います.
# テストデータによる評価 def testeval(model, file_name): #load model and weights json_string=open(file_name+'.json').read() model=model_from_json(json_string) model.load_weights(file_name+'.h5') model.compile(optimizer=SGD(lr=0.0001,momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy']) #data generate test_datagen=ImageDataGenerator(rescale=1.0/255) test_generator=test_datagen.flow_from_directory( # validation_dir, test_dir, target_size=(img_size,img_size), batch_size= num_of_test, # batch_size = num_validation, class_mode="categorical", shuffle=True ) #evaluate model score=model.evaluate_generator(test_generator, steps = 1) print('\n test loss:',score[0]) print('\n test_acc:',score[1]) return score
出力結果
Found 25 images belonging to 5 classes.
test loss: 1.1050524711608887
test_acc: 0.7200000286102295
テストデータに対する精度は72%となりました.テストデータは25枚でしたので、7枚間違ったことになります.
テストデータの予測結果例
テストデータを25枚しか用意していないので、すべて表示して確認します.
# 以下は図で表示するためのスクリプト nb_of_disp = 25 nb_of_row = 5 #np.sqrt(nb_of_disp) nb_of_col = 5 #np.sqrt(nb_of_disp) files=os.listdir(display_dir) img=random.sample(files,nb_of_disp) plt.figure(figsize=(10,10)) for i in range(nb_of_disp): temp_img=load_img(os.path.join(display_dir,img[i]),target_size=(img_size,img_size)) plt.subplot(nb_of_row, nb_of_col,i+1) plt.imshow(temp_img) #Images normalization temp_img_array=img_to_array(temp_img) temp_img_array=temp_img_array.astype('float32')/255.0 temp_img_array=temp_img_array.reshape((1,img_size,img_size,3)) #predict image img_pred=model.predict(temp_img_array) plt.title(label[np.argmax(img_pred)]) plt.xticks([]),plt.yticks([]) plt.show()
もはやモザイクで良く分からないですね.今後、正解しているのかどうかが分かるような表示方法に修正していくようにします.
まとめ
今回は、芸術作品を見て作家を推定するAIを作成しました.ちょっと絵画を知っている人くらいの精度は出たんでしょうか.まあ、AIとしては全然ダメな精度ですね.
ただ、ここまでデータが少ない場合、どうすればいいんでしょうか.水増しするにしても限界があるというか.詳しい人の意見が欲しい・・.
Colaboratory上のノートはコチラ
今回参考にさせていただいたすばらしい記事たちです.ありがとうございます.
VGG16を転移学習させて「まどか☆マギカ」のキャラを見分ける - Qiita