본문 바로가기

AI/Pytorch

Pytorch collate_fn 이란?

DataLoader에는 여러 파라미터가 있어 필요시 적절한 파라미터를 활용해 여러 설정을 줄 수 있다. 그 중에서도 collate_fn은 variable length가 달라서, 패딩해줄 때 사용한다. 

 

from torch.utils.data import Dataset, DataLoader
import torch

class ExampleDataset(Dataset):
    def __init__(self, num):
        self.num = num

    def __len__(self):
        return self.num

    def __getitem__(self, idx):
        return {"X":torch.tensor([idx] * (idx+1), dtype=torch.float32),
                "y": torch.tensor(idx, dtype=torch.float32)}

dataset_example = ExampleDataset(num = 10)
dataloader_example = torch.utils.data.DataLoader(dataset_example, batch_size= 1)
for d in dataloader_example:
    print('X : ',d['X'])
    
'''
X :  tensor([[0.]])
X :  tensor([[1., 1.]])
X :  tensor([[2., 2., 2.]])
X :  tensor([[3., 3., 3., 3.]])
X :  tensor([[4., 4., 4., 4., 4.]])
X :  tensor([[5., 5., 5., 5., 5., 5.]])
X :  tensor([[6., 6., 6., 6., 6., 6., 6.]])
X :  tensor([[7., 7., 7., 7., 7., 7., 7., 7.]])
X :  tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8.]])
X :  tensor([[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])
'''

예를 들어 위와 같이 input의 길이가 다른 Dataset이 있을 때, print할 경우 Batch size가 1이라면 문제가 없다.

 

dataloader_example = torch.utils.data.DataLoader(dataset_example, batch_size = 2)
for d in dataloader_example:
    print(d['X'])
    
 '''
 ---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-449346496a83> in <module>()
      1 dataloader_example = torch.utils.data.DataLoader(dataset_example, batch_size = 2)
----> 2 for d in dataloader_example:
      3     print(d['X'])

5 frames
/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
    136             storage = elem.storage()._new_shared(numel)
    137             out = elem.new(storage).resize_(len(batch), *list(elem.size()))
--> 138         return torch.stack(batch, 0, out=out)
    139     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
    140             and elem_type.__name__ != 'string_':

RuntimeError: stack expects each tensor to be equal size, but got [1] at entry 0 and [2] at entry 1
 '''

 

그러나 위와 같이 Batch Size를 2로 설정하면, 에러가 발생한다. 이는 같은 배치 안의 input(X)의 길이가 다르기 때문이다. 이 에러를 해결하기 위해 input의 길이를 동일하게 맞춰주어야 함으로, collate_fn을 사용하여 같은 배치 안에 길이가 가장 긴 input에 맞춰 다른 input들에 임의로 0값을 넣었다. (Zero padding)

 

def my_collate_fn(samples):
    collate_X = []
    collate_y = []
    max_len = max([len(sample['X']) for sample in samples])
    for sample in samples:
        diff = max_len-len(sample['X'])
        if diff > 0:
            zero_pad = torch.zeros(size=(diff,))
            collate_X.append(torch.cat([sample['X'], zero_pad], dim=0))
        else:
            collate_X.append(sample['X'])
    collate_y = [sample['y'] for sample in samples]
    return {'X': torch.stack(collate_X),
             'y': torch.stack(collate_y)}

dataloader_example = torch.utils.data.DataLoader(dataset_example,
                                                 batch_size=2,
                                                 collate_fn=my_collate_fn)
for d in dataloader_example:
    print(d['X'], d['y'])
    
'''
tensor([[0., 0.],
        [1., 1.]]) tensor([0., 1.])
tensor([[2., 2., 2., 0.],
        [3., 3., 3., 3.]]) tensor([2., 3.])
tensor([[4., 4., 4., 4., 4., 0.],
        [5., 5., 5., 5., 5., 5.]]) tensor([4., 5.])
tensor([[6., 6., 6., 6., 6., 6., 6., 0.],
        [7., 7., 7., 7., 7., 7., 7., 7.]]) tensor([6., 7.])
tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8., 0.],
        [9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]]) tensor([8., 9.])
'''

 

 

참고

네이버 커넥트재단 - 부스트캠프

pytorch document - https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html