When state_below
is a 2D Tensor, U
is a 2D weights matrix, b
is a class_size
-length vector:
logits = tf.matmul(state_below, U) + b
return tf.nn.softmax(logits)
When state_below
is a 3D tensor, U
, b
as before:
def softmax_fn(current_input):
logits = tf.matmul(current_input, U) + b
return tf.nn.softmax(logits)
raw_preds = tf.map_fn(softmax_fn, state_below)