
DataLoader에는 여러 파라미터가 있어 필요시 적절한 파라미터를 활용해 여러 설정을 줄 수 있다. 그 중에서도 collate_fn은 variable length가 달라서, 패딩해줄 때 사용한다.
from torch.utils.data import Dataset, DataLoader
import torch
class ExampleDataset(Dataset):
def __init__(self, num):
self.num = num
def __len__(self):
return self.num
def __getitem__(self, idx):
return {"X":torch.tensor([idx] * (idx+1), dtype=torch.float32),
"y": torch.tensor(idx, dtype=torch.float32)}
dataset_example = ExampleDataset(num = 10)
dataloader_example = torch.utils.data.DataLoader(dataset_example, batch_size= 1)
for d in dataloader_example:
print('X : ',d['X'])
'''
X : tensor([[0.]])
X : tensor([[1., 1.]])
X : tensor([[2., 2., 2.]])
X : tensor([[3., 3., 3., 3.]])
X : tensor([[4., 4., 4., 4., 4.]])
X : tensor([[5., 5., 5., 5., 5., 5.]])
X : tensor([[6., 6., 6., 6., 6., 6., 6.]])
X : tensor([[7., 7., 7., 7., 7., 7., 7., 7.]])
X : tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8.]])
X : tensor([[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])
'''
예를 들어 위와 같이 input의 길이가 다른 Dataset이 있을 때, print할 경우 Batch size가 1이라면 문제가 없다.
dataloader_example = torch.utils.data.DataLoader(dataset_example, batch_size = 2)
for d in dataloader_example:
print(d['X'])
'''
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-11-449346496a83> in <module>()
1 dataloader_example = torch.utils.data.DataLoader(dataset_example, batch_size = 2)
----> 2 for d in dataloader_example:
3 print(d['X'])
5 frames
/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
136 storage = elem.storage()._new_shared(numel)
137 out = elem.new(storage).resize_(len(batch), *list(elem.size()))
--> 138 return torch.stack(batch, 0, out=out)
139 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
140 and elem_type.__name__ != 'string_':
RuntimeError: stack expects each tensor to be equal size, but got [1] at entry 0 and [2] at entry 1
'''
그러나 위와 같이 Batch Size를 2로 설정하면, 에러가 발생한다. 이는 같은 배치 안의 input(X)의 길이가 다르기 때문이다. 이 에러를 해결하기 위해 input의 길이를 동일하게 맞춰주어야 함으로, collate_fn을 사용하여 같은 배치 안에 길이가 가장 긴 input에 맞춰 다른 input들에 임의로 0값을 넣었다. (Zero padding)
def my_collate_fn(samples):
collate_X = []
collate_y = []
max_len = max([len(sample['X']) for sample in samples])
for sample in samples:
diff = max_len-len(sample['X'])
if diff > 0:
zero_pad = torch.zeros(size=(diff,))
collate_X.append(torch.cat([sample['X'], zero_pad], dim=0))
else:
collate_X.append(sample['X'])
collate_y = [sample['y'] for sample in samples]
return {'X': torch.stack(collate_X),
'y': torch.stack(collate_y)}
dataloader_example = torch.utils.data.DataLoader(dataset_example,
batch_size=2,
collate_fn=my_collate_fn)
for d in dataloader_example:
print(d['X'], d['y'])
'''
tensor([[0., 0.],
[1., 1.]]) tensor([0., 1.])
tensor([[2., 2., 2., 0.],
[3., 3., 3., 3.]]) tensor([2., 3.])
tensor([[4., 4., 4., 4., 4., 0.],
[5., 5., 5., 5., 5., 5.]]) tensor([4., 5.])
tensor([[6., 6., 6., 6., 6., 6., 6., 0.],
[7., 7., 7., 7., 7., 7., 7., 7.]]) tensor([6., 7.])
tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8., 0.],
[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]]) tensor([8., 9.])
'''
참고
네이버 커넥트재단 - 부스트캠프
pytorch document - https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
'AI > Pytorch' 카테고리의 다른 글
[Pytorch] 파이토치 모델 정의, 사전 학습 모델 (0) | 2022.08.12 |
---|---|
[Pytorch] 파이토치의 구성요소, 데이터 준비 (0) | 2022.08.12 |