1 概述
如果是自己训练,30G的FLyingChairs数据集还是很吃设备,这里只介绍如何使用该算法。
TIps:假设已经安装好了所有库。
2 代码下载
Torch: https://github.com/ClementPinard/FlowNetPytorch
这里主要使用的是run_inference.py文件:
3 数据下载
链接:https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs
点进去是这样的:
下载这个看网速了,这里我放几个测试文件,方便大家调用:
https://download.csdn.net/download/weixin_44575152/86863644
4 预训练模型下载
推荐下载pytorch的,不然torch转有点麻烦:
https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3
5 代码讲解
如果只想直接用,按照以上步骤运行run_inference.py
即可,否则可以阅读以下带有注释的代码。需要修改一下parser.add_argument
中的存储位置:
- 数据集位置:–data
- 预训练模型位置:–pretrained
- 生成图像存储位置:–output
尽量不要把output放在data里哇~
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import models
import torchvision.transforms as transforms
import flow_transforms
import numpy as np
from path import Path
from tqdm import tqdm
from imageio import imread, imwrite
from util import flow2rgb
import warnings
warnings.filterwarnings("ignore")
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__"))
parser = argparse.ArgumentParser(description='PyTorch FlowNet inference on a folder of img pairs',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# 图像存储路径
parser.add_argument('--data', metavar='DIR', default=r"D:\Data\Flow\FlyingChairs",
help='path to images folder, image names must match \'[name]0.[ext]\' and \'[name]1.[ext]\'')
# 预训练模型
parser.add_argument('--pretrained', metavar='PTH', default=r"D:\Data\Flow\FlowNet\model\flownets.pth.tar", help='path to pre-trained model')
# 文件存储位置
parser.add_argument('--output', '-o', metavar='DIR', default=r"D:\Data\Flow\FlowNet\data",
help='path to output folder. If not set, will be created in data folder')
# 存储值的类型
parser.add_argument('--output-value', '-v', choices=['raw', 'vis', 'both'], default='both',
help='which value to output, between raw input (as a npy file) and color vizualisation (as an image file).'
' If not set, will output both')
#
parser.add_argument('--div-flow', default=20, type=float,
help='value by which flow will be divided. overwritten if stored in pretrained file')
# 图像类型
parser.add_argument("--img-exts", metavar='EXT', default=['png', 'jpg', 'bmp', 'ppm'], nargs='*', type=str,
help="images extensions to glob")
# 最大流值
parser.add_argument('--max_flow', default=None, type=float,
help='max flow value. Flow map color is saturated above this value. If not set, will use flow map\'s max value')
# 未设置输出原始输入,即4次下采样;如果选择,则输出指定上采样下的完整分辨率流图
parser.add_argument('--upsampling', '-u', choices=['nearest', 'bilinear'], default=None,
help='if not set, will output FlowNet raw input,'
'which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling')
# 设置,则输出反转流和常规流
parser.add_argument('--bidirectional', action='store_true',
help='if set, will output invert flow (from 1 to 0) along with regular flow')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@torch.no_grad()
def main():
global args, save_path
args = parser.parse_args()
# 输出方式
output_string = ""
if args.output_value == 'both':
output_string = "raw output and RGB visualization"
elif args.output_value == 'raw':
output_string = "raw output"
elif args.output_value == 'vis':
output_string = "RGB visualization"
print("=> will save " + output_string)
data_dir = Path(args.data)
print("=> fetching img pairs in '{}'".format(args.data))
if args.output is None:
save_path = data_dir / 'flow'
else:
save_path = Path(args.output)
print('=> will save everything to {}'.format(save_path))
save_path.makedirs_p()
# Data loading code
input_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
])
img_pairs = []
for ext in args.img_exts:
# 读取与当前格式匹配的图像,后缀为1.ppm
test_files = data_dir.files('*1.{}'.format(ext))
for file in test_files:
# 单个图像,后缀为2.ppm
img_pair = file.parent / (file.stem[:-1] + '2.{}'.format(ext))
if img_pair.isfile():
# 存储图像对
img_pairs.append([file, img_pair])
print('{} samples found'.format(len(img_pairs)))
# create model
network_data = torch.load(args.pretrained, map_location=torch.device('cpu'))
print("=> using pre-trained model '{}'".format(network_data['arch']))
# 读取模型
model = models.__dict__[network_data['arch']](network_data).to(device)
model.eval()
cudnn.benchmark = True
if 'div_flow' in network_data.keys():
args.div_flow = network_data['div_flow']
# 遍历图像对
for (img1_file, img2_file) in tqdm(img_pairs):
# 以下均以飞行椅子为例
# (3, 384, 512)
img1 = input_transform(imread(img1_file))
img2 = input_transform(imread(img2_file))
# (1, 6, 384, 515)
input_var = torch.cat([img1, img2]).unsqueeze(0)
if args.bidirectional:
# feed inverted pair along with normal pair
inverted_input_var = torch.cat([img2, img1]).unsqueeze(0)
input_var = torch.cat([input_var, inverted_input_var])
input_var = input_var.to(device)
# compute output
output = model(input_var)
if args.upsampling is not None:
# 采样
output = F.interpolate(output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False)
for suffix, flow_output in zip(['flow', 'inv_flow'], output):
filename = save_path / '{}{}'.format(img1_file.stem[:-1], suffix)
if args.output_value in ['vis', 'both']:
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1, 2, 0)
imwrite(filename + '.png', to_save)
# if args.output_value in ['raw', 'both']:
# # Make the flow map a HxWx2 array as in .flo files
# to_save = (args.div_flow * flow_output).cpu().numpy().transpose(1, 2, 0)
# np.save(filename + '.npy', to_save)
break
if __name__ == '__main__':
main()