This lesson will go through how to save and restore the variables in a tensorflow graph. Tensorflow has built-in functions that allow you to save the variables as "checkpoint files".
When creating a tensorflow checkpoint file, it actually creates several accompanying files along with it. It is, therefore, a good idea to place the checkpoints in a dedicated subdirectory, to keep all the related files nicely organized.
So let's start by creating a subdirectory called "checkpoints", and specifying the path of the checkpoint file to be "checkpoints/checkpoint.chk".
# Specify the name of the checkpoints directory checkpoint_dir = "checkpoints" # Create the directory if it does not already exist if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Specify the path to the checkpoint file checkpoint_file = os.path.join(checkpoint_dir, "checkpoint.chk")
In order to actually save and restore checkpoints, we need to create a saver operation in the tensorflow graph using tf.train.Saver()
saver = tf.train.Saver(name="saver")
Save
In order to save the trainable variables in a tensorflow graph to the checkpoint file we created, we run the saver.save()
function in a currently running tensorflow session.
saver.save(session, checkpoint_file)
Restore
In order to restore the saved values from a checkpoint file we use the saver.restore()
function in a currently running tensorflow session.
saver.restore(session, checkpoint_file)
Check existence of checkpoint file
If you want to check if a checkpoint file has already been created and saved, then you can run the following function.
tf.train.checkpoint_exists(checkpoint_file)
NOTE: the function above absolutely requires the parent directory of the checkpoint file to exist. If the directory structure has not been created, then it will throw an error (instead of simply returning False
). That is why you should create the necessary directory structure first (like we did at the beginning of this lesson).
The following is a minimal working example of how to put all the components together in a pipeline. It does the following things:
# Specify the name of the checkpoints directory checkpoint_dir = "checkpoints" # Create the directory if it does not already exist if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Specify the path to the checkpoint file checkpoint_file = os.path.join(checkpoint_dir, "checkpoint.chk") # CREATE THE GRAPH graph = tf.Graph() with graph.as_default(): tf_w1 = tf.Variable(tf.constant(1, shape=[2, 3]), name="weights_1") tf_w2 = tf.Variable(tf.constant(1, shape=[2, 3]), name="weights_2") update_vars = tf_w1.assign(tf_w1 + tf_w2) # update the value of w1 # Create a Saver Object saver = tf.train.Saver(name="saver") # RUN THE SESSION with tf.Session(graph=graph) as session: # Initialize Variables if tf.train.checkpoint_exists(checkpoint_file): print("Restoring from file: ", checkpoint_file) saver.restore(session, checkpoint_file) else: print("Initializing from scratch") session.run(tf.global_variables_initializer()) # RUN THE GRAPH - updating the variables session.run(update_vars) w1 = session.run(tf_w1) print("Value of w1 a after running: \n", w1_val) # Save a snapshot of the variables saver.save(session, checkpoint_file)
The saver actually saves several files:
TODO: Explain each of these files.
Sometimes it is useful to only save or load up a subset of weights to/from a checkpoint file, eg in the case of transfer learning.
TODO: Add instructions on how to do this.