tensorflow Save and Restore a Model in TensorFlow Saving the model


Saving a model in tensorflow is pretty easy.

Let's say you have a linear model with input x and want to predict an output y. The loss here is the mean square error (MSE). The batch size is 16.

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

Here comes the Saver object, which can have multiple parameters (cf. doc).

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

Finally we train the model in a tf.Session(), for 1000 iterations. We only save the model every 100 iterations here.

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

After running this code, you should see the last 5 checkpoints in your directory:

  • model-500 and model-500.meta
  • model-600 and model-600.meta
  • model-700 and model-700.meta
  • model-800 and model-800.meta
  • model-900 and model-900.meta

Note that in this example, while the saver actually saves both the current values of the variables as a checkpoint and the structure of the graph (*.meta), no specific care was taken w.r.t how to retrieve e.g. the placeholders x and y once the model was restored. E.g. if the restoring is done anywhere else than this training script, it can be cumbersome to retrieve x and y from the restored graph (especially in more complicated models). To avoid that, always give names to your variables / placeholders / ops or think about using tf.collections as shown in one of the remarks.