파이토치의 구성요소
- torch: 메인 네임스페이스, 텐서 등의 다양한 수학 함수가 포함
- torch.autograd: 자동 미분 기능을 제공하는 라이브러리
- torch.nn: 신경망 구축을 위한 데이터 구조나 레이어 등의 라이브러리
- torch.multiprocessing: 병럴처리 기능을 제공하는 라이브러리
- torch.optim: SGD(Stochastic Gradient Descent)를 중심으로 한 파라미터 최적화 알고리즘 제공
- torch.utils: 데이터 조작 등 유틸리티 기능 제공
- torch.onnx: ONNX(Open Neural Network Exchange), 서로 다른 프레임워크 간의 모델을 공유할 때 사용
데이터 준비
파이토치에서는 데이터 준비를 위해 torch.utils.data의 Dataset과 DataLoader 사용 가능
- Dataset에는 다양한 데이터셋이 존재 (MNIST, FashionMNIST, CIFAR10, ...)
- Vision Dataset: https://pytorch.org/vision/stable/datasets.html
- Text Dataset: https://pytorch.org/text/stable/datasets.html
- Audio Dataset: https://pytorch.org/audio/stable/datasets.html
- DataLoader와 Dataset을 통해 batch_size, train 여부, transform 등을 인자로 넣어 데이터를 어떻게 load할 것인지 정해줄 수 있음
토치비전(torchvision)은 파이토치에서 제공하는 데이터셋들이 모여있는 패키지from torch.utils.data import Dataset, DataLoader
- transforms: 전처리할 때 사용하는 메소드 (https://pytorch.org/docs/stable/torchvision/transforms.html)
- transforms에서 제공하는 클래스 이외는 일반적으로 클래스를 따로 만들어 전처리 단계를 진행
import torchvision.transforms as transforms
from torchvision import datasets
DataLoader의 인자로 들어갈 transform을 미리 정의할 수 있고, Compose를 통해 리스트 안에 순서대로 전처리 진행
ToTensor()를 하는 이유는 torchvision이 PIL Image 형태로만 입력을 받기 때문에 데이터 처리를 위해서 Tensor형으로 변환 필요
mnist_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5,),std=(1.0,))])
trainset = datasets.MNIST(root='/content/',
train=True, download=True,
transform=mnist_transform)
testset = datasets.MNIST(root='/content/',
train=False, download=True,
transform=mnist_transform)
DataLoader는 데이터 전체를 보관했다가 실제 모델 학습을 할 때 batch_size 크기만큼 데이터를 가져옴
train_loader = DataLoader(trainset, batch_size = 8, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size = 8, shuffle=False, num_workers=2)
dataiter = iter(train_loader)
images, labels = dataiter.next()
images.shape, labels.shape
(torch.Size([8, 1, 28, 28]), torch.Size([8]))
28x28 이미지 인데, 1은 흑백, 8개(배치사이즈)의 흑백사진
torch_image = torch.squeeze(images[0])
torch_image.shape
torch.Size([28, 28])
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(12,6))
cols, rows = 4, 2
for i in range(1, cols*rows + 1):
sample_idx = torch.randint(len(trainset), size=(1, )).item()
img, label = trainset[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(label)
plt.axis('off')
plt.imshow(img.squeeze(), cmap='gray')
plt.show()
Reference
'AI > Pytorch' 카테고리의 다른 글
Pytorch collate_fn 이란? (0) | 2023.02.27 |
---|---|
[Pytorch] 파이토치 모델 정의, 사전 학습 모델 (0) | 2022.08.12 |