我想使用用tf2.0 代替GradientTape()如下所示的简单模型
I want to use optimizer.get_gradients() from tf2.0 in lieu of GradientTape() a simple model as follows
import numpy as np
import tensorflow as tf
class C3BR(tf.keras.Model):
def __init__(self, filterNum, kSize, strSize, padMode, dFormat='channels_first'):
super(C3BR, self).__init__()
if dFormat == 'channels_first':
self.conAx = 1
self.conAx = -1
self.kSize = (kSize, kSize, kSize)
self.conv = layers.Conv3D(filters=filterNum, kernel_size=self.kSize, strides=strSize, padding=padMode, data_format=dFormat)
self.BN = layers.BatchNormalization(axis=self.conAx)
self.Relu = layers.ReLU()
def call(self, inputs, ifTrain=False):
x = self.conv(inputs)
x= self.BN(x, training=ifTrain)
outputs = self.Relu(x)
return outputs
model = C3BR(32, 3, 1, 'valid')
# model.build_model(input_shape=(2, 4, 64, 64, 64))
# model.summary()
curOpt = tf.keras.optimizers.Adam(learning_rate=1e-4)
x = tf.ones((2, 4, 64, 64, 64), dtype=tf.float32)
yTrue = tf.ones((2, 32, 62, 62, 62), dtype=tf.float32)
yPred = model(x,ifTrain=True)
loss = tf.reduce_mean(yPred-yTrue)
## Why does not it work?
gradients = curOpt.get_gradients(loss, model.trainable_variables)
curOpt.apply_gradients(zip(gradients, model.trainable_variables))
gradNorm = tf.linalg.global_norm(gradients)
But when I ran the code, an error was raised saying C3BR has None gradient.
Am I using optimizer.get_gradients(...) incorrectly?
在TF2中处理渐变的推荐方法是使用 tf.GradientTape()
The recommended way to deal with gradients in TF2 is to use tf.GradientTape()
. I can't tell you why the above doesn't work. But the following works.
model = C3BR(32, 3, 1, 'valid')
# model.build_model(input_shape=(2, 4, 64, 64, 64))
# model.summary()
curOpt = tf.keras.optimizers.Adam(learning_rate=1e-4)
x = tf.ones((2, 4, 64, 64, 64), dtype=tf.float32)
yTrue = tf.ones((2, 32, 62, 62, 62), dtype=tf.float32)
## Why does not it work?
with tf.GradientTape() as tape:
yPred = model(x,ifTrain=True)
loss = tf.reduce_mean(yPred-yTrue)
gradients = tape.gradient(loss, model.trainable_variables)