tensorflow Save and Restore a Model in TensorFlow

Help us to keep this website almost Ad Free! It takes only 10 seconds of your time:
> Step 1: Go view our video on YouTube: EF Core Bulk Insert
> Step 2: And Like the video. BONUS: You can also share it!

Introduction

Tensorflow distinguishes between saving/restoring the current values of all the variables in a graph and saving/restoring the actual graph structure. To restore the graph, you are free to use either Tensorflow's functions or just call your piece of code again, that built the graph in the first place. When defining the graph, you should also think about which and how variables/ops should be retrievable once the graph has been saved and restored.

Remarks

In the restoring model section above if I understand correctly you build the model and then restore the variables. I believe rebuilding the model is not necessary so long as you add the relevant tensors/placeholders when saving using tf.add_to_collection(). For example:

tf.add_to_collection('cost_op', cost_op)

Then later you can restore the saved graph and get access to cost_op using

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Even if you don't run tf.add_to_collection(), you can retrieve your tensors, but the process is a bit more cumbersome, and you may have to do some digging to find the right names for things. For example:

in a script that builds a tensorflow graph, we define some set of tensors lab_squeeze:

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

we can recall them later on as follows:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

There are a number of ways to find the name of a tensor -- you can find it in your graph on tensor board, or you can search through for it with something like:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph


Got any tensorflow Question?