tensorflow Reading data for n epochs with batching


Example

Assume your data examples are already read to a python's variable and you would like to read it n times, in batches of given size:

import numpy as np
import tensorflow as tf
data = np.array([1, 2, 3, 4, 5])
n = 4

To merge data in batches, possibly with random shuffling, you can use tf.train.batch or tf.train.batch_shuffle, but you need to pass to it the tensor that would produce whole data n times:

limited_tensor = tf.train.limit_epochs(data, n)
batch = tf.train.shuffle_batch([limited_tensor], batch_size=3, enqueue_many=True, capacity=4)

The limit_epochs converts the numpy array to tensor under the hood and returns a tensor producing it n times and throwing an OutOfRangeError afterwards. The enqueue_many=True argument passed to shuffle_batch denotes that each tensor in the tensor list [limited_tensor] should be interpreted as containing a number of examples. Note that capacity of the batching queue can be smaller than the number of examples in the tensor.

One can process the data as usual:

with tf.Session() as sess:
  sess.run(tf.initialize_local_variables())
  tf.train.start_queue_runners()
  try:
    while True:
      data_batch = sess.run(batch)
      # process data
  except tf.errors.OutOfRangeError:
    pass