在 PyTorch 中使用 ResNet 进行迁移学习
介绍
要使用深度学习解决复杂的图像分析问题,网络深度(堆叠数百层)对于从训练数据中提取关键特征并学习有意义的模式非常重要。但是,由于梯度的存在,添加神经层在计算上可能很昂贵,而且存在问题。在本指南中,您将了解深度神经网络的问题、ResNet 如何提供帮助以及如何在迁移学习中使用 ResNet。
重要提示:我强烈建议您在进一步阅读有关 ResNet 和迁移学习的内容之前先了解 CNN 的基础知识。阅读本使用 PyTorch 进行图像分类指南,了解 CNN 的详细描述。
问题
让我们看看残差网络(ResNet)如何拉平曲线。
残差网络
残差网络(简称 ResNet)是一种人工神经网络,它利用跳过连接或捷径跳过某些层来帮助构建更深的神经网络。您将看到跳过如何帮助构建更深的网络层,而不会陷入梯度消失的问题。
ResNet 有不同的版本,包括 ResNet-18、ResNet-34、ResNet-50 等。虽然架构相同,但数字表示层数。
要创建残差块,请在普通神经网络中添加指向主路径的快捷方式,如下图所示。
通过上面的数学计算,我们可以得出结论:
- 残差网络的恒等函数更容易学习
- 最好跳过 1、2 和 3 层。身份函数将与输出函数很好地映射,而不会损害 NN 性能。它将确保较高层的性能与较低层一样好。
ResNet 块
ResNet 中使用的块主要有两种类型,主要取决于输入和输出维度是否相同或不同。
- 身份块:当输入和输出激活维度相同时。
- 卷积块:当输入和输出激活维度彼此不同时。
例如,为了将激活维度 (HxW) 减少 2 倍,可以使用步幅为 2 的 1x1 卷积。
下图展示了残差块的外观以及这些块里面的内容。
数据准备
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import *
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import copy
import os
batch_size = 128
learning_rate = 1e-3
transforms = transforms.Compose(
[
transforms.ToTensor()
])
train_dataset = datasets.ImageFolder(root='/input/fruits-360-dataset/fruits-360/Training', transform=transforms)
test_dataset = datasets.ImageFolder(root='/input/fruits-360-dataset/fruits-360/Test', transform=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def imshow(inp, title=None):
inp = inp.cpu() if device else inp
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
images, labels = next(iter(train_dataloader))
print("images-size:", images.shape)
out = torchvision.utils.make_grid(images)
print("out-size:", out.shape)
imshow(out, title=[train_dataset.classes[x] for x in labels])
使用 Pytorch 进行迁移学习
迁移学习 (TL)的主要目的是快速实现模型。为了解决当前问题,模型不会从头开始创建 DNN(密集神经网络),而是会迁移从执行相同任务的不同数据集中学到的特征。这种迁移也称为知识迁移。
Pytorch API使用models.resnet18(pretrained=True) (来自 TorchVision模型库的函数)调用预先训练的ResNet18模型。ResNet-18 架构如下所述。
net = models.resnet18(pretrained=True)
net = net.cuda() if device else net
net
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
def accuracy(out, labels):
_,pred = torch.max(out, dim=1)
return torch.sum(pred==labels).item()
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 128)
net.fc = net.fc.cuda() if use_cuda else net.fc
最后,添加一个全连接层进行分类,指定类别和特征数量(FC 128)。
n_epochs = 5
print_every = 10
valid_loss_min = np.Inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(train_dataloader)
for epoch in range(1, n_epochs+1):
running_loss = 0.0
correct = 0
total=0
print(f'Epoch {epoch}\n')
for batch_idx, (data_, target_) in enumerate(train_dataloader):
data_, target_ = data_.to(device), target_.to(device)
optimizer.zero_grad()
outputs = net(data_)
loss = criterion(outputs, target_)
loss.backward()
optimizer.step()
running_loss += loss.item()
_,pred = torch.max(outputs, dim=1)
correct += torch.sum(pred==target_).item()
total += target_.size(0)
if (batch_idx) % 20 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch, n_epochs, batch_idx, total_step, loss.item()))
train_acc.append(100 * correct / total)
train_loss.append(running_loss/total_step)
print(f'\ntrain-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}')
batch_loss = 0
total_t=0
correct_t=0
with torch.no_grad():
net.eval()
for data_t, target_t in (test_dataloader):
data_t, target_t = data_t.to(device), target_t.to(device)
outputs_t = net(data_t)
loss_t = criterion(outputs_t, target_t)
batch_loss += loss_t.item()
_,pred_t = torch.max(outputs_t, dim=1)
correct_t += torch.sum(pred_t==target_t).item()
total_t += target_t.size(0)
val_acc.append(100 * correct_t/total_t)
val_loss.append(batch_loss/len(test_dataloader))
network_learned = batch_loss < valid_loss_min
print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}\n')
if network_learned:
valid_loss_min = batch_loss
torch.save(net.state_dict(), 'resnet.pt')
print('Improvement-Detected, save-model')
net.train()
如果增加时期数,准确度将进一步提高。
fig = plt.figure(figsize=(20,10))
plt.title("Train-Validation Accuracy")
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('accuracy', fontsize=12)
plt.legend(loc='best')
def visualize_model(net, num_images=4):
images_so_far = 0
fig = plt.figure(figsize=(15, 10))
for i, data in enumerate(test_dataloader):
inputs, labels = data
if use_cuda:
inputs, labels = inputs.cuda(), labels.cuda()
outputs = net(inputs)
_, preds = torch.max(outputs.data, 1)
preds = preds.cpu().numpy() if use_cuda else preds.numpy()
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(2, num_images//2, images_so_far)
ax.axis('off')
ax.set_title('predictes: {}'.format(test_dataset.classes[preds[j]]))
imshow(inputs[j])
if images_so_far == num_images:
return
plt.ion()
visualize_model(net)
plt.ioff()
结论
该模型的准确率为 97%,非常棒,它可以正确地预测水果。
本指南简要概述了深度神经网络面临的问题、ResNet 如何帮助克服这一问题,以及如何在迁移学习中使用 ResNet 来加速 CNN 的开发。我强烈建议您通过上述资源、执行 EDA 和更好地了解您的数据来了解更多信息。尝试通过冻结和解冻层、增加 ResNet 层数和调整学习率来自定义模型。阅读此帖子以了解更多数学背景知识。如果您还有任何问题,请随时通过CodeAlphabet与我联系。
迁移学习通过将知识迁移到新任务来适应新领域。ResNet 的概念正在创造新的研究角度,使解决现实问题日益高效。
免责声明:本内容来源于第三方作者授权、网友推荐或互联网整理,旨在为广大用户提供学习与参考之用。所有文本和图片版权归原创网站或作者本人所有,其观点并不代表本站立场。如有任何版权侵犯或转载不当之情况,请与我们取得联系,我们将尽快进行相关处理与修改。感谢您的理解与支持!
请先 登录后发表评论 ~