我正在构建一个 CNN 并定义一个全连接层,SeLU 作为其激活和 AlphaDropout(0.5)。我正在尝试使用 tf.random.normal
分布初始化 SeLU,如下所示:
dist = tf.Variable(tf.random.normal([5, 5, 1, 32], stddev=np.sqrt(1/25)))
这是我的全连接层的代码:
def FullyConnectedLayer(denseUnits, seluDistribution, batchMomentum, alphaDropRate):
model.add(Dense(denseUnits, activity_regularizer='l2'))
model.add(Activation(selu(x=seluDistribution)))
model.add(BatchNormalization(axis=-1, momentum=batchMomentum, epsilon=0.001))
model.add(AlphaDropout(alphaDropRate, noise_shape=None, seed=None))
return model
model = FullyConnectedLayer(512, dist, 0.99, 0.5) # 4 LAYERS
我收到错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-121-f0000c6b1512> in <module>
11 model = ConvAvgStack (256, (3, 3), (1, 1), 1, 0.99, 0.3, None, (2, 2), (2, 2)) # 5 LAYERS
12 model = FlattenLayer ( ) # 1 LAYER
---> 13 model = FullyConnectedLayer (512, dist, 0.99, 0.5 ) # 4 LAYERS
14 model = FullyConnectedLayer (512, dist, 0.99, 0.5 ) # 4 LAYERS
15 model = OutputLayer ( 28 ) # 2 LAYERS
<ipython-input-119-58375bdf8845> in FullyConnectedLayer(denseUnits, seluDistribution, batchMomentum, alphaDropRate)
56 def FullyConnectedLayer(denseUnits, seluDistribution, batchMomentum, alphaDropRate):
57 model.add(Dense(denseUnits, activity_regularizer='l2'))
---> 58 model.add(Activation(gelu(x=seluDistribution)))
59 model.add(BatchNormalization(axis=-1, momentum=batchMomentum, epsilon=0.001))
60 model.add(AlphaDropout(alphaDropRate, noise_shape=None, seed=None))
~\Anaconda3\envs\py36\lib\site-packages\tensorflow_core\python\keras\layers\core.py in __init__(self, activation, **kwargs)
376 super(Activation, self).__init__(**kwargs)
377 self.supports_masking = True
--> 378 self.activation = activations.get(activation)
379
380 def call(self, inputs):
~\Anaconda3\envs\py36\lib\site-packages\tensorflow_core\python\keras\activations.py in get(identifier)
452 raise TypeError(
453 'Could not interpret activation function identifier: {}'.format(
--> 454 repr(identifier)))
TypeError: Could not interpret activation function identifier: <tf.Tensor: shape=(5, 5, 1, 32), dtype=float32, numpy=
array([[[[-1.26586094e-01, -1.02963023e-01, 3.14652212e-02,
5.46985120e-02, 8.40277504e-03, 8.54115710e-02,
-1.39087364e-01, 1.13992631e-01, 1.52557418e-01,
-1.09972686e-01, -5.12595251e-02, -1.58538278e-02,
3.14276330e-02, -3.38738156e-03, -1.00402050e-02,
9.30291191e-02, 3.55263725e-02, -1.13361394e-02,
-1.29528284e-01, 1.63152684e-02, 1.01518132e-01,
-4.35875840e-02, 1.46785110e-01, -2.23108958e-02,
-2.09968127e-02, -8.54036435e-02, 9.01642349e-03,
8.55906028e-03, 1.10763777e-02, 1.35530531e-01,
-4.25574742e-02, 4.80710454e-02]],
[[ 9.34263412e-03, 1.06001608e-01, -7.65870064e-02,
2.61795402e-01, -7.57196844e-02, -1.04116738e-01,
-8.02185014e-02, 6.67698979e-02, -8.98385793e-02,
7.15453252e-02, -9.78381783e-02, 4.91873287e-02,
1.30732795e-02, 1.55197978e-01, -3.54499035e-02,
7.09592476e-02, 4.96367812e-02, 3.68002579e-02,
1.16795145e-01, -1.42192841e-01, 4.91914898e-02,
4.16900441e-02, 3.53892595e-01, 1.78602412e-01,
-6.12295903e-02, 7.36039877e-02, -1.33156419e-01,
2.31991991e-01, 8.40820521e-02, -4.55706231e-02,
2.51267888e-02, 2.58378834e-01]],
[[-1.38585389e-01, 1.03538044e-01, 1.76681668e-01,
-6.94317510e-03, 6.14152141e-02, -3.92788239e-02,
2.71100029e-02, -7.32106417e-02, 1.82974041e-01,
-5.83523549e-02, 6.68111816e-02, 5.49897328e-02,
-5.77139147e-02, -7.64194950e-02, -7.55715296e-02,
-4.95074578e-02, 7.71198049e-02, 5.40203564e-02,
1.55790344e-01, -4.58407030e-02, -3.59891504e-02,
-9.74030495e-02, -1.00650810e-01, 1.23783059e-01,
-8.46874043e-02, -1.04908131e-01, -2.63819955e-02,
1.40812725e-01, -2.82377452e-02, -2.38265842e-02,
-1.31487399e-01, 1.30674899e-01]],
[[ 6.60606772e-02, 1.46065757e-01, 1.59279909e-02,
8.10257494e-02, -7.72565231e-02, -9.53363404e-02,
-1.20391339e-01, -7.02986643e-02, -2.74278801e-02,
-1.29030854e-01, -7.62277395e-02, -1.19075023e-01,
6.59862757e-02, -7.62806982e-02, 1.67345591e-02,
1.51549906e-01, -1.10640965e-01, -1.34767130e-01,
2.70601243e-01, -9.72772986e-02, 2.07788169e-01,
6.56675100e-02, -2.64023039e-02, 1.13911137e-01,
-9.22646299e-02, 7.98776373e-02, 6.54103830e-02,
-6.72401339e-02, -4.81364317e-02, -6.03620708e-02,
-2.84200851e-02, -9.10447016e-02]],
[[-1.23140588e-01, 1.10491589e-01, -9.61843282e-02,
-8.91052186e-02, 4.01075035e-01, 1.94666237e-02,
1.95948835e-02, -1.25196623e-02, -9.97422487e-02,
-2.61222124e-02, -1.56512097e-01, 9.74281505e-02,
-3.66279632e-02, 6.65708026e-03, 9.61058680e-03,
-1.21156186e-01, -2.98077669e-02, 1.66137442e-02,
1.17182136e-01, -1.13791995e-01, -1.27656450e-02,
1.41541764e-01, 6.68982640e-02, 4.79037128e-02,
-3.38280275e-02, -9.28360224e-02, -7.76154548e-02,
-7.96113610e-02, -2.57881228e-02, -1.58247918e-01,
1.13235332e-01, 1.41958997e-01]]],
[[[-1.13160208e-01, -1.98329911e-02, 1.20878376e-01,
-1.13716172e-02, -5.21509871e-02, 7.25255907e-02,
-1.12730011e-01, -7.29970336e-02, 6.37045652e-02,
7.17113987e-02, -4.47467379e-02, 5.34803495e-02,
-8.64603445e-02, -2.22087242e-02, -2.47925967e-02,
8.34110975e-02, 7.71386176e-02, -4.75004427e-02,
-6.44451613e-03, -1.73095725e-02, -6.07393086e-02,
5.75710386e-02, -5.33160344e-02, -8.67358595e-02,
-4.96991165e-02, -3.15147117e-02, 2.43039820e-02,
1.42646387e-01, 1.22333430e-01, -3.74684632e-02,
-7.35211000e-02, -6.92363605e-02]],
[[ 3.96580771e-02, 1.26118317e-01, 1.16271339e-01,
1.54558346e-01, 1.14904214e-02, -2.90639680e-02,
-5.80145419e-02, -2.15136074e-03, -9.12490934e-02,
1.45193376e-02, -3.00550666e-02, 1.45778894e-01,
4.00692225e-02, -1.92456692e-02, 6.31886274e-02,
-1.27457187e-01, 3.60154063e-02, 9.91806835e-02,
8.99021700e-02, 2.88172178e-02, -1.59403589e-02,
4.76611021e-04, -3.30352560e-02, 1.15945041e-02,
-4.64559309e-02, -2.11531147e-02, 4.10205543e-01,
-4.43787202e-02, 4.39099297e-02, 3.06370091e-02,
-9.87873599e-02, -5.10304309e-02]],
[[ 2.13202462e-02, 1.41525701e-01, -4.84775938e-02,
2.97882885e-01, 2.19049938e-02, 3.68789248e-02,
2.60351785e-02, -9.37016606e-02, 5.48276715e-02,
-1.43082231e-01, 4.21900637e-02, -1.17563821e-01,
-3.71489525e-02, -1.45584494e-01, -1.12884097e-01,
-7.87854716e-02, -2.01713406e-02, -3.49416770e-02,
-6.53499886e-02, -2.09143162e-02, 2.94101406e-02,
4.72677462e-02, 2.33202621e-01, -1.95219535e-02,
1.19159967e-02, -1.00374170e-01, -8.75894353e-02,
-5.27165644e-02, 1.19348057e-02, -4.39126566e-02,
-6.26288429e-02, 4.20925207e-02]],
[[-8.23830441e-02, 2.23106906e-01, 8.56178179e-02,
1.73401862e-01, -8.12073424e-02, 2.73483209e-02,
-5.99831380e-02, -1.71386788e-03, -3.62357125e-02,
-1.59021363e-01, -2.17766548e-03, 2.16864720e-01,
-5.73305860e-02, -1.80698894e-02, 1.36940643e-01,
-1.97473206e-02, 8.14313069e-02, 1.96376622e-01,
1.41641393e-01, 1.47828847e-01, -8.56224895e-02,
1.83912277e-01, -1.33015722e-01, 3.97381186e-02,
1.18237391e-01, -9.23948511e-02, 8.74724891e-03,
4.36485223e-02, 6.96098059e-02, -4.20766175e-02,
-6.43103570e-02, -3.85615453e-02]],
[[-3.53560485e-02, 4.35038935e-03, -7.06349090e-02,
-2.80691660e-03, -6.92954510e-02, 1.11481667e-01,
8.37303791e-03, 6.81344569e-02, -7.26705194e-02,
-4.58219610e-02, -2.38394644e-02, -7.87800774e-02,
1.69382155e-01, 1.03942029e-01, -1.96680743e-02,
-1.67009607e-02, 6.01479635e-02, 1.56740978e-01,
-9.78638828e-02, -4.29860055e-02, 1.38192121e-02,
-1.36006713e-01, -1.05418041e-01, -2.51792613e-02,
-1.22639257e-02, -1.21888302e-01, -5.46660051e-02,
-7.12147309e-03, -6.58531636e-02, -7.14808479e-02,
9.00977999e-02, 6.35402352e-02]]],
[[[ 6.32937178e-02, 2.72242278e-01, -3.74731459e-02,
2.15447005e-02, -1.08312249e-01, 2.10458219e-01,
3.16671804e-02, -4.71992679e-02, 3.75940092e-02,
-2.62564681e-02, -1.54855132e-01, 7.81283434e-03,
5.74255362e-02, 1.75963491e-02, -4.40403447e-02,
-8.01301673e-02, 7.47360140e-02, -5.00108190e-02,
-7.64894933e-02, 8.45131949e-02, -3.27355303e-02,
-3.79370786e-02, -6.93783676e-03, -4.87477183e-02,
9.93528962e-02, -1.05679579e-01, -1.12576345e-02,
4.84773107e-02, -1.20892882e-01, 7.03079775e-02,
5.60718998e-02, -1.91565454e-02]],
[[-7.98909813e-02, 2.59152979e-01, 1.75541520e-01,
3.17000411e-02, -1.23978313e-02, 5.59741072e-02,
-8.12215135e-02, -9.54297185e-02, 1.99518725e-03,
-3.72358635e-02, -1.39946237e-01, -5.76626435e-02,
-7.13582858e-02, 5.86171262e-02, -1.39267772e-01,
5.99225797e-02, -2.99881045e-02, 7.08236769e-02,
-1.00216493e-01, 2.68728107e-01, 1.63495377e-01,
2.52694320e-02, -7.93625191e-02, -3.71078290e-02,
-2.24205833e-02, 1.44553408e-01, -9.67240557e-02,
7.93731958e-02, 1.79968283e-01, -8.94036815e-02,
-1.24277532e-01, -1.40620157e-01]],
[[-5.69531657e-02, 1.71630532e-01, 2.86230773e-01,
-5.93378842e-02, -1.71954520e-02, -3.26295868e-02,
1.84255466e-01, 1.47821277e-01, -2.54929177e-02,
-3.98173966e-02, 7.21049905e-02, -6.91456124e-02,
-1.23138815e-01, 1.33402884e-01, -1.02245316e-01,
2.63660389e-04, 4.64727916e-02, 5.91520481e-02,
-4.69203852e-02, -1.75676849e-02, 1.40360445e-01,
1.67195871e-02, -1.11560198e-02, 4.65030931e-02,
1.73744947e-01, -1.47689149e-01, 1.10403180e-01,
-3.40559036e-02, 3.35928686e-02, -1.04908220e-01,
3.52294981e-01, -1.09612457e-01]],
[[-7.85556585e-02, -1.18466914e-01, 1.53003752e-01,
3.53524536e-01, 8.51708353e-02, 1.50212459e-02,
6.00347035e-02, 6.17506169e-02, 3.86744961e-02,
-4.67218924e-03, -1.16112582e-01, 5.51390201e-02,
-1.52055770e-02, 3.54320277e-03, 3.42624858e-02,
1.12283051e-01, 9.55326110e-02, 1.21229617e-02,
2.53406595e-02, -8.03915039e-03, -5.12704402e-02,
-1.05386212e-01, 1.98949352e-02, 2.73315758e-02,
-7.58572146e-02, 1.03625186e-01, 2.05493998e-03,
2.16567993e-01, -1.08607717e-01, -1.77554917e-02,
-1.01805991e-02, 2.19766423e-02]],
[[-1.00509115e-02, 2.22494956e-02, -9.08879191e-02,
1.40368611e-01, -7.76991919e-02, 4.06601280e-02,
-8.11229870e-02, 9.98405516e-02, -5.72074987e-02,
-1.33951874e-02, 3.92576605e-02, 1.16789080e-01,
-1.89318452e-02, -1.59033425e-02, 9.48152542e-02,
-2.66773477e-02, 1.37753570e-02, 1.79445334e-02,
-6.62883669e-02, -9.37851295e-02, 1.94142580e-01,
1.51747808e-01, 1.18895158e-01, -9.38543454e-02,
-1.28269400e-02, 1.25869989e-01, 1.50878415e-01,
-2.11219154e-02, -1.05045862e-01, -2.73662023e-02,
1.34711221e-01, -3.39821167e-02]]],
[[[ 1.83003962e-01, -4.02955636e-02, 7.92874582e-03,
-1.04859909e-02, 1.41754048e-02, -1.52763631e-02,
-9.11424682e-02, 3.24082047e-01, 1.05546042e-02,
1.61272004e-01, -6.35793507e-02, 3.40929255e-02,
1.42173097e-01, 2.29736529e-02, -9.50964168e-02,
4.37728036e-03, 2.28861179e-02, -9.05632600e-02,
4.51861843e-02, 1.37779471e-02, -1.46172449e-01,
1.64313123e-01, -3.22954543e-02, -5.28477319e-02,
4.69896160e-02, -1.41132519e-01, -7.68374726e-02,
-1.30687788e-01, -3.98224816e-02, 1.38061410e-02,
1.22488514e-01, 1.80401299e-02]],
[[ 2.33758405e-01, -9.26907063e-02, 1.65917858e-01,
6.44484162e-02, -2.50724498e-02, 6.55293986e-02,
8.01478773e-02, -1.14124026e-02, 1.28012270e-01,
2.19843611e-01, 6.57515675e-02, -3.07754893e-02,
5.73609546e-02, 2.30572656e-01, -3.31373070e-03,
-1.22203723e-01, -2.83196904e-02, 1.02213569e-01,
1.01013631e-01, -1.12923756e-01, 6.65007606e-02,
3.05338297e-02, -3.15904021e-02, 2.79964060e-02,
-5.63387433e-03, -3.08787469e-02, 1.96257643e-02,
-7.37890229e-02, -1.93086471e-02, 1.30984381e-01,
1.62610561e-01, -1.23884566e-01]],
[[-5.09779751e-02, 6.08728305e-02, -8.07061568e-02,
1.61063105e-01, -1.54089145e-02, -3.93352509e-02,
1.43149942e-01, -2.58404091e-02, -2.68822517e-02,
-1.26804784e-01, -1.43676013e-01, -3.28507088e-02,
7.94891044e-02, -1.40485764e-01, -8.11149403e-02,
6.17359020e-02, 2.30427265e-01, -1.29583761e-01,
9.54858139e-02, -3.21873813e-03, -1.57925244e-02,
9.96306986e-02, -1.02927998e-01, 8.71243626e-02,
-1.66144117e-03, -7.41888210e-03, 1.42028257e-01,
-4.99214791e-02, -1.86899900e-02, -1.09298825e-02,
-8.03249031e-02, -1.00237548e-01]],
[[-7.80191123e-02, 4.05082256e-02, 7.47731477e-02,
-8.76973122e-02, -2.91744564e-02, 1.23694569e-01,
2.35005572e-01, -1.05778649e-01, -4.78913225e-02,
-1.49070352e-01, 2.42730626e-03, 3.52480598e-02,
9.97696498e-06, -1.27278671e-01, -1.08177230e-01,
-5.62792830e-03, -2.28355639e-02, -1.27415329e-01,
3.05411909e-02, 1.00286447e-01, 1.83264986e-02,
-8.48858505e-02, -3.52028869e-02, -7.95315206e-02,
-3.92727107e-02, -4.16678861e-02, 2.39140958e-01,
4.07571718e-02, -9.46874619e-02, 1.50908276e-01,
-1.44019471e-02, -8.69576260e-03]],
[[-1.67441964e-02, -1.43177100e-02, -9.23768803e-02,
4.70091105e-02, 4.42117406e-03, 6.48477301e-02,
-2.72830930e-02, 7.51334131e-02, -2.28366554e-02,
9.48273912e-02, 4.46406417e-02, 6.07026815e-02,
5.69610856e-02, -4.77909558e-02, -6.64769933e-02,
-5.57800010e-02, -1.31770581e-01, 9.31192283e-03,
-1.38517320e-02, -1.41043484e-01, -6.42404705e-02,
2.63120145e-01, 1.80331752e-01, -1.43979434e-02,
-4.86476049e-02, -1.12639852e-01, 7.89660513e-02,
1.24138966e-01, 5.12700714e-02, -1.20767031e-03,
-1.09081008e-01, -3.03610712e-02]]],
[[[-1.40361011e-01, 1.21919084e-02, 4.36685272e-02,
-3.61564793e-02, -1.11773185e-01, 2.25092173e-02,
-1.02469876e-01, 1.76996499e-01, 4.30173017e-02,
-2.26258971e-02, 2.11037025e-01, 9.66922417e-02,
5.76661676e-02, 9.65369982e-04, -1.35565817e-01,
-4.83587980e-02, 4.68245940e-04, -1.47096828e-01,
8.96992441e-03, 4.12831195e-02, 9.53651369e-02,
-2.91392524e-02, 8.22411999e-02, 2.07852814e-02,
-4.12134677e-02, 5.33621386e-02, 9.24792588e-02,
8.16729572e-03, 4.25154343e-02, 6.19177930e-02,
7.98290670e-02, -8.52704328e-03]],
[[ 1.66879535e-01, 6.54919222e-02, -3.27483788e-02,
-1.43241754e-03, -1.14416316e-01, -2.12962832e-02,
-4.46583293e-02, 2.71647628e-02, -5.61558232e-02,
1.09621109e-02, 1.67668343e-01, 3.30472551e-02,
7.05115721e-02, 7.84466881e-03, 1.08160205e-01,
2.66151220e-01, 1.52581872e-03, 7.19077215e-02,
-1.24854170e-01, 1.25476092e-01, -7.09585026e-02,
-4.40548174e-02, 7.21732453e-02, 7.45785460e-02,
3.44901420e-02, 2.10928824e-02, -7.80880824e-02,
-1.17296316e-01, -1.46051958e-01, 1.88378561e-02,
6.55523613e-02, 3.32243517e-02]],
[[ 2.60874778e-01, -1.45940065e-01, -9.79427770e-02,
-8.68195742e-02, 2.04389215e-01, -2.24198923e-02,
4.23102900e-02, -7.01505691e-02, -1.27080590e-01,
6.70303479e-02, 1.60573255e-02, -7.93380756e-03,
-8.38927086e-03, -4.99465019e-02, 4.69646640e-02,
-7.15569034e-02, -1.78242605e-02, -8.51068646e-03,
4.20920074e-01, 7.50197982e-03, -6.86415285e-02,
7.11418912e-02, 1.07180420e-03, -9.36960131e-02,
1.57825544e-01, 5.96512817e-02, 1.75660148e-01,
-3.08227092e-02, -4.82530929e-02, 8.31630453e-02,
-4.16018628e-02, -7.55471215e-02]],
[[ 2.24076852e-01, -1.39667824e-01, 7.93220941e-03,
-1.78845283e-02, -5.64770252e-02, -7.84719810e-02,
5.26466146e-02, 6.62457757e-03, 2.76956528e-01,
9.01412778e-03, -1.48465708e-01, -9.00324360e-02,
-1.81565285e-02, 1.24106847e-01, -6.28474308e-03,
-1.72791779e-02, -3.47166769e-02, -4.92920280e-02,
1.33945951e-02, -1.16457433e-01, -1.28861982e-02,
1.83324851e-02, -1.37674257e-01, -8.29964876e-02,
-9.12440866e-02, 6.42236844e-02, -1.16013244e-01,
-7.96606317e-02, 1.50838092e-01, -4.71229590e-02,
-4.02066261e-02, 1.17019311e-01]],
[[-3.95799540e-02, -4.35096361e-02, -9.93420109e-02,
3.89132760e-02, 8.42780769e-02, -1.38364257e-02,
2.48586033e-02, -8.65626428e-03, 1.72410719e-02,
-6.20126911e-02, 1.93700612e-01, 5.02851121e-02,
-9.00325775e-02, 1.32245719e-01, 2.68575907e-01,
-8.08344856e-02, -4.56905663e-02, 1.26069590e-01,
5.42675406e-02, 1.27283424e-01, 2.92954836e-02,
2.07115993e-01, -1.58712193e-01, -2.03064550e-02,
-6.64912462e-02, 9.61613879e-02, -1.48803489e-02,
1.32543296e-01, -1.13899536e-01, 5.34827523e-02,
我无法初始化 SeLU 激活函数的随机分布。所有帮助将不胜感激!
最佳答案
首先,我认为可能不存在 Activation(selu(x=dist))
这样的用法。对于 selu
在 Activation
中用作 function
而不是 selu
的输出。 selu
的实现可以在下面找到:
@keras_export('keras.activations.selu')
def selu(x):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * K.elu(x, alpha)
在你的情况下,我认为 article 意味着初始化层的权重而不是
selu
。根据官方 api here ,我认为 selu 可以在您的情况下使用如下:# official usage
model.add(Dense(16, kernel_initializer='lecun_normal', activation='selu'))
# in your case, for the Dense layer refer to the standard layer in article
import numpy as np
import tensorflow as tf
from tensorflow.keras.activations import selu
from tensorflow.keras.layers import Dense, Activation, BatchNormalization, AlphaDropout
from tensorflow.keras import initializers
def FullyConnectedLayer(denseUnits, in_dim, batchMomentum, alphaDropRate):
model = tf.keras.Sequential()
model.add(Dense(denseUnits, activity_regularizer='l2', kernel_initializer=initializers.RandomNormal(stddev=np.sqrt(1/in_dim)), input_shape=(in_dim,)))
model.add(Activation(selu))
model.add(BatchNormalization(axis=-1, momentum=batchMomentum, epsilon=0.001))
model.add(AlphaDropout(alphaDropRate, noise_shape=None, seed=None))
return model
model = FullyConnectedLayer(512, 10, 0.99, 0.5) # 4 LAYERS
总而言之,快乐编码。
关于python - SeLU 激活函数 x 参数导致类型错误,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/60675024/