TORCH.UTILS.DATA 공식 문서 파훼하기

데이터 피딩에 활용되는 PyTorch 클래스를 알아보자

Posted by devfon on May 26, 2020

PyTorch 데이터 로딩의 중심에는 torch.utils.data.DataLoader 클래스가 있습니다. DataLoaderDataset에 대한 Python Iterable 클래스입니다. DataLoader에는 다양한 옵션이 존재하는데, 여러 옵션을 활용해 다음과 같이 DataLoader를 초기화 할 수 있습니다.

1
2
3
4
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

이제 DataLoader 클래스 초기화에 활용되는 다양한 옵션들에 대해 알아보도록 하겠습니다.


Dataset 종류

DataLoader 초기화에 있어 가장 중요한 인자는 dataset 이며, 이는 데이터를 끌어올 데이터셋 객체를 지칭합니다. PyTorch는 두 개 유형의 데이터셋을 지원합니다.

Map-style 데이터셋

Map-style 데이터셋__getitem__()__len__()을 구현해 인덱스, 키 등을 활용한 데이터 샘플 매핑을 수행합니다. 예를 들어, Map-style 데이터셋dataset[idx]와 같이 접근되었을 때, idx번 째 이미지와 해당 이미지의 라벨을 매핑해 반환하게 됩니다.


Iterable-style 데이터셋

Iterable-style 데이터셋__iter__()를 구현하는 IterableDataset의 서브 클래스 인스턴스입니다. 이는 데이터 샘플에 대한 Iterable 객체입니다. 해당 데이터셋은 데이터 샘플을 임의로 읽는 작업이 비싼 연산이거나 부적절할 때, 혹은 배치 사이즈__iter__() 연산을 통해 읽어온 데이터 갯수에 의해 정해지게끔 하고 싶을 때 사용하기 적합한 데이터셋입니다.

예를 들어, iter(dataset)이 호출되었을 때, 데이터셋은 데이터베이스, 원격 서버, 혹은 실시간으로 생성되는 로그 등에서 데이터 스트림을 읽어와 반환합니다.

Note: IterableDataset을 멀티 프로세스로 활용하면, 동일한 데이터셋 객체각 워커 프로세스에 복제됩니다. 따라서 중복 데이터가 모델에 Feeding 되지 않도록 하기 위해서는 추가 설정이 수행되어야 합니다.

1
2
3
4
5
6
7
8
9
10
11
12
# 중복 데이터 방지하는 예시
def __iter__(self):
  worker_info = torch.utils.data.get_worker_info()
  if worker_info is None:  # 싱글 프로세스로 데이터 로딩할 경우, Full Iterator를 반환
    iter_start = self.start
    iter_end = self.end
  else:  # 멀티 프로세스로 데이터 로딩할 경우, 워크 로드 분배
    per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    iter_start = self.start + worker_id * per_worker
    iter_end = min(iter_start + per_worker, self.end)
  return iter(range(iter_start, iter_end))


데이터 로딩 순서와 Sampler

Iterable-style 데이터셋의 데이터 로딩 순서는 전적으로 사용자가 정의한 Iterable에 의해 결정됩니다. 그리고 이러한 특성은 Chunk 단위로 데이터를 읽는다거나, 동적인 배치 사이즈를 설정한다거나 등의 작업을 구현하기 쉽게 도와줍니다. 단순히 IterableBatch 단위의 샘플을 yield만 하면 되기 때문입니다.

따라서 이번 섹션에서는 Map-style 데이터셋만 다룰 예정입니다. SamplerMap-style 데이터셋의 데이터 로딩에 사용되는 인덱스 혹은 키 시퀀스를 정하기 위해 사용되는 클래스입니다. 즉, Sampler 클래스는 데이터셋에 접근할 수 있는 인덱스에 대한 Iterable 객체입니다. 예를 들어, Stochastic Gradient Decent (SGD)에 있어 Sampler는 인덱스로 구성된 시퀀스를 순열한 후, 매번 한 개의 인덱스를 yield 합니다. 혹은 미니 배치 SGD를 위해 특정 갯수만큼의 인덱스를 yield 할 수도 있습니다.

Sequential 혹은 Shuffled Sampler의 경우, DataLoader 인스턴스 초기화 옵션에 사용되는 shuffle 인자에 따라 결정 및 구현됩니다. 혹은 사용자가 직접 구현한대로 인덱스 및 키 시퀀스를 yield 해주는 커스텀 Samplersampler 인자에 넣어줄 수도 있습니다.

커스텀 Sampler가 배치 단위로 인덱스를 yield 하도록 설정하고자 한다면, batch_sampler 인자에 커스텀 Sampler를 넘겨주면 됩니다. 혹은 단순 배치만을 위해 커스텀 Sampler를 구현하고자 하는 것이라면, 이는 batch_sizedrop_last 인자를 통해 해당 설정이 가능합니다.

역자주: drop_last 인자의 경우, 배치 크기를 채우지 못한 마지막 불완전 배치를 사용할 것인지, 사용하지 않을 것인지를 결정하는 Boolean 인자입니다.

Note: Iterable-style 데이터셋의 경우, 인덱스 혹은 키 등의 개념이 없기 때문에 samplerbatch_sampler 인자를 사용할 수 없습니다.


배치 혹은 Non-배치 데이터의 로딩

DataLoader 클래스는 batch_size, drop_last 그리고 batched_sampler 인자를 통해 개별 데이터 샘플을 배치로 엮어주는 기능을 제공합니다.

자동 배치 (기본 값)

자동 배치는 가장 흔히 사용되는 옵션입니다. 자동 배치의 경우, 미니 배치 크기만큼의 데이터를 읽어와 이를 배치 샘플로 합쳐줍니다. 때문에 배치 샘플 텐서의 차원 하나는 배치 사이즈를 나타내며, 주로 첫 번째 차원이 이를 나타냅니다.

batch_sizeNone이 아니라면, DataLoader배치 샘플yield하게 됩니다. 앞서 설명한 것과 마찬가지로, Map-style 데이터셋의 경우, 인덱스 시퀀스를 yield 해주는 Sampler를 구현해 batch_sampler 인자에 해당 Sampler 객체를 넘겨주는 방식으로도 배치 샘플을 모델에 Feeding 할 수 있습니다.

Note: batch_sizedrop_last 인자는 sampler로부터 batch_sampler를 구성하기 위해 사용됩니다. Map-style 데이터셋의 경우, sampler가 사용자에 의해 전달되거나, shuffle 인자에 의해 결정됩니다. Iterable-style 데이터셋의 경우 sampler에 더미 값 [1, 1, ..., 1]이 들어가게 됩니다.

Note: Iterable-style 데이터셋을 멀티 프로세서로 활용한다면, drop__last 인자를 통해 각 워커가 복제해 사용하는 데이터셋에서 마지막 미완성 배치를 활용할지 말지 정하게 됩니다.

Sampler가 내놓은 인덱스 시퀀스를 활용해 데이터 샘플을 읽어온 후에는 collate_fn 인자로 등록된 함수샘플 리스트배치 샘플로 합치는데 활용됩니다.

따라서 Map-style 데이터셋에서 데이터를 읽어오는 과정은 다음과 같아집니다:

1
2
for indices in batch_sampler:
  yield collate_fn([dataset[i] for i in indices])

그리고 Iterable-style 데이터셋에서 데이터를 읽어오는 과정은 다음과 같아집니다:

1
2
3
dataset_iter = iter(dataset)
for indices in batch_sampler:
  yield collate_fn([next(dataset_iter) for _ in indices])

추가적으로 배치 내 존재하는 시퀀셜 데이터최대 길이만큼 패딩한다거나 등의 추가 작업을 수행할 수 있는 커스텀 collate_fn을 구현해 활용할 수도 있습니다.


자동 배치 해제

특수한 경우에 사용자는 데이터셋 코드에 있어 배치를 직접 관리하거나, 개별 데이터 샘플을 읽어들여야 할 수도 있습니다. 예를 들어, 때로는 데이터를 데이터베이스에서 직접 Bulk로 읽거나, 메모리에서 연속 Chunk로 읽어오는 것이 (즉, 이미 배치 상태인 데이터를 바로 읽어오는 것이) 더 저렴한 연산일 수 있습니다. 혹은 배치 사이즈가 데이터에 따라 다르게 적용되어야 하는 경우도 있을 것입니다. 이러한 경우, 자동 배치를 해제해 DataLoaderdataset 객체의 각 샘플을 반환하게끔 하는 것이 좋습니다.

batch_sizebatch_sampler가 모두 None 일 때, 자동 배치가 해제됩니다. 그리고 dataset을 통해 얻어진 개별 데이터 샘플은 collate_fn 함수를 거쳐 모델에 Feeding 됩니다.

자동 배치 설정이 해제된 경우, 기본 collate_fn이 수행하는 작업은 단순히 NumPy 배열PyTorch 텐서컨버팅해주는 것입니다.

이 경우, Map-style 데이터셋에서 데이터를 읽어오는 과정이 다음과 같아집니다:

1
2
for index in sampler:
  yield collate_fn(dataset[index])

그리고 Iterable-style 데이터셋에서 데이터를 읽어오는 과정은 다음과 같아집니다:

1
2
for data in iter(dataset):
  yield collate_fn(data)


collate_fn 활용하기

collate_fn자동 배치 설정에 따라 다르게 적용됩니다.

자동 배치가 해제된 경우, collate_fn개별 데이터 샘플에 대해 적용됩니다. 이때의 collate_fn은 앞서 언급했듯 단순히 NumPy 배열PyTorch 텐서로 변환하는 작업을 수행하게 됩니다.

자동 배치가 설정된 경우, collate_fn데이터 샘플 리스트에 대해 적용됩니다. 이때의 collate_fn은 리스트에 포함되어 있는 데이터 샘플들을 배치 샘플로 합치는 작업을 수행합니다.

예를 들어, 각 데이터 샘플이 3 채널 이미지정수 클래스 라벨로 구성되어 dataset의 반환 값이 (image, class_index)의 튜플인 경우, 기본 collate_fn은 해당 튜플로 구성된 리스트를 하나의 튜플로 합쳐 배치 이미지 텐서배치 클래스 라벨 텐서를 생성합니다. 그리고 특히, 기본 collate_fn은 다음과 같은 특성을 지닙니다:

  • 항상 텐서 가장 앞에 배치 크기 차원을 추가합니다.
  • NumPy 배열Python 수치 값PyTorch 텐서로 컨버팅합니다.
  • list, tuple, dictionary, namedtuple 등 입력 자료형의 구조를 보존합니다. 예를 들어, 각 데이터 샘플이 딕셔너리였을 경우, 결과 역시 동일한 키와 값으로 구성된 딕셔너리로 나오게 됩니다. 다만, 키에 상응하는 값은 배치 텐서로 변환되어 반환됩니다.

또한 사용자는 배치 사이즈를 첫 번째 차원으로 사용하지 않게끔 한다거나, 시퀀셜 데이터에 패딩을 적용한다거나 등으로 collate_fn을 커스터마이즈해 활용할 수도 있습니다.


멀티 프로세스를 활용한 데이터 로딩

플랫폼 별 특이점

Windows 운영체제에서는 spawn()을 활용해 멀티 프로세싱이 진행됩니다. spawn()을 활용한 멀티 프로세싱에서는 또 다른 인터프리터가 메인 스크립트를 실행시키는 방식으로 진행됩니다. 따라서 이같은 분리된 직렬화에 있어서는 아래와 같은 두 가지 스텝을 고려해야 제대로 된 멀티 프로세스 데이터 로딩이 가능해집니다.

  1. 메인 스크립트 내 코드들을 if __name__ == "__main__" 으로 감싸주어야, 각 worker 프로세스가 실행될 때 나머지 코드들이 재실행되지 않게 됩니다. 즉, datasetDataLoader 인스턴스를 생성하는 로직을 해당 블락에 포함시켜줌으로써 해당 로직들이 중복 실행되지 않게 합니다.

  2. 커스텀으로 작성한 collate_fnworker_init_fn 혹은 데이터셋 코드 등은 __main__ 블락의 바깥인 탑 레벨에서 정의를 해주어야 합니다. 이렇게 해줌으로써 각 worker 프로세스들이 해당 커스텀 코드들을 활용할 수 있게 됩니다.


멀티 프로세스 상황에서의 임의성

PyTorch에서 각 worker는 디폴트로 base_seed + worker_id의 시드를 지니게 됩니다. worked_id를 활용해 worker 간 서로 다른 시드를 지니게 되는 PyTorch와 달리 NumPy와 같은 다른 라이브러리에서 사용되는 시드는 worker들 간 중복이 되어 동일한 난수를 발생시킬 수도 있습니다.

그리고 이를 방지하기 위해 worker_init_fn에서 개별 worker의 시드를 받아 데이터 로딩 이전에 다른 라이브러리들의 시드 번호를 변경해줄 수 있습니다. 각 workerPyTorch 시드는 torch.utils.data.get_worker_info() 혹은 torch.initial_seed() API를 통해 확인할 수 있습니다.