- 1. 概要
- 2. permute・transpose 次元の並べ替え
- 3. reshape・view 次元・要素を変更
1. 概要
前ページで、「interpolate」を用いて、「tensor」をリサイズする方法についてメモしました。
「interpolate」は、リサイズすると同時に、間を穴埋めするような方式ですが、別のやり方もあるようです。
「【PyTorch 入門】PyTorch の次元操作 permute, transpose, reshape, view って何が違うの?」
2. permute・transpose 次元の並べ替え
サイズ自体は、変わりませんが、「permute」「transpose」というメソッドで、次元の並べ替えを行うことができます。
import os
import torch
x = torch.tensor([[[0.0, 1.0],[2.0, 3.0], [ 4.0, 5.0 ]]])
print(x.shape)
print(x, os.linesep)
x_permute = torch.permute(x, (2, 0, 1))
print(x_permute.shape)
print(x_permute)
というソースを書いて、実行すると下記の結果が得られます。
torch.Size([1, 3, 2])
tensor([[[0., 1.],
[2., 3.],
[4., 5.]]])
torch.Size([2, 1, 3])
tensor([[[0., 2., 4.]],
[[1., 3., 5.]]])
どことどこの次元が入れ替わったのか、ちとわかりにくいですが。
3次元のものを 0、1、2 が、それぞれの次元とすると、2、0、1 の順に入れ替えているわけです。
「transpose」も、時点の並べ替えを行います。
import os
import torch
x = torch.tensor([[[0.0, 1.0], [2.0, 3.0], [ 4.0, 5.0 ]]])
print(x.shape)
print(x, os.linesep)
x_transpose = torch.transpose(x, 2, 1)
print(x_transpose.shape)
print(x_transpose)
というソースを書いて、実行すると下記の結果が得られます。
torch.Size([1, 3, 2])
tensor([[[0., 1.],
[2., 3.],
[4., 5.]]])
torch.Size([1, 2, 3])
tensor([[[0., 2., 4.],
[1., 3., 5.]]])
3. reshape・view 次元・要素を変更
「reshape」「view」というメソッドで、次元・要素を変更できます。
import os
import torch
x = torch.tensor([[[ 0.0, 1.0 ], [ 2.0, 3.0 ], [ 4.0, 5.0 ]]])
print(x.shape)
print(x, os.linesep)
x_reshape = torch.reshape(x, (1, 1, 2, 3))
print(x_reshape.shape)
print(x_reshape)
というソースを書いて、実行すると下記の結果が得られます。
torch.Size([1, 3, 2])
tensor([[[0., 1.],
[2., 3.],
[4., 5.]]])
torch.Size([1, 1, 2, 3])
tensor([[[[0., 1., 2.],
[3., 4., 5.]]]])
次元や要素が変わっています。
気を付けなければならないのは、すべての次元の要素数を掛け合わせた数を同じにしなければ、エラーになるそうです。
「view」について、調べます。
import os
import torch
x = torch.tensor([[ 0.0, 1.0 ], [ 2.0, 3.0 ], [ 4.0, 5.0 ]])
print(x.shape)
print(x, os.linesep)
x_viewe = torch.t(x).view(2, 3)
print(x_viewe.shape)
print(x_viewe)
というソースを書いて、実行すると下記の結果が得られます。
torch.Size([3, 2])
tensor([[0., 1.],
[2., 3.],
[4., 5.]])
torch.Size([2, 3])
tensor([[0., 2., 4.],
[1., 3., 5.]])
「view」は、二次元でないとまずいようで。
三次元でやろうとしたら、下記のようなエラーになりました。
x_viewe = torch.t(x).view(1, 2, 3)
^^^^^^^^^^
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
少なくとも、今のところ、「view」よりは、「reshape」の方が使い勝手がいいように思います。
|