オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】U-Net #実装編 #01

この記事の読者

深層学習・ディープラーニングの手法の1つである「U-Net」について知りたい.


キーワード・知ってると理解がしやすい

  • U-Net
  • pix2pix

yhayato1320.hatenablog.com

yhayato1320.hatenablog.com


Index

Task

画像の Semantic Segmentation をタスクにします.

Datasets

Oxford- IIITPetDataset と呼ばれる動物の画像のデータセットです.



python api tensorflow_datasets を使ってデータを取得します.

import tensorflow_datasets as tfds
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

Model

U-Net は Down Sampler (Encoder) の部分と Up Sampler (Decoder) の 2つで構成されてます. そのうち、Down Sampler (Encoder) の部分のパラメータは学習済みのパラメータを利用し、それを固定することで、学習するパラメータを減らすことができます.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

Up Sampler (Decoder) の部分は、pix2pix の Up Sampling 部分を使います.

from tensorflow_examples.models.pix2pix import pix2pix

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

モデル全体を連結します.

def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])

    # Downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
      x = up(x)
      concat = tf.keras.layers.Concatenate()
      x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same')  #64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

Train

Optimizer, loss, metrics を決定して、コンパイルします.

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Predict

predict function で結果を出力します.

model.predict()

参考

  • U-Net: Convolutional Networks for Biomedical Image Segmentation
    • 発表論文

arxiv.org

Web サイト

www.tensorflow.org