Looking for tensorflow Answers? Try Ask4KnowledgeBase
Looking for tensorflow Keywords? Try Ask4Keywords

tensorflowSpeichern und Wiederherstellen eines Modells in TensorFlow


Einführung

Tensorflow unterscheidet zwischen dem Speichern / Wiederherstellen der aktuellen Werte aller Variablen in einem Diagramm und dem Speichern / Wiederherstellen der tatsächlichen Diagrammstruktur. Um die Grafik wiederherzustellen, können Sie entweder die Tensorflow-Funktionen verwenden oder einfach Ihren Code erneut aufrufen, der die Grafik ursprünglich erstellt hat. Bei der Definition des Graphen sollten Sie auch darüber nachdenken, welche Variablen und Operationen abrufbar sind, nachdem der Graph gespeichert und wiederhergestellt wurde.

Bemerkungen

Wenn ich im Abschnitt zum Wiederherstellen des Modells richtig verstehe, erstellen Sie das Modell und stellen die Variablen wieder her. Ich bin der Meinung, dass ein Neuaufbau des Modells nicht erforderlich ist, solange Sie beim Speichern mit tf.add_to_collection() die entsprechenden Tensoren / Platzhalter tf.add_to_collection() . Zum Beispiel:

tf.add_to_collection('cost_op', cost_op)

Später können Sie dann das gespeicherte Diagramm wiederherstellen und mit cost_op zugreifen

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Selbst wenn Sie tf.add_to_collection() nicht ausführen, können Sie Ihre Tensoren abrufen, der Vorgang ist jedoch etwas umständlicher und Sie müssen möglicherweise etwas suchen, um die richtigen Namen für die Dinge zu finden. Zum Beispiel:

In einem Skript, das ein Tensorflow-Diagramm erstellt, definieren wir eine Reihe von Tensoren lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

Wir können sie später wie folgt abrufen:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

Es gibt verschiedene Möglichkeiten, den Namen eines Tensors zu finden - Sie können ihn in Ihrem Graphen auf der Tensorplatine finden, oder Sie können mit etwas wie dem folgenden suchen:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

Speichern und Wiederherstellen eines Modells in TensorFlow Verwandte Beispiele