Radi_tech’s blog

Radiological technologist in Japan / MRI / AI / Deep learning / MATLAB / R / Python

【Python】Pytorch prediction 推論を行う


Pytorchで保存したmodelを使って、推論(prediction)を行う




radi-tech.hatenablog.com


radi-tech.hatenablog.com


# module import
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from torch.utils.data import Dataset, DataLoader,TensorDataset,random_split,SubsetRandomSampler, ConcatDataset
from torch.nn import functional as F

from torchvision import datasets,transforms
import torchvision.transforms as transforms

from sklearn.metrics import classification_report


# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# working spaceへの移動
ws = "working spaceのフォルダ"
os.chdir(ws)


# test image を入れているフォルダ
test_data_dir = 'test image を入れているフォルダ'


# data_transforms
data_transforms = {
        torchvision.transforms.Compose([
        #torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

test_dataset = torchvision.datasets.ImageFolder(test_data_dir,  transform=data_transforms)


# なぜか、これを経ないと走らない。。。たぶん違う書き方がある
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

test_dataset.transform=transform


# DataLoaderへ格納する
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16)


# modelのLoad
model_path = "modelのpath"
model = torch.load(model_path)
model.eval()
model.to("cpu")

# prediction
pred = []
Y = []


for i, (x,y) in enumerate(test_loader):
    with torch.no_grad():
        output = model(x)
    pred += [int(l.argmax()) for l in output]
    Y += [int(l) for l in y]

print(classification_report(Y, pred))



# ROC curve
import numpy as np
from sklearn.metrics import roc_curve, auc

y_pred = []
y_true = []


with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        y_pred.extend(pred.cpu().numpy())
        y_true.extend(target.cpu().numpy())


# True positive ratio & False positive ratio
fpr, tpr, thresholds = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)


# plot
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111)

plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (AUC = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
ax.set_aspect('equal')

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.savefig('test_ROC.png', format="png", dpi=300)
plt.show()


# confusion matrix
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true, y_pred)
cm = pd.DataFrame(data=cm, index=["Fracture", "Noramal"], 
                           columns=["Fracture", "Noramal"])

sns.heatmap(cm, square=True, cbar=True, annot=True, cmap='Blues')

plt.xlabel("Predction label", fontsize=10, fontstyle='italic')
plt.ylabel("Correct label", fontsize=10, fontstyle='italic')
# save as png wtih 300dpi resolution
plt.savefig('test_confusion_matrix.png', format="png", dpi=300)


radi-tech.hatenablog.com