[Tensorflow] ResNet 구현

CIFAR100 학습을 위한 Tensorflow Residual Network 구현

참고 : https://github.com/weiaicunzai/pytorch-cifar100

공통

Data pipeline

def cifar100_dataloader():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
    x_train, x_test = x_train / 255., x_test / 255.

    train_kwargs = {
        'featurewise_center': True,
        'featurewise_std_normalization': True,
        'rotation_range': 15,
        'horizontal_flip': True,
        # 2 lines for 'transforms.RandomCrop(32, padding=4)'
        'width_shift_range': 5,
        'height_shift_range': 5
    }

    test_kwargs = {
        'featurewise_center': True,
        'featurewise_std_normalization': True
    }
    
    train_gen = ImageDataGenerator(**train_kwargs)
    test_gen = ImageDataGenerator(**test_kwargs)

    # train_gen.fit(x_train)
    train_gen.mean = np.array([[[0.5070751592371323, 0.48654887331495095, 0.4409178433670343]]])
    train_gen.std = np.array([[[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]]])

    # test_gen.fit(x_train)
    test_gen.mean = np.array([[[0.5070751592371323, 0.48654887331495095, 0.4409178433670343]]])
    test_gen.std = np.array([[[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]]])

    train_dl = train_gen.flow(x_train, y_train, batch_size=128, shuffle=True)
    test_dl = test_gen.flow(x_test, y_test, batch_size=128, shuffle=True)

    return train_dl, test_dl

Learning rate scheduling

def lr_schedule(optimizer, epoch, batch, n_batches):
    if epoch == 0:  # learning rate warmup
        optimizer.lr = 0.1 * (batch / n_batches)
    elif epoch <= 60:
        optimizer.lr = 0.1
    elif epoch > 60:
        optimizer.lr = 0.02
    elif epoch > 120:
        optimizer.lr = 0.004
    elif epoch > 160:
        optimizer.lr = 0.0008
    else:
        pass

Weight decay

모든 레이어 가중치에 tf.keras.regularizersregularizers.L2(l2=5e-4) 추가

WEIGHT_DECAY = 5e-4
layers.Conv2D(out_channels, kernel_size=3, strides=strides, padding='same', use_bias=False,
              kernel_regularizer=regularizers.L2(WEIGHT_DECAY))
layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5,
                          gamma_regularizer=regularizers.L2(WEIGHT_DECAY), 
                          beta_regularizer=regularizers.L2(WEIGHT_DECAY))
layers.Dense(n_class,
             kernel_regularizer=regularizers.L2(WEIGHT_DECAY),
             bias_regularizer=regularizers.L2(WEIGHT_DECAY))

Loss / Optimizer

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)

ResNet34

Basic Block

import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow.keras.regularizers as regularizers

class BasicBlock(keras.Model):
    
    expansion = 1
    
    def __init__(self, in_channels, out_channels, strides=1):
        super(BasicBlock, self).__init__()
        
        self.residual_function = keras.Sequential([
            layers.Conv2D(out_channels,
                          kernel_size=3,
                          strides=strides,
                          padding='same',
                          use_bias=False,
                          kernel_regularizer=regularizers.L2(weight_decay)),
            layers.BatchNormalization(axis=-1,
                                      momentum=0.9,
                                      epsilon=1e-5,
                                      gamma_regularizer=regularizers.L2(weight_decay), 
                                      beta_regularizer=regularizers.L2(weight_decay)),
            layers.ReLU(),
            layers.Conv2D(out_channels * BasicBlock.expansion,
                          kernel_size=3,
                          strides=1,
                          padding='same',
                          use_bias=False,
                          kernel_regularizer=regularizers.L2(weight_decay)),
            layers.BatchNormalization(axis=-1,
                                      momentum=0.9,
                                      epsilon=1e-5,
                                      gamma_regularizer=regularizers.L2(weight_decay),
                                      beta_regularizer=regularizers.L2(weight_decay)),
        ])
        
        self.shortcut = keras.Sequential([])
        
        # input shape != output shape
        # change input shape as output shape
        if strides != 1 or in_channels != out_channels * BasicBlock.expansion:
            self.shortcut = keras.Sequential([
                layers.Conv2D(out_channels * BasicBlock.expansion,
                              kernel_size=1,
                              strides=strides,
                              padding='same',
                              use_bias=False,
                              kernel_regularizer=regularizers.L2(weight_decay)),
                layers.BatchNormalization(axis=-1,
                                          momentum=0.9,
                                          epsilon=1e-5,
                                          gamma_regularizer=regularizers.L2(weight_decay),
                                          beta_regularizer=regularizers.L2(weight_decay)),
            ])
    
    def call(self, x):
        return layers.ReLU()(self.residual_function(x) + self.shortcut(x))

ResNet34

def resnet34(n_class=100):
    inputs = keras.Input(shape=(32, 32, 3), name='img')

    conv1 = keras.Sequential([
        layers.Conv2D(64,
                      kernel_size=3,
                      strides=1,
                      padding='same',
                      use_bias=False,
                      kernel_regularizer=regularizers.L2(weight_decay)),
        layers.BatchNormalization(axis=-1,
                                  momentum=0.9,
                                  epsilon=1e-5,
                                  gamma_regularizer=regularizers.L2(weight_decay),
                                  beta_regularizer=regularizers.L2(weight_decay)),
        layers.ReLU()
    ], name='conv1')
    
    conv2_x = keras.Sequential([
        BasicBlock(64, 64, 1),
        BasicBlock(64, 64, 1),
        BasicBlock(64, 64, 1)
    ], name='conv2_x')
    
    conv3_x = keras.Sequential([
        BasicBlock(64, 128, 2),
        BasicBlock(128, 128, 1),
        BasicBlock(128, 128, 1),
        BasicBlock(128, 128, 1)
    ], name='conv3_x')
    
    conv4_x = keras.Sequential([
        BasicBlock(128, 256, 2),
        BasicBlock(256, 256, 1),
        BasicBlock(256, 256, 1),
        BasicBlock(256, 256, 1),
        BasicBlock(256, 256, 1),
        BasicBlock(256, 256, 1)
    ], name='conv4_x')
    
    conv5_x = keras.Sequential([
        BasicBlock(256, 512, 2),
        BasicBlock(512, 512, 1),
        BasicBlock(512, 512, 1)
    ], name='conv5_x')
    
    avg_pool = layers.GlobalAvgPool2D(name='avg_pool')
    classifier = layers.Dense(n_class,
                              name='classifier',
                              kernel_regularizer=regularizers.L2(weight_decay),
                              bias_regularizer=regularizers.L2(weight_decay))
    
    outputs = conv1(inputs)
    outputs = conv2_x(outputs)
    outputs = conv3_x(outputs)
    outputs = conv4_x(outputs)
    outputs = conv5_x(outputs)
    outputs = avg_pool(outputs)
    outputs = classifier(outputs)

    model = keras.Model(inputs, outputs, name="resnet34")

    return model