用CIFAR100数据集来训练图像分类

最近在学习如何进行图像分类和识别,比如给一张狗的图片,系统能够准确识别

查下来目前初学者用的最多的是CIFAR10和CIFAR100,

CIFAR100是一个在线数据集,包含了100个分类,每个分类600张图片,供咱们训练使用

接下来直接上代码

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# 定义数据预处理步骤
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.RandomHorizontalFlip(),     # 随机水平翻转
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # 颜色增强
    transforms.RandomRotation(15),         # 随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))  # 归一化
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
])

# 加载 CIFAR-100 数据集
trainset = torchvision.datasets.CIFAR100(root='./data100', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./data100', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

# 使用预训练的 ResNet50 模型并调整第一层卷积层和删除池化层
model = models.resnet50(pretrained=False)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # 适合小图像
model.maxpool = nn.Identity()  # 删除池化层
model.fc = nn.Linear(2048, 100)  # 修改最后一层用于 CIFAR-100 分类

# 定义设备(GPU / CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 训练函数
def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # 应用梯度剪裁,防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}], Step [{batch_idx}/{len(trainloader)}], Loss: {running_loss / (batch_idx + 1):.3f}, Accuracy: {100. * correct / total:.3f}%')

# 测试函数
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print(f'Test Accuracy: {100. * correct / total:.3f}%')

# 训练和测试模型
epochs = 70  # 增加训练 epoch
for epoch in range(epochs):
    train(epoch)
    test()
    scheduler.step()  # 更新学习率

# 保存模型
torch.save(model.state_dict(), 'cifar100.pth')

用此代码训练,准确率可达70%,我是RTX3060显卡,训练了2个小时

训练好的模型保存为cifar100.pth,接下来我们进行测试一下

测试代码如下

import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from torchvision import models

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载训练好的模型
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(2048, 100)  # CIFAR-100 具有 100 个分类
model.load_state_dict(torch.load('cifar100.pth'))  # 加载模型权重
model = model.to(device)
model.eval()  # 设置模型为评估模式

# 定义预处理函数,和训练时的一致
preprocess = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图片调整为与 CIFAR-100 数据集的大小一致
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
])

# 类别标签(CIFAR-100 的类别)
classes = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 
           'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 
           'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
           'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 
           'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 
           'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 
           'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 
           'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 
           'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 
           'whale', 'willow_tree', 'wolf', 'woman', 'worm']

# 预测函数
def predict_image(image_path):
    image = Image.open(image_path)  # 打开图片
    image = preprocess(image).unsqueeze(0).to(device)  # 预处理并增加批次维度

    with torch.no_grad():  # 关闭梯度计算
        outputs = model(image)  # 进行前向传播
        _, predicted = outputs.max(1)  # 获取预测的类别索引
    
    return classes[predicted.item()]  # 返回对应的类别名称

# 测试图片目录
image_dir = './test/'

# 遍历目录下的所有图片并进行预测
for img_file in os.listdir(image_dir):
    if img_file.endswith(('jpg', 'jpeg', 'png')):
        img_path = os.path.join(image_dir, img_file)
        prediction = predict_image(img_path)
        print(f'Image: {img_file} is predicted as {prediction}')

测试图片文件夹为test,里面的图片文件为test1.png,test2.png,test3.png…

只要图片在这100个类别中,基本都可以识别,识别准确率也很高