Save and Restore Models

This lesson will go through how to save and restore the variables in a tensorflow graph. Tensorflow has built-in functions that allow you to save the variables as "checkpoint files".

Checkpoint files

When creating a tensorflow checkpoint file, it actually creates several accompanying files along with it. It is, therefore, a good idea to place the checkpoints in a dedicated subdirectory, to keep all the related files nicely organized.

So let's start by creating a subdirectory called "checkpoints", and specifying the path of the checkpoint file to be "checkpoints/checkpoint.chk".

# Specify the name of the checkpoints directory
checkpoint_dir = "checkpoints"

# Create the directory if it does not already exist
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Specify the path to the checkpoint file
checkpoint_file = os.path.join(checkpoint_dir, "checkpoint.chk")

Saving and Restoring Operations

In order to actually save and restore checkpoints, we need to create a saver operation in the tensorflow graph using tf.train.Saver()

saver = tf.train.Saver(name="saver")

Save

In order to save the trainable variables in a tensorflow graph to the checkpoint file we created, we run the saver.save() function in a currently running tensorflow session.

saver.save(session, checkpoint_file)

Restore

In order to restore the saved values from a checkpoint file we use the saver.restore() function in a currently running tensorflow session.

saver.restore(session, checkpoint_file)

Check existence of checkpoint file

If you want to check if a checkpoint file has already been created and saved, then you can run the following function.

tf.train.checkpoint_exists(checkpoint_file)

NOTE: the function above absolutely requires the parent directory of the checkpoint file to exist. If the directory structure has not been created, then it will throw an error (instead of simply returning False). That is why you should create the necessary directory structure first (like we did at the beginning of this lesson).

Example

The following is a minimal working example of how to put all the components together in a pipeline. It does the following things:

  • If a checkpoint file has not been created, then it initializes the variables to their default values.
  • If a checkpoint file already exists, then it restores the previously saved values.
  • After running the graph operations in a session, it saves the updated values to the checkpoint file.
# Specify the name of the checkpoints directory
checkpoint_dir = "checkpoints"

# Create the directory if it does not already exist
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Specify the path to the checkpoint file
checkpoint_file = os.path.join(checkpoint_dir, "checkpoint.chk")

# CREATE THE GRAPH
graph = tf.Graph()
with graph.as_default():
    tf_w1 = tf.Variable(tf.constant(1, shape=[2, 3]), name="weights_1")
    tf_w2 = tf.Variable(tf.constant(1, shape=[2, 3]), name="weights_2")
    update_vars = tf_w1.assign(tf_w1 + tf_w2) # update the value of w1

    # Create a Saver Object
    saver = tf.train.Saver(name="saver")

# RUN THE SESSION
with tf.Session(graph=graph) as session:
    # Initialize Variables
    if tf.train.checkpoint_exists(checkpoint_file):
        print("Restoring from file: ", checkpoint_file)
        saver.restore(session, checkpoint_file)
    else:
        print("Initializing from scratch")
        session.run(tf.global_variables_initializer())

    # RUN THE GRAPH - updating the variables
    session.run(update_vars)
    w1 = session.run(tf_w1)
    print("Value of w1 a after running: \n", w1_val)

    # Save a snapshot of the variables
    saver.save(session, checkpoint_file)

Explaining the Checkpoint Files

The saver actually saves several files:

  • checkpoint
  • checkpoint.chk.data-00000-of-00001
  • checkpoint.chk.index
  • checkpoint.chk.meta

TODO: Explain each of these files.

Selectively Saving / Restoring Specific Variables

Sometimes it is useful to only save or load up a subset of weights to/from a checkpoint file, eg in the case of transfer learning.

TODO: Add instructions on how to do this.