为了在MNIST数据集上实现数据增强,我们可以使用PyTorch中的`torchvision.transforms`模块。在这个例子中,我们将定义一个转换链(`transforms.Compose`),包括一些常用的数据增强技术,比如随机旋转(`RandomRotation`)、随机平移(`RandomAffine`)、随机缩放(`RandomResizedCrop`)等。然而,由于MNIST数据集是黑白图像,一些颜色空间的数据增强方法不适用。以下是如何在`MNISTDataset`类中实现数据增强的示例:
```python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
from PIL import Image
# 假设MNISTDataset类定义如上
# 定义数据增强的转换链
data_transforms = transforms.Compose([
transforms.RandomRotation(10), # 随机旋转+-10度
transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), # 随机剪切和缩放
transforms.RandomResizedCrop(28), # 随机缩放裁剪到28x28
transforms.ToTensor(), # 转换为torch.Tensor
transforms.Normalize((0.5,), (0.5,)) # 归一化
])
class EnhancedMNISTDataset(MNISTDataset):
def __init__(self, data_type, data=None, labels=None):
super().__init__(data_type, data, labels)
if self.data_type in ['train', 'val']:
self.transform = data_transforms
else:
self.transform = None
def __getitem__(self, index):
image = Image.fromarray(self.image[index].astype('uint8'), mode='L')
if self.transform is not None:
image = self.transform(image)
if self.data_type in ['train', 'val']:
label = torch.zeros(10, dtype=torch.float)
label[int(self.label[index])] = 1.0
return image, label
else:
return image
# 读取训练集数据
df_train = pd.read_csv("./data/train.csv")
# 切分特征和标签
train_labels = df_train['label'].to_numpy()
train_data = df_train.drop(columns='label').to_numpy()
# 归一化数据
train_data = train_data.reshape(-1, 28, 28) / 255.0
# 使用train_test_split拆分为train和val
X_train, X_val, y_train, y_val = train_test_split(train_data, train_labels, test_size=0.2, random_state=42)
# 创建训练集和验证集的Dataset
train_dataset = EnhancedMNISTDataset("train", X_train, y_train)
val_dataset = EnhancedMNISTDataset("val", X_val, y_val)
# 创建DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=100, shuffle=False)
# 测试集
df_test = pd.read_csv("./data/test.csv").to_numpy().reshape(-1, 28, 28) / 255.0
test_dataset = EnhancedMNISTDataset("test", df_test)
test_dataloader = DataLoader(test_dataset, batch_size=100, shuffle=False)
```
请注意,由于MNIST数据集是灰度图像,我们在使用`Image.fromarray`时指定了模式为'L',代表灰度图像。数据增强操作将在`__getitem__`方法中应用,这意味着每次获取一个样本时都会应用随机数据增强,从而增加了训练的泛化能力。