Tech/PyTorch

PyTorch Basics

@~@ 2024. 7. 3. 17:24

1. view와 reshape

- view와 reshape은 contiguity 보장의 차이

a = torch.zeros(3, 2)
b = a.view(2, 3)
a.fill_(1)

a = torch.zeros(3, 2)
b = a.t().reshape(6)
a.fill_(1)

 

 

2. squeeze와 unsqueeze

- squeeze: 차원의 개수가 1인 차원을 삭제 

- unsqueeze: 차원의 개수가 1인 차원을 추가

 

3. mm

- 행벡터 * 열벡터 연산(내적)은 dot을 사용

- 단, 행렬*행렬 연산(행렬곱셈)은 mm을 사용한다는 차이가 존재 

- mm: 벡터 간 연산을 지원하지 X 

 

4. matmul의 브로드캐스트

a = torch.rand(5, 2, 3)
b = torch.rand(5)
a.mm(b)

- mm은 위 matrix연산이 불가능함.

 

- 하지만 matmul은 가능 -> matmul이 broadcasting 연산을 지원하기 때문에 가능한 일

a = torch.rand(5, 2, 3)
b = torch.rand(5)
a.matmul(b)

이 때!

위 코드의 torch.rand(5, 2, 3)에서 5는 batch로 보고

아래 코드처럼 mm 연산은 5번 수행하면

같은 결과가 나온다. 

a[0].mm(torch.unsqueeze(b,1))
a[1].mm(torch.unsqueeze(b,1))
a[2].mm(torch.unsqueeze(b,1))
a[3].mm(torch.unsqueeze(b,1))
a[4].mm(torch.unsqueeze(b,1))