自定义tf.keras.Model需要注意的点

model.save()

  • subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_format="tf"
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.

model.trainable_variables

  • __init__若没有注册该layers,那么在后面应用梯度时会找不到model.trainable_variables。
    像下面这样是不行的:
class Map_model(tf.keras.Model):
    def __init__(self, is_train=False):
        super(Map_model, self).__init__()
    def call(self, x):
        x = tf.keras.layers.Dense(10, activation='relu')
        return x

model.summary()

  • 需要先指定input_shape,或者你直接fit一遍它也能自动确定
    model.build(input_shape=(None, 448, 448, 3))
    print(model.summary())
class Map_model(tf.keras.Model):
    def __init__(self, is_train=False):
        super(Map_model, self).__init__()
        self.map_f1 = tf.keras.layers.Dense(10, activation='relu', trainable=is_train)
        # self.map_f2 = tf.keras.layers.Dense(6, activation='relu')
        self.map_f3 = tf.keras.layers.Dense(3, activation='softmax', trainable=is_train)

    def call(self, x):
        x = self.map_f1(x)
        # x = self.map_f2(x)
        return self.map_f3(x)


@tf.function
def train_step(mmodel, label, L_label, loss_object, train_loss, train_accuracy, optimizer):
    with tf.GradientTape() as tape:
        L_label_pred = mmodel(label)
        loss = loss_object(L_label, L_label_pred)
    gradient_l = tape.gradient(loss, mmodel.trainable_variables)
    train_loss(loss)
    train_accuracy(L_label, L_label_pred)
    optimizer.apply_gradients(zip(gradient_l, mmodel.trainable_variables))


def train():
    mmodel = Map_model(is_train=True)
    optimizer = tf.keras.optimizers.Adam(0.01)
    loss_object = tf.keras.losses.CategoricalCrossentropy()
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

    EPOCHS = 0
    labels = range(1, 30)  # labels = truth_label -1
    L_labels = [int(prpcs.map2Lclass(l)) for l in labels]
    labels = [l - 1 for l in labels]
    labels_onehot = tf.one_hot(labels, depth=29)
    L_labels_onehot = tf.one_hot(L_labels, depth=3)
    EPS = 1e-6
    loss_e = 0x7f7f7f
    while loss_e > EPS:
        EPOCHS += 1
        train_loss.reset_states()
        train_accuracy.reset_states()
        train_step(mmodel, labels_onehot, L_labels_onehot, loss_object, train_loss, train_accuracy, optimizer)

        template = 'Epoch {}, Loss: {}, Accuracy: {}'
        print(template.format(EPOCHS,
                              train_loss.result(),
                              train_accuracy.result() * 100))
        loss_e = train_loss.result()
    print("labels_onehot shape:", labels_onehot.shape)
    model_path = r'./models/'
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    mmodel.save(os.path.join(model_path, 'map_model_{}'.format(EPS)))
    mmodel.save_weights(os.path.join(model_path, 'map_model_weights_{}'.format(EPS)))
    print("Save model!")
02-01 14:46