CC BY 4.0 (除特别声明或转载文章外)
如果这篇博客帮助到你,可以请我喝一杯咖啡~
快速开始
本节包含ML中常见任务API,具体深入请参考相关链接
数据处理
pytorch 有两个处理数据的包:torch.utils.data.DataLoader 和torch.utils.data.Dataset。dataset存储样本及对应标签,并通过DataLoader加载Dataset
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch 提供特定领域的库,比如TorchText、TorchVision和TorchAudio,所有这些库都包含数据集。在本教程中我们将使用TorchVision数据集
该torchvision.datasets模块包含Dataset许多真实世界视觉数据的对象,如 CIFAR、COCO(此处为完整列表)。在本教程中,我们使用 FashionMNIST 数据集。每个 TorchVision 都Dataset包含两个参数:transform和 target_transform分别修改样本和标签
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
out:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 65536/26421880 [00:00<01:23, 316892.83it/s]
0%| | 131072/26421880 [00:00<01:00, 436688.36it/s]
1%| | 229376/26421880 [00:00<00:42, 610932.89it/s]
2%|1 | 491520/26421880 [00:00<00:21, 1233770.64it/s]
4%|3 | 950272/26421880 [00:00<00:11, 2215146.36it/s]
7%|7 | 1933312/26421880 [00:00<00:05, 4375338.79it/s]
15%|#4 | 3833856/26421880 [00:00<00:02, 8455842.06it/s]
26%|##6 | 6979584/26421880 [00:00<00:01, 14652653.31it/s]
38%|###8 | 10092544/26421880 [00:01<00:00, 18878820.79it/s]
50%|####9 | 13172736/26421880 [00:01<00:00, 21654219.14it/s]
62%|######1 | 16285696/26421880 [00:01<00:00, 23649710.22it/s]
72%|#######2 | 19136512/26421880 [00:01<00:00, 24129058.38it/s]
84%|########4 | 22282240/26421880 [00:01<00:00, 25412955.15it/s]
96%|#########6| 25427968/26421880 [00:01<00:00, 26315921.43it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 16023208.49it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 272213.07it/s]
100%|##########| 29515/29515 [00:00<00:00, 270949.40it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%| | 32768/4422102 [00:00<00:14, 304160.00it/s]
1%|1 | 65536/4422102 [00:00<00:14, 302890.74it/s]
3%|2 | 131072/4422102 [00:00<00:09, 440396.43it/s]
5%|5 | 229376/4422102 [00:00<00:06, 624615.73it/s]
11%|#1 | 491520/4422102 [00:00<00:03, 1270598.77it/s]
21%|##1 | 950272/4422102 [00:00<00:01, 2277278.48it/s]
44%|####3 | 1933312/4422102 [00:00<00:00, 4493671.94it/s]
87%|########6 | 3833856/4422102 [00:00<00:00, 8645529.02it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5082787.38it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 20486031.30it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
我们将Dataset作为参数传递给DataLoader。这里我们的数据集可以进行迭代,并支持自动批处理、采样、混洗和多进程数据加载。这里我们定义了一个64的batch size,即dataloader iterable中的每个元素都会返回一个batch 64个特征和标签。
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
Out:
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64