tensorflow Using if condition inside the TensorFlow graph with tf.cond define and use functions f1 and f2 with parameters


Example

You can pass parameters to the functions in tf.cond() using lambda and the code is as bellow.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
  return tf.mul(a, b)

def fn2(a, b):
  return tf.add(a, b)

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

Then you can call it as bellowing:

with tf.Session() as sess:
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
  # The result is 2.0
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
  # The result is 5.0