问题描述
我想可视化神经网络层的权重.我正在使用 pytorch.
I want to visualize weights of the layer of a neural network. I'm using pytorch.
import torch
import torchvision.models as models
from matplotlib import pyplot as plt
def plot_kernels(tensor, num_cols=6):
if not tensor.ndim==4:
raise Exception("assumes a 4D tensor")
if not tensor.shape[-1]==3:
raise Exception("last dim needs to be 3 to plot")
num_kernels = tensor.shape[0]
num_rows = 1+ num_kernels // num_cols
fig = plt.figure(figsize=(num_cols,num_rows))
for i in range(tensor.shape[0]):
ax1 = fig.add_subplot(num_rows,num_cols,i+1)
ax1.imshow(tensor[i])
ax1.axis('off')
ax1.set_xticklabels([])
ax1.set_yticklabels([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
vgg = models.vgg16(pretrained=True)
mm = vgg.double()
filters = mm.modules
body_model = [i for i in mm.children()][0]
layer1 = body_model[0]
tensor = layer1.weight.data.numpy()
plot_kernels(tensor)
上面给出了这个错误 ValueError: Floating point image RGB values must be in the 0..1 range.
The above gives this error ValueError: Floating point image RGB values must be in the 0..1 range.
我的问题是我应该标准化并取权重的绝对值来克服这个错误还是有其他方法?如果我标准化并使用绝对值,我认为图表的含义会发生变化.
My question is should I normalize and take absolute value of the weights to overcome this error or is there anyother way ?If I normalize and use absolute value I think the meaning of the graphs change.
[[[[ 0.02240197 -1.22057354 -0.55051649]
[-0.50310904 0.00891289 0.15427093]
[ 0.42360783 -0.23392732 -0.56789106]]
[[ 1.12248898 0.99013627 1.6526649 ]
[ 1.09936976 2.39608836 1.83921957]
[ 1.64557672 1.4093554 0.76332706]]
[[ 0.26969245 -1.2997849 -0.64577204]
[-1.88377869 -2.0100112 -1.43068039]
[-0.44531786 -1.67845118 -1.33723605]]]
[[[ 0.71286005 1.45265901 0.64986968]
[ 0.75984162 1.8061738 1.06934202]
[-0.08650422 0.83452386 -0.04468433]]
[[-1.36591709 -2.01630116 -1.54488969]
[-1.46221244 -2.5365622 -1.91758668]
[-0.88827479 -1.59151018 -1.47308767]]
[[ 0.93600738 0.98174071 1.12213969]
[ 1.03908169 0.83749604 1.09565806]
[ 0.71188802 0.85773659 0.86840987]]]
[[[-0.48592842 0.2971966 1.3365227 ]
[ 0.47920835 -0.18186836 0.59673625]
[-0.81358945 1.23862112 0.13635623]]
[[-0.75361633 -1.074965 0.70477796]
[ 1.24439156 -1.53563368 -1.03012812]
[ 0.97597247 0.83084011 -1.81764793]]
[[-0.80762428 -0.62829626 1.37428832]
[ 1.01448071 -0.81775147 -0.41943246]
[ 1.02848887 1.39178836 -1.36779451]]]
...,
[[[ 1.28134537 -0.00482408 0.71610934]
[ 0.95264435 -0.09291686 -0.28001019]
[ 1.34494913 0.64477581 0.96984017]]
[[-0.34442815 -1.40002513 1.66856039]
[-2.21281362 -3.24513769 -1.17751861]
[-0.93520379 -1.99811196 0.72937071]]
[[ 0.63388056 -0.17022935 2.06905985]
[-0.7285465 -1.24722099 0.30488953]
[ 0.24900314 -0.19559766 1.45432627]]]
[[[-0.80684513 2.1764245 -0.73765725]
[-1.35886598 1.71875226 -1.73327696]
[-0.75233924 2.14700699 -0.71064663]]
[[-0.79627383 2.21598244 -0.57396138]
[-1.81044972 1.88310981 -1.63758397]
[-0.6589964 2.013237 -0.48532376]]
[[-0.3710472 1.4949851 -0.30245575]
[-1.25448656 1.20453358 -1.29454732]
[-0.56755757 1.30994892 -0.39370224]]]
[[[-0.67361742 -3.69201088 -1.23768616]
[ 3.12674141 1.70414758 -1.76272404]
[-0.22565465 1.66484773 1.38172317]]
[[ 0.28095332 -2.03035069 0.69989491]
[ 1.97936332 1.76992691 -1.09842575]
[-2.22433758 0.52577412 0.18292744]]
[[ 0.48471382 -1.1984663 1.57565165]
[ 1.09911084 1.31910467 -0.51982772]
[-2.76202297 -0.47073677 0.03936549]]]]
推荐答案
听起来好像您已经知道您的值不在该范围内.是的,您必须将它们重新缩放到 0.0 - 1.0 的范围内.我建议您希望保留负面与正面的可见性,但让 0.5 成为您新的中立"点.进行缩放,使当前的 0.0 值映射到 0.5,您的最极端值(最大震级)缩放到 0.0(如果为负)或 1.0(如果为正).
It sounds as if you already know your values are not in that range. Yes, you must re-scale them to the range 0.0 - 1.0. I suggest that you want to retain visibility of negative vs positive, but that you let 0.5 be your new "neutral" point. Scale such that current 0.0 values map to 0.5, and your most extreme value (largest magnitude) scale to 0.0 (if negative) or 1.0 (if positive).
谢谢你的载体.看起来您的值在 -2.25 到 +2.0 的范围内.我建议重新调整 new = (1/(2*2.25)) * old + 0.5
Thanks for the vectors. It looks like your values are in the range -2.25 to +2.0. I suggest a rescaling new = (1/(2*2.25)) * old + 0.5
这篇关于ValueError:浮点图像 RGB 值必须在 0..1 范围内.使用 matplotlib 时的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!