我正在通过trainable=False API实现的所有图层中都设置了Model,但我想验证一下是否可行。 model.count_params()返回参数的总数,但是除了查看model.summary()的最后几行之外,有什么方法可以获取可训练参数的总数?

最佳答案

from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

可以在 layer_utils.print_summary() 所调用的 summary() 定义的末尾发现上述片段。

编辑:Keras的最新版本为此提供了一个辅助函数 count_params() :
from keras.utils.layer_utils import count_params

trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)

关于python - 如何在Keras中获取模型的可训练参数的数量?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45046525/

10-11 06:52