Radi_tech’s blog

Radiological technologist in Japan / MRI / AI / Deep learning / MATLAB / R / Python

【Python】Pytorch k-fold cross validiation


Pytorchでk-fold cross validiationを行う

radi-tech.hatenablog.com

# moduleのimport  いらないのもあるかも。。。
import os
import random
import numpy as np


import matplotlib.pyplot as plt

from sklearn.model_selection import KFold


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from torch.utils.data import Dataset, DataLoader,TensorDataset,random_split,SubsetRandomSampler, ConcatDataset
from torch.nn import functional as F

from torchvision import datasets,transforms
import torchvision.transforms as transforms


# working spaceへの移動
ws = "work spaceのファルダを書く"
os.chdir(ws)


# 画像を格納しているフォルダを指定
train_data_dir = 'trainのファルダパス'
test_data_dir  = 'validiationのフォルダパス'


# data_transformを定義
data_transforms = {
        torchvision.transforms.Compose([
        #torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


# データセットの作成
train_dataset= torchvision.datasets.ImageFolder(train_data_dir, transform=data_transforms)
test_dataset = torchvision.datasets.ImageFolder(test_data_dir,  transform=data_transforms)

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_dataset.transform=transform
test_dataset.transform=transform

dataset = ConcatDataset([train_dataset, test_dataset])


# データ数、class数の獲得
m=len(train_dataset)
class_names = train_dataset.classes


# deviceの定義、hyper parameter
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(42)

num_epochs=10
batch_size=16
k=5
splits=KFold(n_splits=k,shuffle=True,random_state=42)
foldperf={}


# modelの定義と設定
model = torchvision.models.resnet50(pretrained=True)

# replace fully connect layer
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model = model.to(device)

# loss function & hyper parameter
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


# fuction for trainning & validiation
def train_epoch(model,device,dataloader,loss_fn,optimizer):
    train_loss,train_correct=0.0,0
    model.train()
    for images, labels in dataloader:

        images,labels = images.to(device),labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = loss_fn(output,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        scores, predictions = torch.max(output.data, 1)
        train_correct += (predictions == labels).sum().item()

    return train_loss,train_correct

def valid_epoch(model,device,dataloader,loss_fn):
   valid_loss, val_correct = 0.0, 0
   model.eval()
   with torch.no_grad():
     for images, labels in dataloader:

         images,labels = images.to(device),labels.to(device)
         output = model(images)
         loss=loss_fn(output,labels)
         valid_loss+=loss.item()*images.size(0)
         scores, predictions = torch.max(output.data,1)
         val_correct+=(predictions == labels).sum().item()

   return valid_loss,val_correct


# trainning
history = {'train_loss': [], 'test_loss': [],'train_acc':[],'test_acc':[]}

for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):

    print('Fold {}'.format(fold + 1))

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)


    for epoch in range(num_epochs):
        train_loss, train_correct=train_epoch(model,device,train_loader,criterion,optimizer)
        test_loss, test_correct=valid_epoch(model,device,test_loader,criterion)

        train_loss = train_loss / len(train_loader.sampler)
        train_acc = train_correct / len(train_loader.sampler) * 100
        test_loss = test_loss / len(test_loader.sampler)
        test_acc = test_correct / len(test_loader.sampler) * 100

        print("Epoch:{}/{} AVG Training Loss:{:.3f} AVG Test Loss:{:.3f} AVG Training Acc {:.2f} % AVG Test Acc {:.2f} %".format(epoch + 1,
                                                                                                             num_epochs,
                                                                                                             train_loss,
                                                                                                             test_loss,
                                                                                                             train_acc,
                                                                                                             test_acc))
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc) 
        
    
    sv_path = f"{fold + 1:0=3}" + "_model.pth"

    torch.save(model, sv_path)