-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
105 lines (87 loc) · 3.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import timm
import random
import numpy as np
model_num = [4, 8, 12, 16, 20, 24, 28, 32, 36]
total_epoch = 100 # total epoch
lr = 0.001 # initial learning rate
for s in model_num:
# fix random seed
seed_number = s
random.seed(seed_number)
np.random.seed(seed_number)
torch.manual_seed(seed_number)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# Define the data transforms
transform_train = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.TrivialAugmentWide(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# Load the CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)
# Define the ResNet-18 model with pre-trained weights
model = timm.create_model('resnet18', pretrained=True, num_classes=10)
model = model.to(device) # Move the model to the GPU
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adamax(model.parameters(), lr=lr)
# Define the learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
def train():
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # Move the input data to the GPU
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
def test():
model.eval()
# Test the model
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device) # Move the input data to the GPU
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %f %%' % (100 * correct / total))
# Train the model
for epoch in range(total_epoch):
train()
test()
scheduler.step()
print('Finished Training')
# Save the checkpoint of the last model
PATH = './resnet18_cifar10_%f_%d.pth' % (lr, seed_number)
torch.save(model.state_dict(), PATH)