본문 바로가기

Data-science/deep learning

keras Data generator custom하게 만들기

728x90

많은 데이터를 한 번에 Test하려고 하면

RAM error가 뜸.

이때 어떻게 하냐? Data generator를 이용하면 된다.

방법은 생각보다 간단하다.

Sequence라는 녀석을 상속받는 클래스를 만들고, init과 __len__, __getitem__을 작성해준다!

필자의 경우 files가 입력으로 들어오면, 거기서 image를 로드하면 되는 형태다.

또한 file을 열면 그 파일 내에 X, Y가 둘다 들어가 있다.

__getitem__을 구현해주면 된다. index 부분은 건드리지 않고 나머지 부분을 본인의 로직에 맞게 수정하면 된다.

from tensorflow.keras.preprocessing import image
from tensorflow.keras.utils import Sequence
import math


class DataGenerator(Sequence):
    def __init__(self,
                 files: list,
                 batch_size: int,
                 augmentation: bool = False,
                 shuffle: bool = False,
                 rescale:bool = True) -> None:
        self.files = np.array(files)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augmentation = augmentation
        self.rescale =rescale

        # shuffle for first epoch
        if self.shuffle:
            self.shuffle_data()

    def __len__(self):
        return math.ceil(len(self.files) / self.batch_size)

    def __getitem__(self, index):
        image_batch = self.files[index * self.batch_size:(index + 1) * self.batch_size]
        
        Xs, Ys = [], []
        for i in range(len(image_batch)):
            image = np.load(image_batch[i])
            Y = image[:,:,-1].reshape(120,120,1)
            Y = np.where(Y < 0, 0, Y)
            X = image[:,:,:4]
            if self.rescale:
                X = X / 255.            
            if self.augmentation:
                pass

            Xs.append(X)
            Ys.append(Y)

        Xs = np.array(Xs)
        Ys = np.array(Ys)

        return Xs, Ys

    @property
    def n_images(self):
        return len(self.files)

그러면 아래와 같이 model.predict 함수를 사용할 수 있다.

예전 tensorflow 버전의 경우 predict_generator를 사용해야 했지만, 이젠 model.predict를 써도 된다.

단, keras에서 정의한 Generator를 넣어야 한다.

trainGen = DataGenerator(train_files, 64, False, False, True)

model.predict(trainGen, batch_size=64, verbose=1, workers=8, use_multiprocessing=True)