我的模型在每个输入批次中使用按时间顺序排列的序列。因此,我在改组输入数据之前创建批次。这带来了批处理始终在整个数据集中包含相同数据样本的问题(从相同的索引开始 - 由 batch_size 移动),我通过缓存初始数据集并从跳过的数据集中采样解决了这个问题,但这会很快消耗内存(虽然我的数据集只有 150MB):

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.window(size=window_size, shift=window_shift, stride=window_stride, drop_remainder=True).flat_map(lambda x: x.batch(window_size))
dataset = dataset.map(process_fn, num_parallel_calls=8)
dataset = dataset.cache()
datasets = []
for i in range(0, batch_size):
    d = dataset.skip(i)
    d = d.batch(batch_size, drop_remainder=True)
    datasets.append(d)
dataset = tf.data.experimental.sample_from_datasets(datasets)
dataset = dataset.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
dataset = dataset.repeat()

有没有另一种方法来实现这种行为?我想涵盖批次中第一个序列开始的所有可能索引。

最佳答案

您正在消耗内存,因为您正在洗牌整批 - 也跳过可能不是很有效。由于您的数据似乎全部在内存中,因此您可以直接在 python 中对数据进行采样,而不必太担心性能:

def make_batch(start_idx):
  batch = np.empty((batch_size, window_size), dtype=data.dtype)
  for batch_idx, data_idx in enumerate(
      range(start_idx, start_idx + window_shift * batch_size, window_shift)):
    batch[batch_idx] = data[data_idx:data_idx + window_size * window_stride:window_stride]
  return batch

dataset = (tf.data.Dataset
  .range(len(data) - window_stride * (window_size - 1) - window_shift * (batch_size- 1))
  .shuffle(buffer_size=30000, reshuffle_each_iteration=False)
  .map(lambda x: tf.py_func(make_batch, [x], tf.float32)) # assuming your data is float32
  .repeat()
  .prefetch(1)) # you might want to consider prefetching for performance

改组现在发生在索引上,而不是整个批次上,因此内存占用要低得多。

关于python - tensorflow - tf.data.Dataset 在批处理之前随机跳过样本以获得不同的批次,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53451927/

10-12 21:49