Transfer Learning in Tensroflow for a New Classsification Task

Oct. 13, 2017, 2:57 a.m.

Description

This blog post will go through the steps needed to perform transfer learning using the Inception V3 architecture in python using Tensorflow. There are actually several types of transfer learning, as can be seen in the diagram below. This tutorial will cover the method A2. That is, we will use all the layers from the pre-trained model, except the final layer, creating a new output layer that is similar in structure to the original, but adapted to train on a new classification task. The layers from the pre-trained model will remain trainable, and continue to be fine-tuned for the new task.

Image of taxonomy of transfer LEARNING

I will write the steps for some of the other types of transfer learning in future posts.

About Inception v3

Inception V3 is a powerful deep neural network architecture developed by researchers at Google and described in this paper. A broad overview of the layers of this model can be viewed in the diagram below (using the exact same layer names as in the implementation we will use).

Image of Inception V3 architecture

For a more detailed visualization of the architecture, you can scroll to the very bottom of this tutorial, which shows a tensorboard screenshot of the graph.

Getting the Pretrained Weights

Tensorflow has made a checkpoint file available for an inception V3 model that was trained on the Imagenet dataset. The checkpoint can be downloaded from this link.

We can start by creating a place to store the pre-trained weights. This can be done on the terminal in Linux by running:

# Create a place to save the pre-trained checkpoint
WEIGHTS_DIR="/path/to/inception_v3"
mkdir -p ${WEIGHTS_DIR}
cd ${WEIGHTS_DIR}

We now download and extract the tar.gz file that contains the weights. On Linux this can be done as follows:

wget -c http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar -xzvf inception_v3_2016_08_28.tar.gz

This extracts a file called inception_v3.ckpt into that same directory. This is the checkpoint file containing all the pre-trained weights in a format that Tensorflow will be able to load from.

Preparing the Python Environment

We start by importing the libraries we need.

from __future__ import print_function, division
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim.nets

tensorflow.contrib.slim.nets will give us access to the models that are in the official Tensorflow repository.

We will be making use of Inception V3, which is specified here

Settings

We will set up some important variables that will be used throughout the script. We want to specify the location of the pre-trained checkpoint file we downloaded, as well as where we want to save the snapshot for the fine-tuned model.

# SNAPSHOT FILES
SNAPSHOT_FILE = "/path/to/my/snapshot.ckpt"
PRETRAINED_SNAPSHOT_FILE = "/path/to/inception_v3/inception_v3.ckpt"

We will also specify the dimensions of the inputs and outputs. The weights for the Inception V3 model that we downloaded was trained on RGB images of shape 299 by 299, and 1001 output classes. If we keep the input dimensions exactly the same, and modify the number of output classes, then we will be able to make use of all the pre-trained weights right before the output layer, and only the output layer will need to be initialized to random weights.

IMG_WIDTH, IMG_HEIGHT = [299,299]
N_CHANNELS = 3
N_CLASSES = 10

If you plan on using this model to train on your own images, then you will need to make sure that you resize the images to 299 by 299 in dimensions.

Creating the Inception V3 graph

The functions for specifying the inception V3 architecture are here

In python, the function that creates the Inception V3 architecture is tf.contrib.slim.nets.inception.inception_v3(). But, in order for this function to create the graph correctly, we have to first make use of something called an argument scope. An argument scope is a feature of Tensorflow that allows you to specify your own default values for the arguments in Tensorflow functions. If you wish to learn more about argument scopes, you can check out this blog post.

tf.contrib.slim.nets.inception.inception_v3_arg_scope() creates an argument scope that specifies a bunch of default values for the layer functions used by the inception v3 model. We will create the argument scope as follows:

arg_scope = tf.contrib.slim.nets.inception.inception_v3_arg_scope()

For the argument scope to take effect, we make use of tensroflow's argument scope context manager tf.contrib.framework.arg_scope(). Any functions that are nested inside of this context manager will make use of the default values specified by arg_scope. We can now call the tf.contrib.slim.nets.inception.inception_v3() function inside of this context manager:

with tf.contrib.framework.arg_scope(arg_scope):
    tf_logits, end_points = tf.contrib.slim.nets.inception.inception_v3(
        scaled_inputs,
        num_classes=N_CLASSES,
        is_training=tf_is_training,
        dropout_keep_prob=0.8)

tf.contrib.slim.nets.inception.inception_v3 takes as its first argument the tensor containing the batch of input images (scaled to values between 0-1). It also takes the number of output classes, and a boolean (or boolean placeholder) specifying if the model is currently in training mode. By default, the dropout keep rate is set to 0.8, but you could set this to some other value, or even have a placeholder that feeds in the value.

tf.contrib.slim.nets.inception.inception_v3 returns two things. It returns the final output logits. It also returns end_points, which is a dictionary containing the output tensors for all the important layers from the model. The tensors in end_points are particularly useful when you wish to modify and extend the inception v3 architecture in interesting ways. For instance, for image segmentation, you need access to previous layers in order to create skip-connections.

The tensors in end_points are keyed by their layer names. The list of all the keys is listed below:

'Conv2d_1a_3x3',
'Conv2d_2a_3x3',
'Conv2d_2b_3x3',
'Conv2d_3b_1x1',
'Conv2d_4a_3x3',
'MaxPool_3a_3x3',
'MaxPool_5a_3x3',
'Mixed_5b',
'Mixed_5c',
'Mixed_5d',
'Mixed_6a',
'Mixed_6b',
'Mixed_6c',
'Mixed_6d',
'Mixed_6e',
'Mixed_7a',
'Mixed_7b',
'Mixed_7c',
'AuxLogits'
'PreLogits',
'Logits',
'Predictions'

Saver Operations

As part of the graph, you will need to create two separate saver objects. One to load up the weights for the unmodified portion of the pre-trained inception v3 model. The second saver object will be used to save and restore all the weights associated with the fine-tuned model.

In order to properly use the saver that will load from the pre-trained inception model, we need to specify which weights we need to extract from the checkpoint. Due to the way the name and variable scopes have been set up in the original code, it is quite easy to separate the weights that belong to the trunk of the model, from the weights associated with the output layers.

All the weights associated with the Inception V3 model are inside the "InceptionV3" name scope. The weights associated with the final output layer are in the "InceptionV3/Logits" scope. There is an additional branch of weights in the "InceptionV3/AuxLogits" scope that we do not need either.

# Lists of scopes of weights to include/exclude from pretrained snapshot
pretrained_include = ["InceptionV3"]
pretrained_exclude = ["InceptionV3/AuxLogits", "InceptionV3/Logits"]

We can get the full list of variables we want to extract by using the tf.contrib.framework.get_variables_to_restore() function. These variables are then passed to the saver object that will be used to load up the weights from the pre-trained inception V3 checkpoint.

# PRETRAINED SAVER - For loading pretrained weights on the first run
pretrained_vars = tf.contrib.framework.get_variables_to_restore(include=pretrained_include, exclude=pretrained_exclude)
tf_pretrained_saver = tf.train.Saver(pretrained_vars, name="pretrained_saver")

The second saver will just save and restore all the weights from the model.

# MAIN SAVER - For saving/restoring your complete model
tf_saver = tf.train.Saver(name="saver")

Initialize Variables

When you create a session, you will need to initialize the variables. The very first time you run the model, you will want to initialize the weights using the pre-trained snapshot (and random values for the final layer). However, after saving your model for the first time, you will want to load from your own saved snapshot.

# INITIALIZE VARS
if tf.train.checkpoint_exists(SNAPSHOT_FILE):
    print(" Loading from Main Checkpoint")
    tf_saver.restore(session, SNAPSHOT_FILE)
else:
    print("Initializing from Pretrained Weights")
    session.run(tf.global_variables_initializer())
    tf_pretrained_saver.restore(session, PRETRAINED_SNAPSHOT_FILE)

WARNING : As of tensorflow 1.3 tf.train.checkpoint_exists() requires that the parent directory of the snapshot file actually exists, otherwise, it throws an error.

Saving Snapshots

In order to save a snapshot of your full model, from a running session, simply run:

# SAVE SNAPSHOT
tf_saver.save(session, SNAPSHOT_FILE)

Minimal Working Example

Putting it all together into a complete, yet bare-bones, working example, we have the following:

"""
# ==============================================================================
# MINIMAL TRANSFER LEARNING EXAMPLE USING PRETRAINED INCEPTION V3 MODEL
# Author: Ronny Restrepo
# Origin: http://ronny.rest/blog/post_2017_10_13_tf_transfer_learning
# License: MIT License
# Copyright (c) 2017 Ronny Restrepo
# ==============================================================================
"""
from __future__ import print_function, division
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim.nets

SNAPSHOT_FILE = "/path/to/my/snapshot.ckpt"
PRETRAINED_SNAPSHOT_FILE = "/path/to/inception_v3/inception_v3.ckpt"

# somewhere to store the tensorboard files - to visualise the graph
TENSORBOARD_DIR = "/path/to/my/tensorboard_dir"

# IMAGE SETTINGS
IMG_WIDTH, IMG_HEIGHT = [299,299] # Dimensions required by inception V3
N_CHANNELS = 3                    # Number of channels required by inception V3
N_CLASSES = 10                    # Change N_CLASSES to suit your needs

graph = tf.Graph()
with graph.as_default():
    # INPUTS
    with tf.name_scope("inputs") as scope:
        input_dims = (None, IMG_HEIGHT, IMG_WIDTH, N_CHANNELS)
        tf_X = tf.placeholder(tf.float32, shape=input_dims, name="X")
        tf_Y = tf.placeholder(tf.int32, shape=[None], name="Y")
        tf_alpha = tf.placeholder_with_default(0.001, shape=None, name="alpha")
        tf_is_training = tf.placeholder_with_default(False, shape=None, name="is_training")

    # PREPROCESSING STEPS
    with tf.name_scope("preprocess") as scope:
        scaled_inputs = tf.div(tf_X, 255., name="rescaled_inputs")

    # BODY
    arg_scope = tf.contrib.slim.nets.inception.inception_v3_arg_scope()
    with tf.contrib.framework.arg_scope(arg_scope):
        tf_logits, end_points = tf.contrib.slim.nets.inception.inception_v3(
            scaled_inputs,
            num_classes=N_CLASSES,
            is_training=tf_is_training,
            dropout_keep_prob=0.8)

    # PREDICTIONS
    tf_preds = tf.to_int32(tf.argmax(tf_logits, axis=-1), name="preds")

    # LOSS - Sums all losses (even Regularization losses)
    with tf.variable_scope('loss') as scope:
        unrolled_labels = tf.reshape(tf_Y, (-1,))
        tf.losses.sparse_softmax_cross_entropy(labels=unrolled_labels,
                                               logits=tf_logits)
        tf_loss = tf.losses.get_total_loss()

    # OPTIMIZATION - Also updates batchnorm operations automatically
    with tf.variable_scope('opt') as scope:
        tf_optimizer = tf.train.AdamOptimizer(tf_alpha, name="optimizer")
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # for batchnorm
        with tf.control_dependencies(update_ops):
            tf_train_op = tf_optimizer.minimize(tf_loss, name="train_op")

    # PRETRAINED SAVER SETTINGS
    # Lists of scopes of weights to include/exclude from pretrained snapshot
    pretrained_include = ["InceptionV3"]
    pretrained_exclude = ["InceptionV3/AuxLogits", "InceptionV3/Logits"]

    # PRETRAINED SAVER - For loading pretrained weights on the first run
    pretrained_vars = tf.contrib.framework.get_variables_to_restore(
        include=pretrained_include,
        exclude=pretrained_exclude)
    tf_pretrained_saver = tf.train.Saver(pretrained_vars, name="pretrained_saver")

    # MAIN SAVER - For saving/restoring your complete model
    tf_saver = tf.train.Saver(name="saver")

    # TENSORBOARD - To visialize the architecture
    with tf.variable_scope('tensorboard') as scope:
        tf_summary_writer = tf.summary.FileWriter(TENSORBOARD_DIR, graph=graph)
        tf_dummy_summary = tf.summary.scalar(name="dummy", tensor=1)


def initialize_vars(session):
    # INITIALIZE VARS
    if tf.train.checkpoint_exists(SNAPSHOT_FILE):
        print(" Loading from Main Checkpoint")
        tf_saver.restore(session, SNAPSHOT_FILE)
    else:
        print("Initializing from Pretrained Weights")
        session.run(tf.global_variables_initializer())
        tf_pretrained_saver.restore(session, PRETRAINED_SNAPSHOT_FILE)


with tf.Session(graph=graph) as sess:
    n_epochs = 2
    print_every = 32
    batch_size = 2 # small batch size so inception v3 can be run on laptops
    steps_per_epoch = len(X_train)//batch_size

    initialize_vars(session=sess)

    for epoch in range(n_epochs):
        print("----------------------------------------------")
        print("EPOCH {}/{}".format(epoch+1, n_epochs))
        print("----------------------------------------------")
        for step in range(steps_per_epoch):
            # EXTRACT A BATCH OF TRAINING DATA
            X_batch = X_train[batch_size*step: batch_size*(step+1)]
            Y_batch = Y_train[batch_size*step: batch_size*(step+1)]

            # RUN ONE TRAINING STEP - feeding batch of data
            feed_dict = {tf_X: X_batch,
                         tf_Y: Y_batch,
                         tf_alpha:0.0001,
                         tf_is_training: True}
            loss, _ = sess.run([tf_loss, tf_train_op], feed_dict=feed_dict)

            # PRINT FEED BACK - once every `print_every` steps
            if (step+1)%print_every == 0:
                print("STEP: {: 4d}  LOSS: {:0.4f}".format(step, loss))

        # SAVE SNAPSHOT - after each epoch
        tf_saver.save(sess, SNAPSHOT_FILE)

NOTE : It assumes that you have already loaded the data as:

This, of course, is a very rudimentary example. It is the bare minimum needed to show that training is occurring by only tracking the batch training losses. You should extend it further if you want to use it for training anything seriously.

Viewing model in Tensorboard

Tensorboard is a tool provided by tensorflow that allows you to visualize the architecture of your model visually. The code provided above automatically created the necessary files to use tensorboard. The only thing you need to do now is run the following in a terminal window:

tensorboard --logdir /path/to/my/tensorboard_dir

WARNING : Running and viewing Tensorboard can consume quite a lot of RAM. If your computer is already straining for memory space, then It is advisable that you only run these steps after the model has finished training.

When you run this, it will print out a URL you can use to view tensorboard. Usually, it will be something like:

http://localhost:6006

Copy and paste this URL into your web browser. Then go into the Graphs tab. This will allow you to explore the architecture of the model, by expanding/collapsing different portions of the model.

Here is a screenshot of the inception graph that is displayed once we expand most of the inception module to show the outline of the branches.

Image of Inception V3 tensorboard graph

And expanding Mixed_7c/Branch_2 to see what is inside, we get:

Image of Branch 2

Comments

Note you can comment without any login by:

  1. Typing your comment
  2. Selecting "sign up with Disqus"
  3. Then checking "I'd rather post as a guest"