问题描述
我决定从keras切换到tf.keras(建议此处).因此,我安装了tf.__version__=2.0.0
和tf.keras.__version__=2.2.4-tf
.在我的代码的较旧版本中(使用一些较旧的Tensorflow版本tf.__version__=1.x.x
),我使用回调在每个时期结束时对整个验证数据计算自定义指标.这样做的想法来自这里.但是,似乎好像不赞成使用"validation_data"属性,因此下面的代码不再起作用.
I decided to switch from keras to tf.keras (as recommended here). Therefore I installed tf.__version__=2.0.0
and tf.keras.__version__=2.2.4-tf
. In an older version of my code (using some older Tensorflow version tf.__version__=1.x.x
) I used a callback to compute custom metrics on the entire validation data at the end of each epoch. The idea to do so was taken from here. However, it seems as if the "validation_data" attribute is deprecated so that the following code is not working any longer.
class ValMetrics(Callback):
def on_train_begin(self, logs={}):
self.val_all_mse = []
def on_epoch_end(self, epoch, logs):
val_predict = np.asarray(self.model.predict(self.validation_data[0]))
val_targ = self.validation_data[1]
val_epoch_mse = mse_score(val_targ, val_predict)
self.val_epoch_mse.append(val_epoch_mse)
# Add custom metrics to the logs, so that we can use them with
# EarlyStop and csvLogger callbacks
logs["val_epoch_mse"] = val_epoch_mse
print(f"\nEpoch: {epoch + 1}")
print("-----------------")
print("val_mse: {:+.6f}".format(val_epoch_mse))
return
以下是我当前的解决方法.我只是将validation_data作为ValMetrics
类的参数:
My current workaround is the following. I simply gave validation_data as an argument to the ValMetrics
class :
class ValMetrics(Callback):
def __init__(self, validation_data):
super(Callback, self).__init__()
self.X_val, self.y_val = validation_data
仍然有一些问题:是否确实不赞成使用"validation_data"属性,或者可以在其他位置找到它?与上述解决方法相比,在每个时期结束时是否有更好的方法来访问验证数据?
Still I have some questions: Is the "validation_data" attribute really deprecated or can it be found elsewhere? Is there a better way to access the validation data at the end of each epoch than with the above workaround?
非常感谢!
推荐答案
您正确的认为,根据 Tensorflow回调文档.
您面临的问题已在Github中提出.相关问题是问题1 ,问题2 和问题3 .
The issue which you are facing has been raised in Github. Related issues are Issue1, Issue2 and Issue3.
以上所有Github问题均未解决,按照此 Github评论,因为许多人认为它很有用.
None of the above Github Issues is resolved and Your workaround of passing Validation_Data
as an argument to Custom Callback is a good one, as per this Github Comment, as many people found it useful.
为Stackoverflow Community
的利益指定以下变通方法代码,即使它存在于Github中.
Specifying the code of workaround below, for the benefit of the Stackoverflow Community
, even though it is present in Github.
class Metrics(Callback):
def __init__(self, val_data, batch_size = 20):
super().__init__()
self.validation_data = val_data
self.batch_size = batch_size
def on_train_begin(self, logs={}):
print(self.validation_data)
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []
def on_epoch_end(self, epoch, logs={}):
batches = len(self.validation_data)
total = batches * self.batch_size
val_pred = np.zeros((total,1))
val_true = np.zeros((total))
for batch in range(batches):
xVal, yVal = next(self.validation_data)
val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal
val_pred = np.squeeze(val_pred)
_val_f1 = f1_score(val_true, val_pred)
_val_precision = precision_score(val_true, val_pred)
_val_recall = recall_score(val_true, val_pred)
self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)
return
我将继续关注上述Github问题,并会相应地更新答案.
I will keep following the Github Issues mentioned above and will update the Answer accordingly.
希望这会有所帮助.学习愉快!
Hope this helps. Happy Learning!
这篇关于访问弃用的属性"validation_data";在tf.keras.callbacks.Callback中的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!