Restoring is also quite nice and easy.
Here's a handy helper function:
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
sess.run(tf.initialize_all_variables())
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
try:
print("making checkpoint_dir")
os.makedirs(checkpoint_dir)
return False
except OSError:
raise
path = tf.train.get_checkpoint_state(checkpoint_dir)
print("path = ",path)
if path is None:
return False
else:
saver.restore(sess, path.model_checkpoint_path)
return True
Main code:
path_to_saved_model = './'
max_steps = 1
# Start a session
with tf.Session() as sess:
... define the model here ...
print("define the param saver")
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
# restore session if there is a saved checkpoint
print("restoring model")
restored = restore_vars(saver, sess, path_to_saved_model)
print("model restored ",restored)
# Now continue training if you so choose
for step in range(max_steps):
# do an update on the model (not needed)
loss_value = sess.run([loss])
# Now save the model
saver.save(sess, "./model", global_step=step)