Python - AI - Tensor - リサイズ(reshape)

クラウディア 
1. 概要
2. permute・transpose 次元の並べ替え
3. reshape・view 次元・要素を変更

1. 概要

 前ページで、「interpolate」を用いて、「tensor」をリサイズする方法についてメモしました。  「interpolate」は、リサイズすると同時に、間を穴埋めするような方式ですが、別のやり方もあるようです。
【PyTorch 入門】PyTorch の次元操作 permute, transpose, reshape, view って何が違うの?

2. permute・transpose 次元の並べ替え

 サイズ自体は、変わりませんが、「permute」「transpose」というメソッドで、次元の並べ替えを行うことができます。

import os
import torch

x = torch.tensor([[[0.0, 1.0],[2.0, 3.0], [ 4.0, 5.0 ]]])
print(x.shape)
print(x, os.linesep)

x_permute = torch.permute(x, (2, 0, 1))
print(x_permute.shape)
print(x_permute)
 というソースを書いて、実行すると下記の結果が得られます。

torch.Size([1, 3, 2])
tensor([[[0., 1.],
         [2., 3.],
         [4., 5.]]])

torch.Size([2, 1, 3])
tensor([[[0., 2., 4.]],

        [[1., 3., 5.]]])
 どことどこの次元が入れ替わったのか、ちとわかりにくいですが。  3次元のものを 0、1、2 が、それぞれの次元とすると、2、0、1 の順に入れ替えているわけです。  「transpose」も、時点の並べ替えを行います。

import os
import torch

x = torch.tensor([[[0.0, 1.0], [2.0, 3.0], [ 4.0, 5.0 ]]])
print(x.shape)
print(x, os.linesep)

x_transpose = torch.transpose(x, 2, 1)
print(x_transpose.shape)
print(x_transpose)
 というソースを書いて、実行すると下記の結果が得られます。

torch.Size([1, 3, 2])
tensor([[[0., 1.],
         [2., 3.],
         [4., 5.]]])

torch.Size([1, 2, 3])
tensor([[[0., 2., 4.],
         [1., 3., 5.]]])

3. reshape・view 次元・要素を変更

 「reshape」「view」というメソッドで、次元・要素を変更できます。

import os
import torch

x = torch.tensor([[[ 0.0, 1.0 ], [ 2.0, 3.0 ], [ 4.0, 5.0 ]]])
print(x.shape)
print(x, os.linesep)

x_reshape = torch.reshape(x, (1, 1, 2, 3))
print(x_reshape.shape)
print(x_reshape)
 というソースを書いて、実行すると下記の結果が得られます。

torch.Size([1, 3, 2])
tensor([[[0., 1.],
         [2., 3.],
         [4., 5.]]])

torch.Size([1, 1, 2, 3])
tensor([[[[0., 1., 2.],
          [3., 4., 5.]]]])
 次元や要素が変わっています。  気を付けなければならないのは、すべての次元の要素数を掛け合わせた数を同じにしなければ、エラーになるそうです。  「view」について、調べます。

import os
import torch

x = torch.tensor([[ 0.0, 1.0 ], [ 2.0, 3.0 ], [ 4.0, 5.0 ]])
print(x.shape)
print(x, os.linesep)

x_viewe = torch.t(x).view(2, 3)
print(x_viewe.shape)
print(x_viewe)
 というソースを書いて、実行すると下記の結果が得られます。

torch.Size([3, 2])
tensor([[0., 1.],
        [2., 3.],
        [4., 5.]])

torch.Size([2, 3])
tensor([[0., 2., 4.],
        [1., 3., 5.]])
 「view」は、二次元でないとまずいようで。  三次元でやろうとしたら、下記のようなエラーになりました。

    x_viewe = torch.t(x).view(1, 2, 3)
              ^^^^^^^^^^
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
 少なくとも、今のところ、「view」よりは、「reshape」の方が使い勝手がいいように思います。
AbemaTV 無料体験