Note
Click here to download the full example code
基础知识 || 快速入门 || 张量 || 数据集与数据加载器 || Transforms || 构建神经网络 || 自动微分 || 优化模型参数 || 保存和加载模型
Transforms¶
数据并不总是以训练机器学习算法所需的最终处理形式呈现。我们使用**transforms**来对数据进行一些处理,使其适用于训练。
所有 TorchVision 数据集都有两个参数 - transform 用于修改特征,target_transform 用于修改标签 - 它们接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的转换。
FashionMNIST 的特征是以 PIL 图像格式呈现的,标签是整数。对于训练,我们需要将特征转换为归一化的张量,
将标签转换为编码的张量。为了进行这些转换,我们使用了 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Lambda Transforms¶
Lambda transforms 应用任何用户定义的 lambda 函数。这里,我们定义一个函数将整数转换为独热编码的张量。
它首先创建一个大小为 10(我们数据集中标签的数量)的零张量,然后调用 scatter_,
在由标签 y
指定的索引上赋值为 1
。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
延伸阅读¶
Total running time of the script: ( 0 minutes 0.000 seconds)