我通过读取tfrecords创建了一个数据集,映射了这些值,并希望筛选特定值的数据集,但是由于结果是一个带有张量的dict,所以我无法获取张量的实际值,也无法使用tf.cond()/tf.equal检查它。我该怎么做?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()

最佳答案

我在回答我自己的问题。我发现了问题!
我需要做的是像这样的标签:

label = tf.unstack(features['label'])
label = label[0]

在我把它交给
result = tf.reshape(tf.equal(label, 'some_label_value'), [])

我认为问题在于标签被定义为一个数组,其中一个元素的类型为stringtf.unstack(),所以为了得到第一个元素和单个元素,我必须将其解包(这将创建一个列表),然后获取索引为0的元素,如果我错了,请纠正我。

10-06 06:30