728x90
많은 데이터를 한 번에 Test하려고 하면
RAM error가 뜸.
이때 어떻게 하냐? Data generator를 이용하면 된다.
방법은 생각보다 간단하다.
Sequence라는 녀석을 상속받는 클래스를 만들고, init과 __len__, __getitem__을 작성해준다!
필자의 경우 files가 입력으로 들어오면, 거기서 image를 로드하면 되는 형태다.
또한 file을 열면 그 파일 내에 X, Y가 둘다 들어가 있다.
__getitem__을 구현해주면 된다. index 부분은 건드리지 않고 나머지 부분을 본인의 로직에 맞게 수정하면 된다.
from tensorflow.keras.preprocessing import image
from tensorflow.keras.utils import Sequence
import math
class DataGenerator(Sequence):
def __init__(self,
files: list,
batch_size: int,
augmentation: bool = False,
shuffle: bool = False,
rescale:bool = True) -> None:
self.files = np.array(files)
self.batch_size = batch_size
self.shuffle = shuffle
self.augmentation = augmentation
self.rescale =rescale
# shuffle for first epoch
if self.shuffle:
self.shuffle_data()
def __len__(self):
return math.ceil(len(self.files) / self.batch_size)
def __getitem__(self, index):
image_batch = self.files[index * self.batch_size:(index + 1) * self.batch_size]
Xs, Ys = [], []
for i in range(len(image_batch)):
image = np.load(image_batch[i])
Y = image[:,:,-1].reshape(120,120,1)
Y = np.where(Y < 0, 0, Y)
X = image[:,:,:4]
if self.rescale:
X = X / 255.
if self.augmentation:
pass
Xs.append(X)
Ys.append(Y)
Xs = np.array(Xs)
Ys = np.array(Ys)
return Xs, Ys
@property
def n_images(self):
return len(self.files)
그러면 아래와 같이 model.predict 함수를 사용할 수 있다.
예전 tensorflow 버전의 경우 predict_generator를 사용해야 했지만, 이젠 model.predict를 써도 된다.
단, keras에서 정의한 Generator를 넣어야 한다.
trainGen = DataGenerator(train_files, 64, False, False, True)
model.predict(trainGen, batch_size=64, verbose=1, workers=8, use_multiprocessing=True)
'Data-science > deep learning' 카테고리의 다른 글
Spatiotemporal CNN for Video Object Segmentation (0) | 2020.11.04 |
---|---|
keras Data generator custom하게 만들기 (0) | 2020.11.02 |
conda install vs pip install (0) | 2020.10.31 |
stylegan2 환경 구축 (0) | 2020.10.28 |
tensorflow 설치시 주의점 (0) | 2020.10.28 |