vlambda博客
学习文章列表

【图像识别】基于Matlab的迁移学习的图像分类案例

大家好
我们今天来讲一讲如何
用Matlab做一个新的迁移学习
您可能需要的基础知识
Matlab编程Deep learning的基础知识

一、什么是迁移学习?

以图像识别为例。如果你想构建一个神经网络,让它能够识别马匹,但是手上又没有任何公开的算法可以完成这项任务。这时,借助迁移学习,你可以从一个原本是用来识别其它动物的现成的卷积神经网络(CNN)入手,对其进行调整并训练它识别马匹。
深度学习应用中常常用到迁移学习。可以采用预训练的网络,基于它学习新任务。与使用随机初始化的权重从头训练网络相比,通过迁移学习微调网络要更快更简单。我们可以使用较少数量的训练图像快速地将已学习的特征迁移到新任务。

二、网络的创建和数据的导入

加载数据

解压缩新图像并加载这些图像作为图像数据存储。imageDatastore 根据文件夹名称自动标注图像,并将数据存储为 ImageDatastore 对象。通过图像数据存储可以存储大图像数据,包括无法放入内存的数据,并在卷积神经网络的训练过程中高效分批读取图像。
unzip('MerchData.zip');imds = imageDatastore('MerchData', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
将数据划分为 训练数据集和验证数据集。 70% 的图像用于训练, 30% 的图像用于验证。splitEachLabel 将 images 数据存储拆分为两个新的数据存储
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

这个非常小的数据集现在包含 55 个训练图像和 20 个验证图像。

numTrainImages = numel(imdsTrain.Labels);idx = randperm(numTrainImages,16);figurefor i = 1:16 subplot(4,4,i) I = readimage(imdsTrain,idx(i)); imshow(I)end

【图像识别】基于Matlab的迁移学习的图像分类案例

加载预训练网络

加载预训练的 AlexNet 神经网络。如果未安装 Deep Learning Toolbox™ Model for AlexNet Network,则软件会提供下载链接。AlexNet 已基于超过一百万个图像进行训练,可以将图像分为 1000 个对象类别(例如键盘、鼠标、铅笔和多种动物)。因此,该模型已基于大量图像学习了丰富的特征表示。
net = alexnet;

使用 analyzeNetwork 可以交互可视方式呈现网络架构以及有关网络层的详细信息。

analyzeNetwork(net)

【图像识别】基于Matlab的迁移学习的图像分类案例

第一层(图像输入层)需要大小为 227×227×3 的输入图像

其中 3 是颜色通道数

inputSize = 1×3 227 227 3

三、网络的训练

替换最终层

预训练网络 net 的最后三层针对 1000 个类进行配置。必须针对新分类问题微调这三个层。从预训练网络中提取除最后三层之外的所有层。
layersTransfer = net.Layers(1:end-3);
通过将最后三层替换为全连接层、softmax 层和分类输出层,将层迁移到新分类任务。根据新数据指定新的全连接层的选项。将全连接层设置为大小与新数据中的类数相同。要使新层中的学习速度快于迁移的层,请增大全连接层的 WeightLearnRateFactor 和 BiasLearnRateFactor 值。
numClasses = numel(categories(imdsTrain.Labels))numClasses = 5
layers = [ layersTransfer fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20) softmaxLayer classificationLayer];

训练网络

网络要求输入图像的大小为 227×227×3,但图像数据存储中的图像具有不同大小。使用增强的图像数据存储可自动调整训练图像的大小。指定要对训练图像额外执行的增强操作:沿垂直轴随机翻转训练图像,以及在水平和垂直方向上随机平移训练图像最多 30 个像素。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。
pixelRange = [-30 30];imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange);augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... 'DataAugmentation',imageAugmenter);

对验证图像进行分类

使用经过微调的网络对验证图像进行分类

[YPred,scores] = classify(netTransfer,augimdsValidation);

显示四个示例验证图像及预测的标签。

idx = randperm(numel(imdsValidation.Files),4);figurefor i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label));end

计算针对验证集的分类准确度。准确度是网络预测正确的标签的比例

YValidation = imdsValidation.Labels;accuracy = mean(YPred == YValidation)
accuracy = 1



今天你学废了吗???

推荐阅读: