0702-计算机视觉工具包torchvision

pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

一、torchvision 概述

计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,torch 专门开发了一个视觉工具包 torchvision,这个包独立于 torch,需要使用 pip install torchvision 进行安装。

之前的我们已经使用过它的部分功能,在这里我们在做一个系统的介绍,它主要包含以下三个功能:

  • models:提供深度学习中各种经典网络的网络结构以及训练好的模型,包括 Alex-Net、VGG 系列、ResNet 系列、Inception 系列等
  • datasets:提供常用的数据集加载,设计上都是集成 torch.utils.data.Dataset,主要包括 MNIST、CIFAR10/100、ImageNet、COCO 等
  • transforms:提供常用的数据预处理操作,主要包括对 Tensor 以及 PIL Image 对象的操作

二、通过 torchvision 加载模型

from torchvision import models
from torch import nn

# 加载预训练好的模型,如果不存在会下载
# 预训练好的模型保存在 ~/.torch/modes/ 下面
resnet34 = models.resnet34(pretrained=True, num_classes=1000)

# 修改最后的全连接层为 10 分类问题(默认是 ImageNet 上的 1000 分类)
resnet34.fc = nn.Linear(512, 10)

三、通过 torchvision 加载并处理数据集

from torchvision import datasets
from torchvision import transforms as T
# 指定数据集路径为 data,如果数据集不存在则进行下载
# 通过 train=False 获取测试集

normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),  # 把图片转成 Tensor,归一化至 [0,1]
    T.Lambda(lambda x: x.repeat(3, 1, 1)),  # 把图片转为 3 通道的
    normalize,
])

dataset = datasets.MNIST('data/',
                         download=True,
                         train=False,
                         transform=transform)

Transforms 中涵盖了大部分对 Tensor 和 PIL Image 的常用处理,这个转换通常分为两步:

  1. 第一步:构建转换操作,例如 transf = transforms.Normalize(mean=x, std=y)
  2. 第二步:执行转换操作,例如 otuput = transf(inp)
import torch as t

# 构建随机噪声,图片如下图所示
to_pil = T.ToPILImage()
to_pil(t.rand(3, 64, 64))

四、通过 torchvision 拼接并保存图片

torchvision 还提供了两个常用的函数:

  1. make_grid,它能把多张图片拼接在一个网格中
  2. save_img,它能把 Tensor 保存成图片
len(dataset)
10000
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
dataiter
img = make_grid(next(dataiter)[0], 4)  # 拼接成 4*4 网格图片,并且会转成 3 通道,如下图所示
to_img = T.ToPILImage()
to_img(img)
save_image(img, 'a.png')
from PIL import Image
Image.open('a.png')
内容来源于网络如有侵权请私信删除