Saving a model in tensorflow is pretty easy.
Let's say you have a linear model with input x
and want to predict an output y
. The loss here is the mean square error (MSE). The batch size is 16.
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
Here comes the Saver object, which can have multiple parameters (cf. doc).
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
Finally we train the model in a tf.Session()
, for 1000
iterations. We only save the model every 100
iterations here.
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
After running this code, you should see the last 5 checkpoints in your directory:
model-500
and model-500.meta
model-600
and model-600.meta
model-700
and model-700.meta
model-800
and model-800.meta
model-900
and model-900.meta
Note that in this example, while the saver
actually saves both the current values of the variables as a checkpoint and the structure of the graph (*.meta
), no specific care was taken w.r.t how to retrieve e.g. the placeholders x
and y
once the model was restored. E.g. if the restoring is done anywhere else than this training script, it can be cumbersome to retrieve x
and y
from the restored graph (especially in more complicated models). To avoid that, always give names to your variables / placeholders / ops or think about using tf.collections
as shown in one of the remarks.