Looking for tensorflow Answers? Try Ask4KnowledgeBase
Looking for tensorflow Keywords? Try Ask4Keywords

tensorflowGuardar y restaurar un modelo en TensorFlow


Introducción

Tensorflow distingue entre guardar / restaurar los valores actuales de todas las variables en un gráfico y guardar / restaurar la estructura del gráfico real. Para restaurar el gráfico, puede utilizar cualquiera de las funciones de Tensorflow o simplemente volver a llamar a su parte del código, que creó el gráfico en primer lugar. Al definir el gráfico, también debe pensar en cuáles y cómo deben recuperarse las variables / operaciones una vez que el gráfico se haya guardado y restaurado.

Observaciones

En la sección de modelo de restauración anterior, si entiendo correctamente, usted construye el modelo y luego restaura las variables. Creo que no es necesario reconstruir el modelo siempre que agregue los tensores / marcadores de posición relevantes al guardar utilizando tf.add_to_collection() . Por ejemplo:

tf.add_to_collection('cost_op', cost_op)

Luego, más tarde, puede restaurar el gráfico guardado y obtener acceso a cost_op usando

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Incluso si no ejecuta tf.add_to_collection() , puede recuperar sus tensores, pero el proceso es un poco más incómodo, y es posible que tenga que cavar un poco para encontrar los nombres correctos para las cosas. Por ejemplo:

en un script que construye un gráfico de tensorflow, definimos un conjunto de tensores lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

Podemos recordarlos más adelante de la siguiente manera:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

Hay varias formas de encontrar el nombre de un tensor: puede encontrarlo en su gráfica en el tablero tensor, o puede buscarlo con algo como:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

Guardar y restaurar un modelo en TensorFlow Ejemplos relacionados