上应用TimeDistributed层

上应用TimeDistributed层

这是我的尝试:

inputs = Input(shape=(config.N_FRAMES_IN_SEQUENCE, config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))

def cnn_model(inputs):
    x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(inputs)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=128, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    return x

x = TimeDistributed(cnn_model)(inputs)


出现以下错误:

AttributeError: 'function' object has no attribute 'built'

最佳答案

您需要使用Lambda层并将函数包装在其中:

# cnn_model function the same way as you defined it ...

x = TimeDistributed(Lambda(cnn_model))(inputs)


或者,您可以将该块定义为模型,然后在其上应用TimeDistributed层:

def cnn_model():
    input_frame = Input(shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))

    x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(input_frame)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(filters=128, kernel_size=(3,3), padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    model = Model(input_frame, x)
    return model

inputs = Input(shape=(config.N_FRAMES_IN_SEQUENCE, config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))

x = TimeDistributed(cnn_model())(inputs)

关于python - 如何在CNN块上应用TimeDistributed层?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52845711/

10-12 18:42