This blog post will go through the steps needed to perform transfer learning
using the Inception V3 architecture in python using Tensorflow. There are
actually several types of transfer learning, as can be seen in the diagram
below. This tutorial will cover the method A2
. That
is, we will use all the layers from the pre-trained model, except the final layer,
creating a new output layer that is similar in structure to the original, but
adapted to train on a new classification task. The layers from
the pre-trained model will remain trainable, and continue to be fine-tuned for
the new task.
I will write the steps for some of the other types of transfer learning in future posts.
Inception V3 is a powerful deep neural network architecture developed by researchers at Google and described in this paper. A broad overview of the layers of this model can be viewed in the diagram below (using the exact same layer names as in the implementation we will use).
For a more detailed visualization of the architecture, you can scroll to the very bottom of this tutorial, which shows a tensorboard screenshot of the graph.
Tensorflow has made a checkpoint file available for an inception V3 model that was trained on the Imagenet dataset. The checkpoint can be downloaded from this link.
We can start by creating a place to store the pre-trained weights. This can be done on the terminal in Linux by running:
# Create a place to save the pre-trained checkpoint WEIGHTS_DIR="/path/to/inception_v3" mkdir -p ${WEIGHTS_DIR} cd ${WEIGHTS_DIR}
We now download and extract the tar.gz
file that contains the weights. On
Linux this can be done as follows:
wget -c http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz tar -xzvf inception_v3_2016_08_28.tar.gz
This extracts a file called inception_v3.ckpt
into that same directory. This
is the checkpoint file containing all the pre-trained weights in a format that
Tensorflow will be able to load from.
We start by importing the libraries we need.
from __future__ import print_function, division import numpy as np import tensorflow as tf import tensorflow.contrib.slim.nets
tensorflow.contrib.slim.nets
will give us access to the models that are in the
official Tensorflow repository.
We will be making use of Inception V3, which is specified here
We will set up some important variables that will be used throughout the script. We want to specify the location of the pre-trained checkpoint file we downloaded, as well as where we want to save the snapshot for the fine-tuned model.
# SNAPSHOT FILES SNAPSHOT_FILE = "/path/to/my/snapshot.ckpt" PRETRAINED_SNAPSHOT_FILE = "/path/to/inception_v3/inception_v3.ckpt"
We will also specify the dimensions of the inputs and outputs. The weights for the Inception V3 model that we downloaded was trained on RGB images of shape 299 by 299, and 1001 output classes. If we keep the input dimensions exactly the same, and modify the number of output classes, then we will be able to make use of all the pre-trained weights right before the output layer, and only the output layer will need to be initialized to random weights.
IMG_WIDTH, IMG_HEIGHT = [299,299] N_CHANNELS = 3 N_CLASSES = 10
If you plan on using this model to train on your own images, then you will need
to make sure that you resize the images to 299
by 299
in dimensions.
The functions for specifying the inception V3 architecture are here
In python, the function that creates the Inception V3 architecture is
tf.contrib.slim.nets.inception.inception_v3()
. But, in order for this function
to create the graph correctly, we have to first make use of something called an
argument scope. An argument scope is a feature of Tensorflow that allows you to
specify your own default values for the arguments in Tensorflow functions. If
you wish to learn more about argument scopes, you can check out this
blog post.
tf.contrib.slim.nets.inception.inception_v3_arg_scope()
creates an argument
scope that specifies a bunch of default values for the layer functions used by
the inception v3 model. We will create the argument scope as follows:
arg_scope = tf.contrib.slim.nets.inception.inception_v3_arg_scope()
For the argument scope to take effect, we make use of tensroflow's argument
scope context manager tf.contrib.framework.arg_scope()
. Any functions that are
nested inside of this context manager will make use of the default values
specified by arg_scope
. We can now call the tf.contrib.slim.nets.inception.inception_v3()
function inside of this context manager:
with tf.contrib.framework.arg_scope(arg_scope): tf_logits, end_points = tf.contrib.slim.nets.inception.inception_v3( scaled_inputs, num_classes=N_CLASSES, is_training=tf_is_training, dropout_keep_prob=0.8)
tf.contrib.slim.nets.inception.inception_v3
takes as its first argument the
tensor containing the batch of input images (scaled to values between 0-1). It
also takes the number of output classes, and a boolean (or boolean placeholder)
specifying if the model is currently in training mode. By default, the dropout
keep rate is set to 0.8, but you could set this to some other value, or even
have a placeholder that feeds in the value.
tf.contrib.slim.nets.inception.inception_v3
returns two things. It returns
the final output logits. It also returns end_points
, which is a dictionary
containing the output tensors for all the important layers from the model. The
tensors in end_points
are particularly useful when you wish to modify and
extend the inception v3 architecture in interesting ways. For instance, for
image segmentation, you need access to previous layers in order to create
skip-connections.
The tensors in end_points
are keyed by their layer names. The list of all
the keys is listed below:
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_3a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'AuxLogits' 'PreLogits', 'Logits', 'Predictions'
As part of the graph, you will need to create two separate saver objects. One to load up the weights for the unmodified portion of the pre-trained inception v3 model. The second saver object will be used to save and restore all the weights associated with the fine-tuned model.
In order to properly use the saver that will load from the pre-trained inception model, we need to specify which weights we need to extract from the checkpoint. Due to the way the name and variable scopes have been set up in the original code, it is quite easy to separate the weights that belong to the trunk of the model, from the weights associated with the output layers.
All the weights associated with the Inception V3 model are inside the
"InceptionV3"
name scope. The weights associated with the final output layer
are in the "InceptionV3/Logits"
scope. There is an additional branch of weights
in the "InceptionV3/AuxLogits"
scope that we do not need either.
# Lists of scopes of weights to include/exclude from pretrained snapshot pretrained_include = ["InceptionV3"] pretrained_exclude = ["InceptionV3/AuxLogits", "InceptionV3/Logits"]
We can get the full list of variables we want to extract by using the
tf.contrib.framework.get_variables_to_restore()
function. These variables are
then passed to the saver object that will be used to load up the weights
from the pre-trained inception V3 checkpoint.
# PRETRAINED SAVER - For loading pretrained weights on the first run pretrained_vars = tf.contrib.framework.get_variables_to_restore(include=pretrained_include, exclude=pretrained_exclude) tf_pretrained_saver = tf.train.Saver(pretrained_vars, name="pretrained_saver")
The second saver will just save and restore all the weights from the model.
# MAIN SAVER - For saving/restoring your complete model tf_saver = tf.train.Saver(name="saver")
When you create a session, you will need to initialize the variables. The very first time you run the model, you will want to initialize the weights using the pre-trained snapshot (and random values for the final layer). However, after saving your model for the first time, you will want to load from your own saved snapshot.
# INITIALIZE VARS if tf.train.checkpoint_exists(SNAPSHOT_FILE): print(" Loading from Main Checkpoint") tf_saver.restore(session, SNAPSHOT_FILE) else: print("Initializing from Pretrained Weights") session.run(tf.global_variables_initializer()) tf_pretrained_saver.restore(session, PRETRAINED_SNAPSHOT_FILE)
WARNING : As of tensorflow 1.3 tf.train.checkpoint_exists()
requires that
the parent directory of the snapshot file actually exists, otherwise, it throws
an error.
In order to save a snapshot of your full model, from a running session, simply run:
# SAVE SNAPSHOT tf_saver.save(session, SNAPSHOT_FILE)
Putting it all together into a complete, yet bare-bones, working example, we have the following:
""" # ============================================================================== # MINIMAL TRANSFER LEARNING EXAMPLE USING PRETRAINED INCEPTION V3 MODEL # Author: Ronny Restrepo # Origin: http://ronny.rest/blog/post_2017_10_13_tf_transfer_learning # License: MIT License # Copyright (c) 2017 Ronny Restrepo # ============================================================================== """ from __future__ import print_function, division import numpy as np import tensorflow as tf import tensorflow.contrib.slim.nets SNAPSHOT_FILE = "/path/to/my/snapshot.ckpt" PRETRAINED_SNAPSHOT_FILE = "/path/to/inception_v3/inception_v3.ckpt" # somewhere to store the tensorboard files - to visualise the graph TENSORBOARD_DIR = "/path/to/my/tensorboard_dir" # IMAGE SETTINGS IMG_WIDTH, IMG_HEIGHT = [299,299] # Dimensions required by inception V3 N_CHANNELS = 3 # Number of channels required by inception V3 N_CLASSES = 10 # Change N_CLASSES to suit your needs graph = tf.Graph() with graph.as_default(): # INPUTS with tf.name_scope("inputs") as scope: input_dims = (None, IMG_HEIGHT, IMG_WIDTH, N_CHANNELS) tf_X = tf.placeholder(tf.float32, shape=input_dims, name="X") tf_Y = tf.placeholder(tf.int32, shape=[None], name="Y") tf_alpha = tf.placeholder_with_default(0.001, shape=None, name="alpha") tf_is_training = tf.placeholder_with_default(False, shape=None, name="is_training") # PREPROCESSING STEPS with tf.name_scope("preprocess") as scope: scaled_inputs = tf.div(tf_X, 255., name="rescaled_inputs") # BODY arg_scope = tf.contrib.slim.nets.inception.inception_v3_arg_scope() with tf.contrib.framework.arg_scope(arg_scope): tf_logits, end_points = tf.contrib.slim.nets.inception.inception_v3( scaled_inputs, num_classes=N_CLASSES, is_training=tf_is_training, dropout_keep_prob=0.8) # PREDICTIONS tf_preds = tf.to_int32(tf.argmax(tf_logits, axis=-1), name="preds") # LOSS - Sums all losses (even Regularization losses) with tf.variable_scope('loss') as scope: unrolled_labels = tf.reshape(tf_Y, (-1,)) tf.losses.sparse_softmax_cross_entropy(labels=unrolled_labels, logits=tf_logits) tf_loss = tf.losses.get_total_loss() # OPTIMIZATION - Also updates batchnorm operations automatically with tf.variable_scope('opt') as scope: tf_optimizer = tf.train.AdamOptimizer(tf_alpha, name="optimizer") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # for batchnorm with tf.control_dependencies(update_ops): tf_train_op = tf_optimizer.minimize(tf_loss, name="train_op") # PRETRAINED SAVER SETTINGS # Lists of scopes of weights to include/exclude from pretrained snapshot pretrained_include = ["InceptionV3"] pretrained_exclude = ["InceptionV3/AuxLogits", "InceptionV3/Logits"] # PRETRAINED SAVER - For loading pretrained weights on the first run pretrained_vars = tf.contrib.framework.get_variables_to_restore( include=pretrained_include, exclude=pretrained_exclude) tf_pretrained_saver = tf.train.Saver(pretrained_vars, name="pretrained_saver") # MAIN SAVER - For saving/restoring your complete model tf_saver = tf.train.Saver(name="saver") # TENSORBOARD - To visialize the architecture with tf.variable_scope('tensorboard') as scope: tf_summary_writer = tf.summary.FileWriter(TENSORBOARD_DIR, graph=graph) tf_dummy_summary = tf.summary.scalar(name="dummy", tensor=1) def initialize_vars(session): # INITIALIZE VARS if tf.train.checkpoint_exists(SNAPSHOT_FILE): print(" Loading from Main Checkpoint") tf_saver.restore(session, SNAPSHOT_FILE) else: print("Initializing from Pretrained Weights") session.run(tf.global_variables_initializer()) tf_pretrained_saver.restore(session, PRETRAINED_SNAPSHOT_FILE) with tf.Session(graph=graph) as sess: n_epochs = 2 print_every = 32 batch_size = 2 # small batch size so inception v3 can be run on laptops steps_per_epoch = len(X_train)//batch_size initialize_vars(session=sess) for epoch in range(n_epochs): print("----------------------------------------------") print("EPOCH {}/{}".format(epoch+1, n_epochs)) print("----------------------------------------------") for step in range(steps_per_epoch): # EXTRACT A BATCH OF TRAINING DATA X_batch = X_train[batch_size*step: batch_size*(step+1)] Y_batch = Y_train[batch_size*step: batch_size*(step+1)] # RUN ONE TRAINING STEP - feeding batch of data feed_dict = {tf_X: X_batch, tf_Y: Y_batch, tf_alpha:0.0001, tf_is_training: True} loss, _ = sess.run([tf_loss, tf_train_op], feed_dict=feed_dict) # PRINT FEED BACK - once every `print_every` steps if (step+1)%print_every == 0: print("STEP: {: 4d} LOSS: {:0.4f}".format(step, loss)) # SAVE SNAPSHOT - after each epoch tf_saver.save(sess, SNAPSHOT_FILE)
NOTE : It assumes that you have already loaded the data as:
X_train
: a numpy array containing input images with:[n_samples, 299, 299, 3]
0
to 255
.Y_train
: a numpy array of class labels with:[n_samples]
0
to (N_CLASSES-1
)This, of course, is a very rudimentary example. It is the bare minimum needed to show that training is occurring by only tracking the batch training losses. You should extend it further if you want to use it for training anything seriously.
Tensorboard is a tool provided by tensorflow that allows you to visualize the architecture of your model visually. The code provided above automatically created the necessary files to use tensorboard. The only thing you need to do now is run the following in a terminal window:
tensorboard --logdir /path/to/my/tensorboard_dir
WARNING : Running and viewing Tensorboard can consume quite a lot of RAM. If your computer is already straining for memory space, then It is advisable that you only run these steps after the model has finished training.
When you run this, it will print out a URL you can use to view tensorboard. Usually, it will be something like:
http://localhost:6006
Copy and paste this URL into your web browser. Then go into the Graphs
tab.
This will allow you to explore the architecture of the model, by
expanding/collapsing different portions of the model.
Here is a screenshot of the inception graph that is displayed once we expand most of the inception module to show the outline of the branches.
And expanding Mixed_7c/Branch_2
to see what is inside, we get:
Note you can comment without any login by: