问题描述
遵循 Tensorflow 的最佳性能实践,我使用的是 NCHW 数据格式,但我不确定要在 tensorflow.nn.conv2d.
Following Tensorflow's best practices for performance, I am using NCHW data format, but I am not sure about the filter shape to be used in tensorflow.nn.conv2d.
文档说对 NHWC 格式使用 [filter_height, filter_width, in_channels, out_channels]
,但不清楚如何处理 NCHW.
The doc says to use [filter_height, filter_width, in_channels, out_channels]
for NHWC format, but is not clear about what to do with NCHW.
是否应该使用相同的形状?
Should the same shape be used ?
推荐答案
使用相同的过滤器形状应该可行.函数参数的唯一变化是步幅.例如,假设您希望您的架构同时使用这两种格式,这也是推荐的:
Using the same filter shape should work. The only change to the function arguments is the stride. As an example let's say you wanted your architecture to work with both formats, which is also recommended:
# input -> Tensor in NCHW format
if use_nchw:
result = tf.nn.conv2d(
input=input,
filter=filter,
strides=[1, 1, stride, stride],
data_format='NCHW')
else:
input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC
result = tf.nn.conv2d(
input=input_t,
filter=filter,
strides=[1, stride, stride, 1])
result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW
这篇关于tensorflow.nn.conv2d 中 NCHW 格式的过滤器形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!