我的问题很简单,但到目前为止(网上)我找不到确切的答案。

经过定义次数的训练后,我使用adam优化器训练的keras模型的权重已保存下来:

callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])


当我关闭jupyter后恢复训练时,我可以简单地使用:

model.load_weights(path)


继续训练。

由于Adam取决于时期数(例如学习率下降的情况),因此我想知道在与以前相同的条件下恢复训练的最简单方法。

按照ibarrond的回答,我编写了一个小的自定义回调。

optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])

weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)

class optim_callback(tf.keras.callbacks.Callback):
    '''Custom callback to save optimiser state'''

          def on_epoch_end(self,epoch,logs=None):
                optim_state = tf.keras.optimizers.Adam.get_config(optim)
                with open(optim_state_pkl,'wb') as f_out:
                       pickle.dump(optim_state,f_out)

model.fit(X,y,callbacks=[weight_callback,optim_callback()])


当我恢复训练时:

model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:
                    optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)


我只想检查一下是否正确。再次非常感谢!

附录:在进一步阅读Adam的默认Keras implementationoriginal Adam paper时,我相信默认的Adam不依赖于历元数,而仅依赖于迭代数。因此,这是不必要的。但是,该代码对于希望跟踪其他优化程序的任何人仍然可能有用。

最佳答案

为了完美地捕获优化器的状态,您应该使用功能get_config()存储其配置。此函数返回一个字典(包含选项),该字典可以使用pickle序列化并存储在文件中。

要重新启动该过程,只需d = pickle.load('my_saved_tfconf.txt')使用配置检索字典,然后使用Keras Adam Optimizer.的函数from_config(d)生成Adam Optimizer。

关于python - 在Keras中使用Adam优化器恢复培训,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/60076741/

10-12 19:26