Transfer Learning tutorial
http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
Transfer Learning tutorial
작성자: Sasank Chilamkurthy
이 튜토리얼에서는 transfer learning을 이용하여 어떻게 네트워크를 학습시키는지를 배울 것입니다. Transfer learning에 대해서는 cs231n 강의노트를 참고하세요.
강의노트를 인용해보겠습니다.
실제로는 컨볼루셔널 네트워크를 처음부터 학습시키는 사람은 아주 극소수입니다. 충분한 크기의 데이터셋을 가진 사람은 상대적으로 적기 때문입니다. ConvNet을 ImageNet같이 큰 데이터셋에서 선학습(pretrain)한 다음, 그 ConvNet을 초기 상태로 사용하거나 우리가 관심있는 작업을 위한 특징 추출기로 고정해놓고 사용하는 것이 일반적입니다.
그에 따른 두가지 transfer learning 큰 시나리오는 다음과 같습니다.
- ConvNet의 파인튜닝 : 랜덤 초기화를 하는 대신 ImageNet 1000 데이터셋과 같은 데이터에서 선학습된 네트워크로 초기화로 사용하고 그 뒤의 학습을 진행하는 것
- 특징 추출기로써의 ConvNet : 마지막 fully connected 층을 제외한 나머지 모든 네트워크의 가중치를 고정시킨 뒤, 마지막 fully connected 층을 새로운 랜덤 가중치로 초기화 한 뒤, 이 층만 학습하는 것
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# License: BSD # Author: Sasank Chilamkurthy from __future__ import print_function, division import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable import numpy as np import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt import time import os import copy plt.ion() # interactive mode |
Load Data
torchvision과 torch.utils.data 패키지를 이용해서 데이터를 로딩해보겠습니다.
여기서 해결하려는 문제는 개미와 벌을 분류하는 모델을 학습시키는 것입니다. 120개 정도의 개미와 벌 학습 이미지를 가지고 있습니다. 또 각각 75개의 검증용 이미지도 가지고 있습니다. 보통 처음부터 학습려는 상황이라면, 일반화(generalize)를 하기에는 매우 적은 수의 데이터셋입니다. 하지만 transfer learning을 사용할 것이기 때문에 잘 generalize할 수 있습니다.
사실 이 데이터셋은 imagenet의 일부입니다.
Note: 데이터를 여기에서 다운로드받아 현재 디렉토리에 풀어놓습니다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# Data augmentation and normalization for training # Just normalization for validation data_transforms = { 'train': transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } data_dir = 'hymenoptera_data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes use_gpu = torch.cuda.is_available() |
Visualize a few images
이제 데이터 augmentation을 이해하기 위해서 약간의 트레이닝 이미지를 출력해봅시다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
def imshow(inp, title=None): """Imshow for Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) if title is not None: plt.title(title) plt.pause(0.001) # pause a bit so that plots are updated # Get a batch of training data inputs, classes = next(iter(dataloaders['train'])) # Make a grid from batch out = torchvision.utils.make_grid(inputs) imshow(out, title=[class_names[x] for x in classes]) |
Training the model
이제 모델을 학습하기 위한 함수를 작성해봅시다. 다음의 것들을 설명하도록 하겠습니다.
- 학습률(learning rate)의 스케쥴링
- 가장 좋은 모델을 저장
다음 나올 소스에서 scheduler
파라미터는 LR 스케쥴러 객체로 torch.optim.lr_scheduler
에서 나온 것입니다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': scheduler.step() model.train(True) # Set model to training mode else: model.train(False) # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Iterate over data. for data in dataloaders[phase]: # get the inputs inputs, labels = data # wrap them in Variable if use_gpu: inputs = Variable(inputs.cuda()) labels = Variable(labels.cuda()) else: inputs, labels = Variable(inputs), Variable(labels) # zero the parameter gradients optimizer.zero_grad() # forward outputs = model(inputs) _, preds = torch.max(outputs.data, 1) loss = criterion(outputs, labels) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.data[0] * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) # deep copy the model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) return model |
Visualizing the model predictions
약간의 이미지에서 예측한 결과를 출력하는 함수를 작성합시다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
def visualize_model(model, num_images=6): images_so_far = 0 fig = plt.figure() for i, data in enumerate(dataloaders['val']): inputs, labels = data if use_gpu: inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) else: inputs, labels = Variable(inputs), Variable(labels) outputs = model(inputs) _, preds = torch.max(outputs.data, 1) for j in range(inputs.size()[0]): images_so_far += 1 ax = plt.subplot(num_images//2, 2, images_so_far) ax.axis('off') ax.set_title('predicted: {}'.format(class_names[preds[j]])) imshow(inputs.cpu().data[j]) if images_so_far == num_images: return |
Finetuning the convnet
선학습된 모델을 불러온 뒤, 마지막 fully connected 층을 초기화합니다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, 2) if use_gpu: model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss() # Observe that all parameters are being optimized optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) |
Train and evaluate
CPU로 실행하면 15-25분 정도, GPU에서는 1분 이내의 시간이 걸립니다.
1 2 3 |
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) |
Out:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 Epoch 0/24----------train Loss: 0.6603 Acc: 0.6393val Loss: 0.2358 Acc: 0.9020Epoch 1/24----------train Loss: 0.5158 Acc: 0.7623val Loss: 0.2329 Acc: 0.9020Epoch 2/24----------train Loss: 0.4975 Acc: 0.7951val Loss: 0.1763 Acc: 0.9281Epoch 3/24----------train Loss: 0.4227 Acc: 0.8197val Loss: 0.3499 Acc: 0.8562Epoch 4/24----------train Loss: 0.4385 Acc: 0.8484val Loss: 0.3266 Acc: 0.8824Epoch 5/24----------train Loss: 0.5485 Acc: 0.7951val Loss: 0.4423 Acc: 0.8627Epoch 6/24----------train Loss: 0.4820 Acc: 0.8443val Loss: 0.2096 Acc: 0.9346Epoch 7/24----------train Loss: 0.4656 Acc: 0.8320val Loss: 0.2239 Acc: 0.9150Epoch 8/24----------train Loss: 0.2709 Acc: 0.9057val Loss: 0.2294 Acc: 0.9281Epoch 9/24----------train Loss: 0.4037 Acc: 0.8525val Loss: 0.2265 Acc: 0.9216Epoch 10/24----------train Loss: 0.4505 Acc: 0.8443val Loss: 0.2208 Acc: 0.9150Epoch 11/24----------train Loss: 0.3570 Acc: 0.8607val Loss: 0.2022 Acc: 0.9412Epoch 12/24----------train Loss: 0.3294 Acc: 0.8811val Loss: 0.2164 Acc: 0.9216Epoch 13/24----------train Loss: 0.2821 Acc: 0.8811val Loss: 0.2422 Acc: 0.9085Epoch 14/24----------train Loss: 0.2427 Acc: 0.8852val Loss: 0.2083 Acc: 0.9346Epoch 15/24----------train Loss: 0.4063 Acc: 0.8525val Loss: 0.2172 Acc: 0.9216Epoch 16/24----------train Loss: 0.2128 Acc: 0.9344val Loss: 0.2192 Acc: 0.9281Epoch 17/24----------train Loss: 0.2688 Acc: 0.8893val Loss: 0.2505 Acc: 0.9020Epoch 18/24----------train Loss: 0.2448 Acc: 0.9221val Loss: 0.2165 Acc: 0.9150Epoch 19/24----------train Loss: 0.2102 Acc: 0.9057val Loss: 0.2070 Acc: 0.9216Epoch 20/24----------train Loss: 0.3172 Acc: 0.8689val Loss: 0.2137 Acc: 0.9216Epoch 21/24----------train Loss: 0.2731 Acc: 0.9098val Loss: 0.2309 Acc: 0.9085Epoch 22/24----------train Loss: 0.2512 Acc: 0.9098val Loss: 0.2228 Acc: 0.9412Epoch 23/24----------train Loss: 0.2668 Acc: 0.8934val Loss: 0.2455 Acc: 0.9150Epoch 24/24----------train Loss: 0.2117 Acc: 0.9057val Loss: 0.2194 Acc: 0.9281Training complete in 0m 54sBest val Acc: 0.941176
1 2 |
visualize_model(model_ft) |
ConvNet as fixed feature extractor
이제 마지막 층만 제외하고 나머지는 고정시켜보도록 하겠습니다. backward()
를 실행할 때 그래디언트를 계산하지 않도록 파라미터를 고정시키려면 requires_grad = False
으로 설정해야 합니다.
더 많은 정보는 문서에서 볼 수 있습니다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
model_conv = torchvision.models.resnet18(pretrained=True) for param in model_conv.parameters(): param.requires_grad = False # Parameters of newly constructed modules have requires_grad=True by default num_ftrs = model_conv.fc.in_features model_conv.fc = nn.Linear(num_ftrs, 2) if use_gpu: model_conv = model_conv.cuda() criterion = nn.CrossEntropyLoss() # Observe that only parameters of final layer are being optimized as # opoosed to before. optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) |
Train and evaluate
CPU에서 실행하면 이전의 시나리오의 반 정도밖에 시간이 걸리지 않습니다. 이것은 네트워크의 대부분에서 그래디언트를 계산할 필요가 없기 때문입니다. 하지만 순전파는 계산되어야만 합니다.
1 2 3 |
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) |
Out:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 Epoch 0/24----------train Loss: 0.5227 Acc: 0.7213val Loss: 0.4018 Acc: 0.7974Epoch 1/24----------train Loss: 0.5609 Acc: 0.7295val Loss: 0.1907 Acc: 0.9412Epoch 2/24----------train Loss: 0.4098 Acc: 0.8238val Loss: 0.3553 Acc: 0.8366Epoch 3/24----------train Loss: 0.4917 Acc: 0.8074val Loss: 0.1798 Acc: 0.9346Epoch 4/24----------train Loss: 0.3917 Acc: 0.8197val Loss: 0.1668 Acc: 0.9477Epoch 5/24----------train Loss: 0.3302 Acc: 0.8443val Loss: 0.1861 Acc: 0.9281Epoch 6/24----------train Loss: 0.2965 Acc: 0.8811val Loss: 0.2005 Acc: 0.9150Epoch 7/24----------train Loss: 0.3125 Acc: 0.8770val Loss: 0.1726 Acc: 0.9542Epoch 8/24----------train Loss: 0.3310 Acc: 0.8730val Loss: 0.1651 Acc: 0.9477Epoch 9/24----------train Loss: 0.3654 Acc: 0.8689val Loss: 0.1980 Acc: 0.9346Epoch 10/24----------train Loss: 0.4297 Acc: 0.8279val Loss: 0.1984 Acc: 0.9412Epoch 11/24----------train Loss: 0.3352 Acc: 0.8443val Loss: 0.2008 Acc: 0.9346Epoch 12/24----------train Loss: 0.4393 Acc: 0.7992val Loss: 0.1606 Acc: 0.9412Epoch 13/24----------train Loss: 0.3489 Acc: 0.8566val Loss: 0.1609 Acc: 0.9346Epoch 14/24----------train Loss: 0.3435 Acc: 0.8443val Loss: 0.1576 Acc: 0.9412Epoch 15/24----------train Loss: 0.3648 Acc: 0.8525val Loss: 0.1817 Acc: 0.9412Epoch 16/24----------train Loss: 0.2489 Acc: 0.8975val Loss: 0.1781 Acc: 0.9412Epoch 17/24----------train Loss: 0.2831 Acc: 0.8730val Loss: 0.1822 Acc: 0.9412Epoch 18/24----------train Loss: 0.3642 Acc: 0.8074val Loss: 0.1690 Acc: 0.9281Epoch 19/24----------train Loss: 0.4191 Acc: 0.7746val Loss: 0.1566 Acc: 0.9477Epoch 20/24----------train Loss: 0.3356 Acc: 0.8607val Loss: 0.1744 Acc: 0.9346Epoch 21/24----------train Loss: 0.3974 Acc: 0.8443val Loss: 0.1585 Acc: 0.9542Epoch 22/24----------train Loss: 0.3151 Acc: 0.8525val Loss: 0.1850 Acc: 0.9346Epoch 23/24----------train Loss: 0.3539 Acc: 0.8525val Loss: 0.1691 Acc: 0.9412Epoch 24/24----------train Loss: 0.3858 Acc: 0.8361val Loss: 0.1720 Acc: 0.9412Training complete in 0m 27sBest val Acc: 0.954248
1 2 3 4 5 |
visualize_model(model_conv) plt.ioff() plt.show() |