本文介绍了如何在Pytorch中可视化网络?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想从pytorch模型中可视化 resnet .我该怎么做?我尝试使用 torchviz ,但出现错误:

I want to visualize resnet from the pytorch models. How can I do it? I tried to use torchviz but it gives an error:

'ResNet' object has no attribute 'grad_fn'

推荐答案

make_dot 需要一个变量(即带有 grad_fn 的张量),而不是模型本身.
尝试:

make_dot expects a variable (i.e., tensor with grad_fn), not the model itself.
try:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module

这篇关于如何在Pytorch中可视化网络?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

07-22 15:45