我正在通过trainable=False
API实现的所有图层中都设置了Model
,但我想验证一下是否可行。 model.count_params()
返回参数的总数,但是除了查看model.summary()
的最后几行之外,有什么方法可以获取可训练参数的总数?
最佳答案
from keras import backend as K
trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
可以在
layer_utils.print_summary()
所调用的 summary()
定义的末尾发现上述片段。编辑:Keras的最新版本为此提供了一个辅助函数
count_params()
:from keras.utils.layer_utils import count_params
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
关于python - 如何在Keras中获取模型的可训练参数的数量?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45046525/