tensorflow Creating a custom operation with tf.py_func (CPU only) Why to use tf.py_func


The tf.py_func() operator enables you to run arbitrary Python code in the middle of a TensorFlow graph. It is particularly convenient for wrapping custom NumPy operators for which no equivalent TensorFlow operator (yet) exists. Adding tf.py_func() is an alternative to using sess.run() calls inside the graph.

Another way of doing that is to cut the graph in two parts:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

With tf.py_func this is much easier:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder