본문 바로가기

Data-science/deep learning

[pytorch] torch.utils.data.DataLoader 이용시 파일 경로 출력

728x90

배경

efficientNet을 이용한 간단한 분류기 만들기

github.com/narumiruna/efficientnet-pytorch

 

narumiruna/efficientnet-pytorch

A PyTorch implementation of "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". - narumiruna/efficientnet-pytorch

github.com

shutil 을 이용해서 파일을 분류하고 분류 결과대로 저장하고 싶었다. 그래서 파일명이 필요했다.

과정

from torchvision.datasets import ImageFolder

보통 Image data의 경우 ImageFolder를 쓴다.

from torch.utils.data import DataLoader

ImageFolder class로 형성한 데이터 셋을 DataLoader가 인자로 받는다.

xx = datasets.ImageFolder(os.path.join('data/example', 'train'), transform=transform)
yy = DataLoader(dataset=xx, shuffle=False, batch_size=16)

이때 yy안에 dataset이 정의 되어있다.

그리고 파일이름은 xx.imgs 내에 있다.

filenames = np.array(yy.dataset.imgs)[:, 0]

imgs를 열면 아마 tuple 형식으로 두 번째 값이 label 일 것이다.

그러니 numpy array로 바꿔서 저런식으로 열면

파일 경로가 출력될 것이다.

전체 코드

test 함수만 보면 된다. 2가지 클래스를 분류한 뒤 특정 폴더로 이미지를 복사하게끔 하는 코드.

import argparse

import mlconfig
import torch
import torch.nn.functional as F
from tqdm import tqdm

from efficientnet import models
from efficientnet.datasets.photo2illu import P2LDataLoader
from efficientnet.metrics import Accuracy, Average
from efficientnet.models.efficientnet import params


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='configs/photo2illu-test.yaml')
    parser.add_argument('--arch', type=str, default='efficientnet_b0')
    parser.add_argument('-r', '--root', type=str, default='test')
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('-w', '--weight', type=str, default=None)
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--no-cuda', action='store_true')
    return parser.parse_args()


def evaluate(model, valid_loader, device):
    model.eval()

    valid_loss = Average()
    valid_acc = Accuracy()

    with torch.no_grad():
        valid_loader = tqdm(valid_loader, desc='Validate', ncols=0)
        for x, y in valid_loader:
            x = x.to(device)
            y = y.to(device)

            output = model(x)
            loss = F.cross_entropy(output, y)

            valid_loss.update(loss.item(), number=x.size(0))
            valid_acc.update(output, y)

            valid_loader.set_postfix_str(f'valid loss: {valid_loss}, valid acc: {valid_acc}.')

    return valid_loss, valid_acc

def test(model, valid_loader, device):
    model.eval()
    from torchvision.utils import save_image
    import os
    import numpy as np
    import shutil
    save_photo = 'test_results3/photo'
    save_illu = 'test_results3/illu'
    os.makedirs(save_photo, exist_ok=True)
    os.makedirs(save_illu, exist_ok=True)
    dataset = np.array(valid_loader.dataset.imgs)
#     print(dataset)
    idx = 0
    batch_size = valid_loader.batch_size
    with torch.no_grad():
        
        valid_loader = tqdm(valid_loader, desc='Validate', ncols=0)
        for x, y in valid_loader:
            imgs = dataset[idx:idx+batch_size, 0]
            x= x.to(device)
            y= y.to(device)
            output = model(x)
            outs = torch.argmax(output, dim=1)
            for i, out in enumerate(outs):
                this_name = imgs[i]
                if out.item() == 0:
                    new_path = os.path.join(save_illu, os.path.basename(this_name))
#                     save_image(x[i], os.path.join(save_illu, f'{i}_{output[i][out]}.png'))
                if out.item() == 1:
                    new_path = os.path.join(save_photo, os.path.basename(this_name))
#                     save_image(x[i], os.path.join(save_photo, f'{i}_{output[i][out]}.png'))
                shutil.copy(imgs[i], new_path)
#             print(out, y)
            idx += batch_size

if __name__ == '__main__':
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    config = mlconfig.load(args.config)

    model = config.model()
#     model = getattr(models, args.arch)(pretrained=(args.weight is None))
    if args.weight is not None:
        state_dict = torch.load(args.weight, map_location='cpu')
        model.load_state_dict(state_dict['model'])
    model.to(device)

    image_size = params[args.arch][2]
    valid_loader = config.dataset()
#     P2LDataLoader(args.root, image_size, False, args.batch_size, num_workers=args.num_workers)

    test(model, valid_loader, device)