Shortcuts

基础知识 || 快速入门 || 张量 || 数据集与数据加载器 || Transforms || 构建神经网络 || 自动微分 || 优化模型参数 || 保存和加载模型

Transforms

数据并不总是以训练机器学习算法所需的最终处理形式呈现。我们使用**transforms**来对数据进行一些处理,使其适用于训练。

所有 TorchVision 数据集都有两个参数 - transform 用于修改特征,target_transform 用于修改标签 - 它们接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的转换。

FashionMNIST 的特征是以 PIL 图像格式呈现的,标签是整数。对于训练,我们需要将特征转换为归一化的张量, 将标签转换为编码的张量。为了进行这些转换,我们使用了 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(
        10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素强度值缩放到范围 [0., 1.]。

Lambda Transforms

Lambda transforms 应用任何用户定义的 lambda 函数。这里,我们定义一个函数将整数转换为独热编码的张量。 它首先创建一个大小为 10(我们数据集中标签的数量)的零张量,然后调用 scatter_, 在由标签 y 指定的索引上赋值为 1

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources