본문 바로가기

AI/Pytorch

[Pytorch] 파이토치의 구성요소, 데이터 준비

파이토치의 구성요소

  • torch: 메인 네임스페이스, 텐서 등의 다양한 수학 함수가 포함
  • torch.autograd: 자동 미분 기능을 제공하는 라이브러리
  • torch.nn: 신경망 구축을 위한 데이터 구조나 레이어 등의 라이브러리
  • torch.multiprocessing: 병럴처리 기능을 제공하는 라이브러리
  • torch.optim: SGD(Stochastic Gradient Descent)를 중심으로 한 파라미터 최적화 알고리즘 제공
  • torch.utils: 데이터 조작 등 유틸리티 기능 제공
  • torch.onnx: ONNX(Open Neural Network Exchange), 서로 다른 프레임워크 간의 모델을 공유할 때 사용

데이터 준비

파이토치에서는 데이터 준비를 위해 torch.utils.data의 Dataset과 DataLoader 사용 가능

import torchvision.transforms as transforms
from torchvision import datasets

DataLoader의 인자로 들어갈 transform을 미리 정의할 수 있고, Compose를 통해 리스트 안에 순서대로 전처리 진행

ToTensor()를 하는 이유는 torchvision이 PIL Image 형태로만 입력을 받기 때문에 데이터 처리를 위해서 Tensor형으로 변환 필요

mnist_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(mean=(0.5,),std=(1.0,))])

trainset = datasets.MNIST(root='/content/',
                         train=True, download=True,
                         transform=mnist_transform)
testset = datasets.MNIST(root='/content/',
                         train=False, download=True,
                         transform=mnist_transform)

DataLoader는 데이터 전체를 보관했다가 실제 모델 학습을 할 때 batch_size 크기만큼 데이터를 가져옴

train_loader = DataLoader(trainset, batch_size = 8, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size = 8, shuffle=False, num_workers=2)

dataiter = iter(train_loader)
images, labels = dataiter.next()
images.shape, labels.shape

(torch.Size([8, 1, 28, 28]), torch.Size([8]))

28x28 이미지 인데, 1은 흑백, 8개(배치사이즈)의 흑백사진

torch_image = torch.squeeze(images[0])
torch_image.shape

torch.Size([28, 28])

import matplotlib.pyplot as plt

figure = plt.figure(figsize=(12,6))
cols, rows = 4, 2
for i in range(1, cols*rows + 1):
  sample_idx = torch.randint(len(trainset), size=(1, )).item()
  img, label = trainset[sample_idx]
  figure.add_subplot(rows, cols, i)
  plt.title(label)
  plt.axis('off')
  plt.imshow(img.squeeze(), cmap='gray')
plt.show()

Reference

이수안컴퓨터연구소 - 파이토치 한번에 끝내기 PyTorch Full Tutorial Course

'AI > Pytorch' 카테고리의 다른 글

Pytorch collate_fn 이란?  (0) 2023.02.27
[Pytorch] 파이토치 모델 정의, 사전 학습 모델  (0) 2022.08.12