我正在尝试了解TensorFlow广泛和深度学习教程。普查收入数据集有两个文件可用于验证:adult.data和adult.test。
经过一定数量的时间后,它将打印评估结果(您可以在此处查看完整的代码:https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py)。它使用“ input_fn”从csv文件读取输入信息。它用于读取文件public.data和adult.test。

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have either run data_download.py or '
      'set both arguments --train_data and --test_data.' % data_file)

  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

  dataset = dataset.map(parse_csv, num_parallel_calls=5)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  return dataset


它构建一个估计器DNNLinearCombinedClassifier,并评估和打印精度,如下所示:

...
results = model.evaluate(input_fn=lambda: input_fn(
    FLAGS.test_data, 1, False, FLAGS.batch_size))

# Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
print('-' * 60)

for key in sorted(results):
  print('%s: %s' % (key, results[key]))


我了解您应该分批训练您的网络。我的问题是,他们为什么要批量评估模型?他们不应该使用整个评估数据集吗?数据集具有16281个验证值,不应该这样调用model.evaluate吗?:

_NUM_EXAMPLES = {
  'train': 32561,
  'validation': 16281,
}
...
results = model.evaluate(input_fn=lambda: input_fn(
    FLAGS.test_data, 1, False, _NUM_EXAMPLES['validation']))


使用整个验证数据集是否错误?

最佳答案

训练和测试都需要小批量数据,因为否则两者都可能导致内存不足错误(OOM)。没错,该问题在训练中更为关键,因为后向传递实际上会使内存消耗翻倍。但这并不意味着OOM不可能进行推断。

以我的经验为例:


Python kernel died when using tensorflow
OOM when allocating tensor


...而且我敢肯定还有很多我还没有看过的例子。根据您的资源,16281可能足够小,可以放入一个批处理中,但是总的来说,以推理方式对批处理进行迭代并为此批处理大小进行单独设置是非常有意义的,例如,如果模型可以在另一台资源较少的机器。

08-25 00:31
查看更多