본문 바로가기
딥러닝&머신러닝/파이토치 기본 문법

[PyTorch] Tensor 합치기: cat(), stack()

by David.Ho 2023. 1. 9.
728x90
반응형

 

실험이 돌아가는 동안 심심하니까 하는 포스팅. PyTorch에서 tensor를 합치는 2가지 방법이 있는데 cat과 stack이다. 두가지는 현재 차원의 수를 유지하느냐 확장하느냐의 차이가 있다. 그림과 코드를 통해 사용법을 알아보자.

Cat함수란?

cat함수는 concatenate를 해주는 함수이고 concatenate하고자 하는 차원을 증가시킨다 (차원의 수는 유지된다). concatenate하고자하는 차원을 지정해주면 그 차원으로 두 tensor의 차원을 더한 값으로 차원이 변경된다. concatenate하고자하는 dimension을 지정해주지 않으면 default=0으로 설정된다.

Cat함수의 시각화

Python 코드

import torch

batch_size, N, K = 3, 10, 256

x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]

output1 = torch.cat([x,y], dim=1) #[M, N+N, K]
output2 = torch.cat([x,y], dim=2) #[M, N, K+K]

Stack함수란?

stack함수는 지정하는 차원으로 확장하여 tensor를 쌓아주는 함수이다. (지정하는 차원에 새로운 차원이 생긴다=차원의 수가 증가한다) tensor를 쌓아주는 함수이기 때문에 두 tensor의 차원이 정확히 일치해야 쌓을 수 있다. stack 하고자하는 dimension을 지정해주지 않으면 default=0으로 설정된다.

Stack함수의 시각화

Python 코드

import torch

batch_size, N, K = 3, 10, 256

x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]

output = torch.stack([x,y], dim=1) #[M, 2, N, K]

Cat 함수 활용: Tensor list를 한번에 tensor로 만들기

import torch

#(....중략)

out_list = []
for data in datas:
    out = model(data)
    out_list.append(out)
output = torch.cat(out_list, 0)
# same as --> output = torch.cat(out_list, dim=0) 

# 참고, numpy로 변환
output_np = output.detach().cpu().numpy()

Stack 함수 활용: Tensor list를 한번에 tensor로 만들기

import torch

#(....중략)

out_list = []
for data in datas:
    out = model(data)
    out_list.append(out)
output = torch.stack(out_list, 0)
728x90
반응형

댓글