我的模型在每个输入批次中使用按时间顺序排列的序列。因此,我在改组输入数据之前创建批次。这带来了批处理始终在整个数据集中包含相同数据样本的问题(从相同的索引开始 - 由 batch_size
移动),我通过缓存初始数据集并从跳过的数据集中采样解决了这个问题,但这会很快消耗内存(虽然我的数据集只有 150MB):
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.window(size=window_size, shift=window_shift, stride=window_stride, drop_remainder=True).flat_map(lambda x: x.batch(window_size))
dataset = dataset.map(process_fn, num_parallel_calls=8)
dataset = dataset.cache()
datasets = []
for i in range(0, batch_size):
d = dataset.skip(i)
d = d.batch(batch_size, drop_remainder=True)
datasets.append(d)
dataset = tf.data.experimental.sample_from_datasets(datasets)
dataset = dataset.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
dataset = dataset.repeat()
有没有另一种方法来实现这种行为?我想涵盖批次中第一个序列开始的所有可能索引。
最佳答案
您正在消耗内存,因为您正在洗牌整批 - 也跳过可能不是很有效。由于您的数据似乎全部在内存中,因此您可以直接在 python 中对数据进行采样,而不必太担心性能:
def make_batch(start_idx):
batch = np.empty((batch_size, window_size), dtype=data.dtype)
for batch_idx, data_idx in enumerate(
range(start_idx, start_idx + window_shift * batch_size, window_shift)):
batch[batch_idx] = data[data_idx:data_idx + window_size * window_stride:window_stride]
return batch
dataset = (tf.data.Dataset
.range(len(data) - window_stride * (window_size - 1) - window_shift * (batch_size- 1))
.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
.map(lambda x: tf.py_func(make_batch, [x], tf.float32)) # assuming your data is float32
.repeat()
.prefetch(1)) # you might want to consider prefetching for performance
改组现在发生在索引上,而不是整个批次上,因此内存占用要低得多。
关于python - tensorflow - tf.data.Dataset 在批处理之前随机跳过样本以获得不同的批次,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53451927/