|
7 | 7 | from model import AlexNet
|
8 | 8 | import os
|
9 | 9 | import json
|
| 10 | +import time |
| 11 | + |
| 12 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
10 | 13 |
|
11 | 14 | data_transform = {
|
12 | 15 | "train": transforms.Compose([transforms.RandomResizedCrop(224),
|
|
31 | 34 | with open('class_indices.json', 'w') as json_file:
|
32 | 35 | json_file.write(json_str)
|
33 | 36 |
|
34 |
| -batch_size = 32 |
| 37 | +batch_size = 64 |
35 | 38 | train_loader = torch.utils.data.DataLoader(train_dataset,
|
36 | 39 | batch_size=batch_size, shuffle=True,
|
37 |
| - num_workers=0) |
| 40 | + num_workers=16) |
38 | 41 |
|
39 | 42 | validate_dataset = datasets.ImageFolder(root=image_path + "/val",
|
40 | 43 | transform=data_transform["val"])
|
41 | 44 | val_num = len(validate_dataset)
|
42 | 45 | validate_loader = torch.utils.data.DataLoader(validate_dataset,
|
43 | 46 | batch_size=batch_size, shuffle=False,
|
44 |
| - num_workers=0) |
| 47 | + num_workers=16) |
45 | 48 |
|
46 | 49 | # test_data_iter = iter(validate_loader)
|
47 | 50 | # test_image, test_label = test_data_iter.next()
|
|
54 | 57 |
|
55 | 58 |
|
56 | 59 | net = AlexNet(num_classes=5, init_weights=True)
|
| 60 | + |
| 61 | +net.to(device) |
57 | 62 | loss_function = nn.CrossEntropyLoss()
|
58 | 63 | pata = list(net.parameters())
|
59 |
| -optimizer = optim.Adam(net.parameters(), lr=0.0005) |
| 64 | +optimizer = optim.Adam(net.parameters(), lr=0.0002) |
60 | 65 |
|
61 |
| -for epoch in range(10): |
| 66 | +save_path = './AlexNet.pth' |
| 67 | +best_acc = 0.0 |
| 68 | +for epoch in range(15): |
62 | 69 | # train
|
63 | 70 | net.train()
|
64 | 71 | running_loss = 0.0
|
| 72 | + t1 = time.perf_counter() |
65 | 73 | for step, data in enumerate(train_loader, start=0):
|
66 | 74 | images, labels = data
|
67 | 75 | # imshow(torchvision.utils.make_grid(images))
|
68 | 76 | # print(' '.join('%5s' % flower_set[labels[j]] for j in range(8)))
|
69 | 77 | optimizer.zero_grad()
|
70 |
| - outputs = net(images) |
71 |
| - loss = loss_function(outputs, labels) |
| 78 | + outputs = net(images.to(device)) |
| 79 | + loss = loss_function(outputs, labels.to(device)) |
72 | 80 | loss.backward()
|
73 | 81 | optimizer.step()
|
74 | 82 |
|
75 | 83 | # print statistics
|
76 | 84 | running_loss += loss.item()
|
| 85 | + # print train process |
| 86 | + rate = (step + 1) / len(train_loader) |
| 87 | + a = "*" * int(rate * 50) |
| 88 | + b = "." * int((1 - rate) * 50) |
| 89 | + print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") |
| 90 | + print() |
| 91 | + print(time.perf_counter()-t1) |
77 | 92 |
|
78 | 93 | # validate
|
79 | 94 | net.eval()
|
80 | 95 | acc = 0.0 # accumulate accurate number / epoch
|
81 | 96 | with torch.no_grad():
|
82 | 97 | for data_test in validate_loader:
|
83 | 98 | test_images, test_labels = data_test
|
84 |
| - outputs = net(test_images) |
| 99 | + outputs = net(test_images.to(device)) |
85 | 100 | predict_y = torch.max(outputs, dim=1)[1]
|
86 |
| - acc += (predict_y == test_labels).sum().item() |
| 101 | + acc += (predict_y == test_labels.to(device)).sum().item() |
| 102 | + accurate_test = acc / val_num |
| 103 | + if accurate_test > best_acc: |
| 104 | + best_acc = accurate_test |
| 105 | + torch.save(net.state_dict(), save_path) |
87 | 106 | print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
|
88 | 107 | (epoch + 1, running_loss / step, acc / val_num))
|
89 | 108 |
|
90 |
| - |
91 | 109 | print('Finished Training')
|
92 |
| -save_path = './AlexNet.pth' |
93 |
| -torch.save(net.state_dict(), save_path) |
|
0 commit comments