数据:
【MATLAB实战】基于UNet的肺结节的检测-LMLPHP
【MATLAB实战】基于UNet的肺结节的检测-LMLPHP
训练过程图
【MATLAB实战】基于UNet的肺结节的检测-LMLPHP
算法简介:
UNet网络是分割任务中的一个经典模型,因其整体形状与"U"相似而得名,"U"形结构有助于捕获多尺度信息,并促进了特征的精确重建,该网络整体由编码器,解码器以及跳跃连接三部分组成。
编码器由一系列卷积层(Convenlution)和池化层(Polling) 组成,用于逐步降低输入图像的空间尺寸和通道数,同时提取图像的高层特征信息:
解码器由一系列上采样层和卷积层组成,用于逐步还原特征图的空间尺寸和通道细节信息。
在编码器和解码器通过跳跃连接将编码器的某一层的特征图与对应的解码器层的特征图连接起来。正是因为"U"结构的有效性,UNet网络被许多学者沿用至今。
UNet网络的基本块由两个卷积和ReLu激活函数构成,使用3x3的卷积核尺寸来捕捉上下文信息。网络的左侧部分构成UNet的编码器,负责从输入影像中提取特征信息。
在UNet的解码阶段,解码基本块与编码阶段的基本块是一一对应的。从瓶颈层出发,通过上采样将特征图放大,然后经过解码基本块进行特征信息的解码重建。
此外跳跃连接(copyandcrop)的设计允许将编码阶段提取的特征信息传递到解码基本块中,有助于进一步恢复细节信息。整个过程重复进行四次,完成对病灶区域的分割。
【MATLAB实战】基于UNet的肺结节的检测-LMLPHP

运行视频:

【MATLAB实战】基于UNet的肺结节的检测

代码:

function expName = unet(expName, size, encoderDepth, filters, batchsize, epochs, useDataAugmentation, L2Reg, lr, ...
    gradientclipping, path, splits, folders, savePredictionsFolder, classNames, labelIDs, valPat, isMAT )

% 创建unet 网络
numClasses  = length(classNames);
lgraph = unetLayers(size,numClasses,'EncoderDepth',encoderDepth, 'NumFirstEncoderFilters', filters)

%%     TRAIN       %%
% 加载训练集
if isMAT == true
    imdsTrain = imageDatastore(strcat(path, '/', splits(1), '/', folders(1)), 'FileExtensions','.mat', 'ReadFcn', @loadMAT);
else
    imdsTrain = imageDatastore(strcat(path, '/', splits(1), '/', folders(1)));
end
    pxdsTrain = pixelLabelDatastore(strcat(path, '/', splits(1), '/', folders(2)),classNames,labelIDs);

    tbl = countEachLabel(pxdsTrain)

     imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
     medFreqClassWeights = median(imageFreq) ./ imageFreq

layer_to_add = [pixelClassificationLayer('Classes',classNames,'ClassWeights',medFreqClassWeights,'Name','Segmentation-Layer')];
% 替换层
lgraph = replaceLayer(lgraph,'Segmentation-Layer',layer_to_add);
Display the network.
analyzeNetwork(lgraph)

% 创建 datastore 
if useDataAugmentation == false % 不适用图像增强
%     testpxds = pixelLabelDatastore(testlabelDir,classNames,labelIDs);
    ds = pixelLabelImageDatastore(imdsTrain,pxdsTrain);
else % 图像增强
    augmenter = imageDataAugmenter('RandXReflection',true, 'RandYReflection',true)%,'RandRotation',[-10 10], 'RandXTranslation', [-5 5], 'RandYTranslation', [-5 5]);  
    ds = pixelLabelImageDatastore(imdsTrain,pxdsTrain, 'DataAugmentation', augmenter);
end 


% 加载验证集.
if isMAT==true
    imdsVal = imageDatastore(strcat(path, '/', splits(3), '/', folders(1)), 'FileExtensions','.mat', 'ReadFcn', @loadMAT);
else
    imdsVal = imageDatastore(strcat(path, '/', splits(3), '/', folders(1)));
end
pxdsVal = pixelLabelDatastore(strcat(path, '/', splits(3), '/', folders(2)),classNames,labelIDs);  
valData = pixelLabelImageDatastore(imdsVal, pxdsVal); 
valFreq = floor(length(ds.Images)/batchsize)

% 设置 OPTIONS
options = trainingOptions(...
    'adam', ...    
    ... 'rmsprop' 
    ... 'sgdm','Momentum', 0.9, ...
    'InitialLearnRate',lr, ...
    ... 'LearnRateSchedule','piecewise', ...
    ... 'LearnRateDropFactor', dropfactor, ...
    ... 'LearnRateDropPeriod', 1, ...
    'MaxEpochs',epochs, ...
    'VerboseFrequency',10, ...
    'MiniBatchSize' , batchsize, ...
    'Plots','training-progress', ...
    'L2Regularization',L2Reg, ...,
    'ValidationData',valData, ...,
    'ValidationFrequency', valFreq,...,
    'ValidationPatience', valPat,...,
    ... 'GradientThresholdMethod','l2norm',...
    ... 'GradientThreshold',gradientclipping, ...
    'Shuffle','every-epoch', ...
    'ExecutionEnvironment', 'gpu');

% 训练网络
tic;
[net,info] = trainNetwork(ds,lgraph,options)
traintime=toc;
save net.mat net;
%%         TEST       %%
% 加载测试集
if isMAT==true
    imdsTest = imageDatastore(strcat(path, '/', splits(2), '/', folders(1)), 'FileExtensions','.mat', 'ReadFcn', @loadMAT);
else
    imdsTest = imageDatastore(strcat(path, '/', splits(2), '/', folders(1)));
end
pxdsTest = pixelLabelDatastore(strcat(path, '/', splits(2), '/', folders(2)),classNames,labelIDs);

% 在测试映像上运行网络。预测标签作为pixelLabelDatastore返回。
tic
pxdsResults = semanticseg(imdsTest,net, 'MiniBatchSize',batchsize,"WriteLocation", savePredictionsFolder);
toc

% 计算混淆矩阵和分割度量(根据实际情况评估预测结果)
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest)
metrics.ClassMetrics
metrics.NormalizedConfusionMatrix
metrics.ConfusionMatrix
metrics.DataSetMetrics
% 设置保存目录

mkdir(['ExpUNet/',expName]);
% logs 保存
save(['ExpUNet/',expName, '/results'])
% 保存评价指标
writetable(metrics.DataSetMetrics,['ExpUNet/',expName,'/dataset.csv'])
writetable(metrics.ClassMetrics, ['ExpUNet/',expName,'/classmetrics.csv'])
writetable(metrics.ConfusionMatrix, ['ExpUNet/',expName,'/confusionmatrix.csv'])
writetable(metrics.NormalizedConfusionMatrix, ['ExpUNet/',expName,'/normconfusionmatrix.csv'])
% 测试 6 张图像 结果保存
saveTestImages(net, imdsTest, pxdsTest,classNames, ['ExpUNet/',expName,'/ejemplos.png'], labelIDs, isMAT)

end


function final_matrix = loadMAT(filename)
    load(filename)
end

代码链接:https://download.csdn.net/download/qq_45047246/89565243

07-22 12:24