torchnet package (2)

torchnet
torch7

Dataset Iterators

尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly manner或者在线程中读取数据,这时候Dataset Iterator就是个好的选择

注意,iterators是用于特殊情况的,一般情况下还是使用Dataset比较好

Iteartor 的两个主要方法:

* run() 返回一个Lua 迭代器,也可以使用()操作符,因为iterator源码中定义了__call事件

* exec(funcname,...) 在指定的dataset上执行funcname方法,funcname是dataset自己的方法,比如size

  • tnt.DatasetIterator(self,dataset[,perm][,filter][,transform])

    The default dataset iterator

    perm(idx), 实现shuffle功能,即对idx进行变换,更复杂的变换可以使用ShuffleDataset

    filter(sample), 闭包函数,筛选样本是否用于迭代,返回bool值

    transform(sample),闭包函数,实现对样本的变换,更复杂的变换可以结合TransformDataset和transform.compose等实现

  1. ldata = tnt.ListData{list=torch.range(1,10):long(),load = function(x) return {x,x+1} end} 

  2. dIter = tnt.DatasetIterator{dataset = ldata,filter = function(x) if x[1]<2 then return false else return true end end} 

  3. for v in dIter:run() 

  4. print(v) 

  5. end 

  • tnt.ParallelDatasetIterator(self[,init],closure,nthread[,perm][,filter][,transform][,ordered])

    这个才是迭代器的重点,用于以多线程方式迭代数据。

nthreads 指定了线程的个数

init(threadid) 闭包函数,指定了线程threadid的初始化工作,如果啥都不做可以省略

closure(threadid) 每个线程的job,返回的必须时tnt.Dataset的一个实例

perm(idx) 用于shuffle

filter(sample) 闭包函数,指定哪些样本不用于迭代

transform(sample) 对样本进行变换,在filter之前执行

order 线程之间数据的处理是否有序,主要是为了程序的可重现性,当order=true时,多次执行程序,顺序是相同的

  1. tnt=require'torchnet' 

  2. local list=torch.Tensor{{2,2},{2,2},{2,2},{2,2}}:long() 

  3. ldata = tnt.ListDataset{list=list,load=function(x) return torch.Tensor(x[1],x[2]) end} 

  4. local bdata = tnt.BatchDataset{batchsize=2,dataset = tnt.TransformDataset{dataset = ldata,transform=function(x) return 2*x end}} 

  5. Padata = tnt.ParallelDatasetIterator{ 

  6. nthread = 4, 

  7. init = function(tid) 

  8. print ('init thread id: '.. tid) 

  9. tnt=require'torchnet' 

  10. end, 

  11. closure = function(tid) 

  12. print('closure of threadid: '.. tid) 

  13. return bdata 

  14. end 

  15. }  

尤其需要注意的是,closure中的所有upvalues都必须是可序列化的,最好是避免使用upvalues,并保证closure中使用的package都在init中require

tnt.Engine

在网络训练的过程中,都是计算前向误差,误差反传,更新权重这些过程,只是模型,数据和评价函数不同而已,所以Engine给训练过程提供了一个模板,该模板建立了model,DatasetIterator,Criterion和Meter之间的联系

engine=tnt.Engine()包含两个主要方法

* engine:train() 在数据集上训练数据

* engine:test() 评估模型,可选

Engine不仅实现了训练和评估的一般模板,还提供了许多接口,用于控制训练过程

  • tnt.SGDEngine

    SGDEngine 模块在train过程中使用Stochastic Gradient Descent方法训练,模块包含数据采样,前向传递,反向传递,参数更新等,还有一些钩子函数

    hooks = {

    ['onStart'] = function() end, --用于训练开始前的设置和初始化

    ['onStartEpoch'] = function() end, -- 每一个epoch前的操作

    ['onSample'] = function() end, -- 每次采样一个样本之后的操作

    ['onForward'] = function() end, -- 在model:forward()之后的操作

    ['onForwardCriterion'] = function() end, -- 前向计算损失函数之后的操作

    ['onBackwardCriterion'] = function() end, -- 反向计算损失误差之后的操作

    ['onBackward'] = function() end, -- 反向传递误差之后的操作

    ['onUpdate'] = function() end, -- 权重参数更新之后的操作

    ['onEndEpoch'] = function() end, -- 每一个epoch结束时的操作

    ['onEnd'] = function() end, -- 整个训练过程结束后的收拾现场

    }

    可以发现Engine给的hook函数还是很全面的,几乎训练过程的每一个节点都允许用户制定操作,使用hook函数

  1. local engine = SGDEngine() 

  2. local meter = tnt.AverageValueMeter() 

  3. engine.hooks.onStartEpoch = function(state) meter:reset() end 

一般而言,训练过程最少应该知道训练模型,损失函数,数据和学习率,这里学习方法已经知道了SGD,Engine用到的数据是tnt.DatasetIterator类型的。 评估过程只需要数据和模型就可以了

外部可以通过state变量与Engine训练过程交互

state = {

['network'] = network, --设置了model

['criterion'] = criterion, -- 设置损失函数

['iterator'] = iterator, -- 数据迭代器

['lr'] = lr, -- 学习率

['lrcriterion'] = lrcriterion, --

['maxepoch'] = maxepoch, --最大epoch数

['sample'] = {}, -- 当前采集的样本,可以在onSample中通过该阈值查看采样样本

['epoch'] = 0 , -- 当前的epoch

['t'] = 0, -- 已经训练样本的个数

['training'] = true -- 训练过程

}

评估时需要指定:

state = {

['netwrok'] = network

['iterator'] = iterator

['criterion'] = criterion

}

  • tnt.OptimEngine

    这个方法和SGDEngine的最大的区别在于封装了optim中的多种优化方法。在训练开始的时候,engine会通过getParameters获取model的参数

    train需要附加两个量:

    • optimMethod 优化方法,比如optim.sgd

    • config 优化方法对应的参数

      Example:

  1. local engine = tnt.OptimEngine{ 

  2. network = network, 

  3. criterion=criterion, 

  4. iterator = iterator, 

  5. optimMethod = optim.sgd, 

  6. config = { 

  7. learningRate = 0.1, 

  8. momentum = 0.9, 

  9. }, 



tnt.Meter

和Engine配合使用,用于measure the model.

几乎所有的meters都会有3个方法:

* add() 给待统计的meter添加一个观测值,其输入参数一般形式为(output,value),output为model的输出,target为真实值

* value() 获得待统计的meter的当前值

* reset() 重新计数

Meter的使用示例:

  1. local meter = tnt.<Measure>Meter() -- <Measure> 可以选择具体的度量 

  2. for state,event in tnt.<Optimization>Engine:train{ --定义Engine 

  3. network = network, 

  4. criterion=criterion, 

  5. iterator=iterator, 

  6. } do 

  7. if state == 'start-epoch' then  

  8. meter:reset() -- reset meter 

  9. elseif state == 'forward-criterion' then 

  10. meter:add(state.network.output,sample.target) 

  11. elseif state == 'end-epoch' then 

  12. print('value of meter:) .. meter:value()) 

  13. end 

  14. end 

  • tnt.APMeter(self)

    评估每一类的平均正确率

    APMeter的操作对象是一个torchnet package (2)-LMLPHP的Tensor,表示N个样本对应在K类中的值,另外可选的一个torchnet package (2)-LMLPHP的 Tensor表示每个样本的权重

  1. target = torch.Tensor{ 

  2. {0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,0,0,0}} 

  3. apm = tnt.APMeter() 

  4. for i=1,5 do 

  5. apm:add{output=torch.rand(1,4),target=target[i]:size(1,4)} -- 注意N*K的Tensor 

  6. end 

  7. print(apm:value()) 

  • tnt.AverageValueMeter(self)

    用于统计任意添加的变量的方差和均值,可以用来测量平均损失等

    add()的输入必须时number类型,另外在add的时候可以有一个可选的参数n,表示对应值的权重

  1. avm = tnt.AverageValueMeter() 

  2. for i=1,10 do  

  3. avm:add(i,10-i) 

  4. end 

  5. print(avm:value()) -- 输出 4 2.4720... 

  • tnt.AUCMeter(self)

    对于二分类问题计算Area Under Curve (AUC).

    AUCMeter操作的变量是1D的tensor

  • tnt.ConfusionMeter(self,k[,nirmalized])

    多类之间的混淆矩阵,注意不是多类多标签问题,多标签是指一个类的实例可能分配多个标签,这类问题参见tnt.MultiLabelConfusionMeter

    初始化的时候,需要指定类别数k,normalized指定是否将confuse matrix 归一化,归一化之后输出的是百分比,否则是数值

    add(output,target) 输入都是torchnet package (2)-LMLPHP的tensor,这里为什么每次都是N个样本一起输入呢?这是因为往往训练模型都是Batch模式处理的,target可以是N的tensor,每个值表示对应类别标号,也可以时NK的tensor表示类别的one-hot vector

    value()返回K
    K的混淆矩阵行表示groundtruth,列表示predicted targets

  • tnt.mAPMeter(self)

    统计所有类别之间的平均正确率,和APMeter参数完全一致,不同的时value()返回的是多个类别总的正确率

  • tnt.MovingAverageValueMeter(self,windowsize)

    该meter和AverageValueMeter非常类似,输入的也是number,不同在于他统计的不是所有的number的均值和方差,而是往前windowsize时间窗内的numbers的均值和方差,windowsize在初始化时需要指定

  • tnt.MultiLabelConfusionMeter(self,k[,normalized])

    多类多标签混淆矩阵,这个没接触过,不知道理解对不对,先放这吧,需要的时候再看

  • tnt.ClassErrorMeter(self[,topk][,accuracy])

    参数: topk = table

    accuracy = boolean

    该meter用于统计分类误差,topk是一个table指定分别统计前k类预测误差,如ImageNet Competition中的Top5类误差,accuracy表示返回的是正确了还是错误率,accuracy=true,返回的就是1-error

    add(output,target),output是一个torchnet package (2)-LMLPHP的tensor,target可以使一个N的tensor也可以是一个torchnet package (2)-LMLPHP的tensor,参考之前的AUCMeter

    value()返回的时topk误差,value(k)返回的是第topk类误差

  • tnt.TimeMeter(self[,unit])

    这个Meter用于统计events之间的时间,也可以用来统计batch数据的平均处理数据。她很特别!

    unit在初始的时候给定,是一个布尔值,默认false,当设置为true时,返回值将会被incUnit()值平均,计算平均时间消耗。

    tnt.TimeMeter提供的方法有:

    • reset() 重置timer,unit counter

    • stop() stop the timer

    • resume() 唤醒timer

    • incUnit() uint+1

    • value() 返回从reset()到现在的时间消耗

  • tnt.PrecisionAtKMeter(self[,topk][,dim][,online])

待补充
  • tnt.RecallMeter(self[,threshold][,preclass])

    统计threshold下的召回率,threshold是一个table类型,每个元素是一个阈值,默认值为0.5. perclass是一个布尔值,表示是单独统计每一类的召回率还是统计整个召回率,默认值是false

    add(output,target) output是N*K的概率矩阵,行和为1;target是NK的二值矩阵,不一定行和为1,如{0,1,0,1}

    value()返回的是table值,对应的是threshold table中指定阈值下的召回率,如果perclass = true,那么table的每个元素就是一个table

  • tnt.PrecisionMeter(self[,threshold][,perclass])

    参考RecallMeter,这里计算的是正确率

  • tnt.NDCGMeter(self[,K])

    计算normalized discounted cumulative gain,没使用过。。。。

tnt.Log

Log是一个由sting key索引的table,这些keys必须在构造函数中指定,有一个特殊的键 __status__可以在log:status()函数中设置用于记录一些基本的messages

Log中提供的一些closures以及对应attached events

* onSet(log,key,value) 对应着给键赋值 log:set{}

* onGet(log,key) 对应着读取key对应的值 log:get()

* onFlush(log) 对应着清空log log:flush()

* onClose(log) 对应log:close() 关闭log

示例:

  1. tnt = require'torchnet' 

  2. logtext = require 'torchnet.log.view.text' 

  3. logstatus = require 'torchnet.log.view.status' 

  4. log = tnt.log{ 

  5. keys = {'loss','accuracy'} 

  6. onFlush = { 

  7. -- write out all keys in "log" file 

  8. logtext{filename='log.txt', keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}}, 

  9. -- write out loss in a standalone file 

  10. logtext{filename='loss.txt', keys={"loss"}}, 

  11. -- print on screen too 

  12. logtext{keys={"loss", "accuracy"}}, 

  13. }, 

  14. onSet = { 

  15. -- add status to log 

  16. logstatus{filename='log.txt'}, 

  17. -- print status to screen 

  18. logstatus{}, 






  19. -- set values 

  20. log:set{ 

  21. loss = 0.1, 

  22. accuracy = 97 




  23. -- write some info 

  24. log:status("hello world") 


  25. -- flush out log 

  26. log:flush() 



后面我们来看一个具体的例子,以VGG16为例实现一个Siamese CNN网络计算patch之间的相似度


04-14 05:12