Training small convolutional neural networks

with mobilenetv2 in the encoder

lightweight networks can be very useful to reduce computational costs and speed up training and inference. The script below can be used to train a model after the records are exported from Google Earth engine.

Find more info in this paper

Use the code below and change the paths to the ones on your desktop / server

import os
import sys
import tensorflow as tf
import glob
from functools import partial

from tensorflow.python import keras
from tensorflow.python.keras import layers, losses, models,  optimizers, Model, metrics
from tensorflow.python.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.metrics import Recall, Precision
from tensorflow.keras.layers import UpSampling2D, Input, Concatenate, Conv2D, Activation, BatchNormalization
from tensorflow.keras import callbacks
from tensorflow.keras.layers import Input, Conv2D, SeparableConv2D, \
     Add, Dense, BatchNormalization, ReLU, MaxPool2D, GlobalAvgPool2D, Conv2D
from tensorflow.keras.applications import MobileNetV2



# Read a serialized example into the structure defined by FEATURES_DICT
def parse_tfrecord(example_proto, features=None, labels=None, patch_shape=None):
    keys = features + labels
    columns = [
        tf.io.FixedLenFeature(shape=patch_shape, dtype=tf.float32) for k in keys
    ]
    proto_struct = dict(zip(keys, columns))
    inputs = tf.io.parse_single_example(example_proto, proto_struct)
    inputs_list = [inputs.get(key) for key in keys]
    stacked = tf.stack(inputs_list, axis=0)
    stacked = tf.transpose(stacked, [1, 2, 0])
    return tf.data.Dataset.from_tensors(stacked)

# Function to read, parse and format to tuple a set of input tfrecord file
def to_tuple(dataset, n_features=None):
    features = dataset[:, :, :, :n_features]
    labels = dataset[:, :, :, n_features:]
    return features, labels

# read the tensorflow records
def get_dataset(files, features, labels, patch_shape, batch_size,
                buffer_size=1000, training=False, **kwargs):
 
    parser = partial(parse_tfrecord,
                     features=features,
                     labels=labels,
                     patch_shape=patch_shape
                     )

    split_data = partial(to_tuple, n_features=len(features))

    dataset = tf.data.TFRecordDataset(files, compression_type='GZIP')
    dataset = dataset.interleave(parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # for training we apply random transform for data augmentation
    if training:
        dataset = dataset.shuffle(buffer_size, reshuffle_each_iteration=True).batch(batch_size) \
            .map(random_transform, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ 
            .map(split_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    else:
        dataset = dataset.batch(batch_size).map(split_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    return dataset


# calculate recall
def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

# calculate precision
def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

# calculate f1 score
def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

# calculate dice coefficient
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# calculate the dice loss    
def dice_loss(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    true_sum = K.sum(K.square(y_true), -1)
    pred_sum = K.sum(K.square(y_pred), -1)
    return 1 - ((2. * intersection + smooth) / (true_sum + pred_sum + smooth))
    

# apply random data transform
def random_transform(dataset):
    x = tf.random.uniform(())

    if x < 0.10:
        dataset = tf.image.flip_left_right(dataset)
    elif tf.math.logical_and(x >= 0.10, x < 0.20):
        dataset = tf.image.flip_up_down(dataset)
    elif tf.math.logical_and(x >= 0.20, x < 0.30):
        dataset = tf.image.flip_left_right(tf.image.flip_up_down(dataset))
    elif tf.math.logical_and(x >= 0.30, x < 0.40):
        dataset = tf.image.rot90(dataset, k=1)
    elif tf.math.logical_and(x >= 0.40, x < 0.50):
        dataset = tf.image.rot90(dataset, k=2)
    elif tf.math.logical_and(x >= 0.50, x < 0.60):
        dataset = tf.image.rot90(dataset, k=3)
    elif tf.math.logical_and(x >= 0.60, x < 0.70):
        dataset = tf.image.flip_left_right(tf.image.rot90(dataset, k=2))
    else:
        pass
    return dataset

# flip data up and down
def flip_inputs_up_down(inputs):
    return tf.image.flip_up_down(inputs)


# flip data left and right
def flip_inputs_left_right(inputs):
    return tf.image.flip_left_right(inputs)


# transpose data
def transpose_inputs(inputs):
    flip_up_down = tf.image.flip_up_down(inputs)
    transpose = tf.image.flip_left_right(flip_up_down)
    return transpose


# rotate the data 90 degrees
def rotate_inputs_90(inputs):
    return tf.image.rot90(inputs, k=1)


# rotate 180 degrees
def rotate_inputs_180(inputs):
    return tf.image.rot90(inputs, k=2)


# rotate 270 degrees
def rotate_inputs_270(inputs):
    return tf.image.rot90(inputs, k=3)


# function to build the unet with mobilenetv2 in the encoder
def build_model_mobilenet():
	# change the number of bands from 4 to 3 when using rgb.
    inputs = Input(shape=(None, None,4), name="input_image")
    
    # use only rgb when using imagenet
    #encoder = MobileNetV2(input_tensor=inputs, weights="imagenet", include_top=False) 
    
    encoder = MobileNetV2(input_tensor=inputs, weights=None, include_top=False) 
    skip_connection_names = ["input_image", "block_1_expand_relu", "block_3_expand_relu", "block_6_expand_relu"]
    encoder_output = encoder.get_layer("block_13_expand_relu").output
    
    f = [32, 64, 128, 256]
    x = encoder_output
    for i in range(1, len(skip_connection_names)+1, 1):
        x_skip = encoder.get_layer(skip_connection_names[-i]).output
        x = UpSampling2D((2, 2))(x)
        x = Concatenate()([x, x_skip])
        
        x = Conv2D(f[-i], (3, 3), padding="same")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        
        x = Conv2D(f[-i], (3, 3), padding="same")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        
    x = Conv2D(2, (2, 2), padding="same")(x)
    x = Activation("sigmoid")(x)
    
    model = Model(inputs, x)
    return model

# function to get the model
def get_model(input_optimizer, input_loss_function, evaluation_metrics):
    model = build_model_mobilenet()
    model.summary()
    model.compile(
        optimizer = input_optimizer, 
        loss = input_loss_function,
        metrics = evaluation_metrics
        )
    return model

if __name__ == "__main__":

	# get the number of gpu;s
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

    # get distributed strategy and apply distribute i/o and model build
    strategy = tf.distribute.MirroredStrategy()

    print('Number of devices: {}'.format(strategy.scope()))

    # Set the path to the raw data
    # this expects three folders with training, testing and validation data
    raw_data_path = r"/path/to/"
        
    # Define the path to the log directory for tensorboard
    log_dir = r'/path/to/log'
        
    # Define the directory where the models will be saved
    model_dir = r'/path/to/'

    # Specify inputs (Landsat bands) to the model and the response variable.
    # the names should be similar as the ones you exported from GEE
    BANDS = ['red','green','blue',"nir"]
    
    # these arte the response variables that you defined in GEE
    RESPONSE = ["crops","other"]

    # we combine the bands and response variables
    FEATURES = BANDS + RESPONSE
    
    # we define the training size here
    TRAIN_SIZE = 120000    
    BATCH_SIZE = 12
    EPOCHS = 40
    BUFFER_SIZE = 4000
    optimizer = 'ADAM'
    eval_metrics = [metrics.categorical_accuracy]
    
    eval_metrics = [metrics.categorical_accuracy, f1_m, precision_m, recall_m]
    
    # Specify the size and shape of patches expected by the model.
    # this is the kernel size specified in GEE. recommended to use 256 or 128
    kernel_size = 128
    kernel_shape = [kernel_size, kernel_size]
    COLUMNS = [tf.io.FixedLenFeature(shape=kernel_shape, dtype=tf.float32) for k in FEATURES]
    FEATURES_DICT = dict(zip(FEATURES, COLUMNS))   
    PATCH_SHAPE = (kernel_size, kernel_size)
   
    training_files = glob.glob(raw_data_path + '/training/*')
    training_ds = get_dataset(training_files, BANDS, RESPONSE, PATCH_SHAPE, BATCH_SIZE,
                              buffer_size=BUFFER_SIZE, training=True).repeat()
    
    testing_files = glob.glob(raw_data_path + '/testing/*')
    testing_ds = get_dataset(testing_files, BANDS, RESPONSE, PATCH_SHAPE, BATCH_SIZE,
                              buffer_size=BUFFER_SIZE, training=True).repeat()

    validation_files = glob.glob(raw_data_path + '/validation/*')
    validation_ds = get_dataset(testing_files, BANDS, RESPONSE, PATCH_SHAPE, BATCH_SIZE,
                              buffer_size=BUFFER_SIZE, training=False)

    # run on multiple gpu if possible
    with strategy.scope():
        model = get_model(optimizer, dice_loss, eval_metrics)
    

    CALLBACK_PARAMETER = 'val_loss'
    
    # set the early stopping algorithm
    early_stopping = callbacks.EarlyStopping(
        monitor=CALLBACK_PARAMETER, patience=12, verbose=0,
        mode='auto', restore_best_weights=True)

    # Fit the model to the data
    model.fit(
        x = training_ds, 
        epochs = EPOCHS, 
        steps_per_epoch =int(TRAIN_SIZE / BATCH_SIZE), 
        validation_data = testing_ds,
        validation_steps = 100,
        callbacks = [tf.keras.callbacks.TensorBoard(log_dir),early_stopping]
        )
    
    # Save the model
    model.save(model_dir, save_format='tf')

    # evaluate the performance of the model
    model.evaluate(val_ds)	


Leave a Reply