• 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏吧

pytorch中对tensor操作:分片、索引、压缩、扩充、交换维度、拼接、切割

互联网 diligentman 3天前 3次浏览

pytorch中对tensor操作:分片、索引、压缩、扩充、交换维度、拼接、切割

  • 1 根据维度提取子集
  • 2 对数据进行压缩和扩充:torch.squeeze() 和torch.unsqueeze()
  • 3 对数据维度进行交换:tensor.permute()
  • 4 对数据进行拼接:torch.cat(), torch.stack()
  • 5 对数据进行切割:torch.split()

1 根据维度提取子集

1.0 原始数据情况

import torch
#### 先看一下原始数据
a = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],
                  [[-1,-2,-3,-4],[-5,-6,-7,-8],[-9,-10,-11,-12]]], dtype=float)
print(a)
# 每个print下面的内容是输出,这里是一个2*3*4的三维矩阵
tensor([[[  1.,   2.,   3.,   4.],
         [  5.,   6.,   7.,   8.],
         [  9.,  10.,  11.,  12.]],

        [[ -1.,  -2.,  -3.,  -4.],
         [ -5.,  -6.,  -7.,  -8.],
         [ -9., -10., -11., -12.]]], dtype=torch.float64)

1.1 根据第一个维度提取一个子集

#### 根据第一个维度提取第一个元素,结果是一个3*4的矩阵
print(a[0]) 
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]], dtype=torch.float64)

#### 根据第一个维度提取前两个元素,结果是一个2*3*4的矩阵,其实等于a,因为a在第一个维度也就两个元素
print(a[0:2])
tensor([[[  1.,   2.,   3.,   4.],
         [  5.,   6.,   7.,   8.],
         [  9.,  10.,  11.,  12.]],

        [[ -1.,  -2.,  -3.,  -4.],
         [ -5.,  -6.,  -7.,  -8.],
         [ -9., -10., -11., -12.]]], dtype=torch.float64)

1.2 根据前两个维度提取一个子集

#### 方法1.2.1
#### 提取第一个维度的第2个元素,再从中提取第二个维度的第3个元素,
#### 结果是一个向量
print(a[1,2])
tensor([ -9., -10., -11., -12.], dtype=torch.float64)

#### 方法1.2.2
#### 提取第一个维度的前两个元素,再从中提取第二个维度的1:2维度(也就是第1个元素)
#### 结果是一个2*1*4的矩阵
print(a[0:2,1:2])
tensor([[[ 5.,  6.,  7.,  8.]],

        [[-5., -6., -7., -8.]]], dtype=torch.float64)

#### 方法1.2.3
#### 提取第一个维度的前两个元素,再从中提取第二个维度的第1个元素
#### 注意结果是一个2*4的矩阵
print(a[0:2,1])
tensor([[ 5.,  6.,  7.,  8.],
        [-5., -6., -7., -8.]], dtype=torch.float64)

#### 方法1.2.4
#### 提取第二个维度的第3个元素,其他维度的元素全部提取
print(a[:,2]) # 同 print(a[:,2,:]),也就是维度比第二个维度大的下标可以忽略,默认全部提取
tensor([[  9.,  10.,  11.,  12.],
        [ -9., -10., -11., -12.]], dtype=torch.float64)

注意:上面的方法1.2.1最后的维数是1维,和原始数据比下降了两个维度。方法1.2.2和1.2.3想要获得的数据是一致的,但是维度不同。方法1.2.3下降了一个维度。从上面我们可以发现,用n个固定的标量来作为下标,会使得结果比原始数据降低n个维数。比如方法1.2.1种有两个固定标量(1和2),所以维数从3维下降成1维。方法1.2.3种有一个固定标量(1),所以下降了两维。方法1.2.4种有一个固定标量(2),所以下降了一维。

1.3 提取某个特定的元素的值

#### 提取第一个维度的第1个元素,再从中提取第二个维度的第3个元素,再从中提取第三个维度的第2个元素
print(a[0,2,1])
tensor(10., dtype=torch.float64)

#### 将这个值从tensor变量转成python中的数值变量
print(a[0,2,1].item())
10.0

2 对数据进行压缩和扩充:torch.squeeze() 和torch.unsqueeze()

2.1 squeeze()将元素个数只有1的维度压缩掉

#### 先看一下b长什么样,是一个2*1*4的3维矩阵
b = a[:,1:2]
print(b)
tensor([[[ 5.,  6.,  7.,  8.]],

        [[-5., -6., -7., -8.]]], dtype=torch.float64)

#### 将第二个维度压缩掉,因为第二个维度的元素个数只有1,所以可以压缩
c = b.squeeze(1) # 等价于 c = torch.squeeze(b,1)
print(c) # 看一下压缩后的结果,是一个2*4的矩阵
tensor([[ 5.,  6.,  7.,  8.],
        [-5., -6., -7., -8.]], dtype=torch.float64)

print(b) # 发现b没有变化,也就是torch.squeeze()会返回一个tensor,而不是inplace的操作
tensor([[[ 5.,  6.,  7.,  8.]],

        [[-5., -6., -7., -8.]]], dtype=torch.float64)
# tensor.squeeze_()是inplace操作
b.squeeze_(1)
print(b)
tensor([[ 5.,  6.,  7.,  8.],
        [-5., -6., -7., -8.]], dtype=torch.float64)

#### 将b中所有只有一个元素的维度都压缩
b = a[:,1:2,2:3]
print(b)
tensor([[[ 7.]],

        [[-7.]]], dtype=torch.float64)
print(b.squeeze())
tensor([ 7., -7.], dtype=torch.float64)

2.2 unsqueeze()对数据进行扩充维度

#### 先看一下数据的情况
b = a[0:2,1]
print(b)
tensor([[ 5.,  6.,  7.,  8.],
        [-5., -6., -7., -8.]], dtype=torch.float64)

print(b.size())
torch.Size([2, 4])

#### 在第一个维度进行扩充
print(b.unsqueeze(0)) # 等价于print(torch.unsqueeze(b,0))
tensor([[[ 5.,  6.,  7.,  8.],
         [-5., -6., -7., -8.]]], dtype=torch.float64)

print(b.unsqueeze(0).size()) 
torch.Size([1, 2, 4])

#### 在第二个维度进行扩充
# 等于将方法1.2.3的结果在第2个维度上进行扩充,变成和方法1.2.2是一样的结果
print(b.unsqueeze(1))
tensor([[[ 5.,  6.,  7.,  8.]],

        [[-5., -6., -7., -8.]]], dtype=torch.float64)
print(b.unsqueeze(1).size())
torch.Size([2, 1, 4])

#### 在第三个维度进行扩充
print(b.unsqueeze(2))
tensor([[[ 5.],
         [ 6.],
         [ 7.],
         [ 8.]],

        [[-5.],
         [-6.],
         [-7.],
         [-8.]]], dtype=torch.float64)
print(torch.unsqueeze(b,2))
torch.Size([2, 4, 1])

3 对数据维度进行交换:tensor.permute()

permute可以对数据维度进行交换,数据本身不变

#### 先看一下原始数据
print(a)
tensor([[[  1.,   2.,   3.,   4.],
         [  5.,   6.,   7.,   8.],
         [  9.,  10.,  11.,  12.]],

        [[ -1.,  -2.,  -3.,  -4.],
         [ -5.,  -6.,  -7.,  -8.],
         [ -9., -10., -11., -12.]]], dtype=torch.float64)
print(a.size())
torch.Size([2, 3, 4])

#### 将原始数据a的第2维给新数据b的第1维;第1维给第二维;第3维给第3维
b= a.permute(1,0,2)
print(b)
tensor([[[  1.,   2.,   3.,   4.],
         [ -1.,  -2.,  -3.,  -4.]],

        [[  5.,   6.,   7.,   8.],
         [ -5.,  -6.,  -7.,  -8.]],

        [[  9.,  10.,  11.,  12.],
         [ -9., -10., -11., -12.]]], dtype=torch.float64)
print(b.size())
torch.Size([3, 2, 4])

#### a本身不变         
print(a)
tensor([[[  1.,   2.,   3.,   4.],
         [  5.,   6.,   7.,   8.],
         [  9.,  10.,  11.,  12.]],

        [[ -1.,  -2.,  -3.,  -4.],
         [ -5.,  -6.,  -7.,  -8.],
         [ -9., -10., -11., -12.]]], dtype=torch.float64)

4 对数据进行拼接:torch.cat(), torch.stack()

4.1 cat
指定维度,利用cat对多个数据进行拼接,拼接前后的总维数不变

#### 先看一下数据
a = torch.tensor([[1,2,3,4],[5,6,7,8]])
b = torch.tensor([[9,10,11,12],[13,14,15,16]])
print(a)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])
print(b)
tensor([[ 9, 10, 11, 12],
        [13, 14, 15, 16]])

#### 根据第一维度拼接
print(torch.cat((a,b),0)) # 等价于print(torch.cat((a,b))
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

#### 根据第二维度拼接
print(torch.cat((a,b),1))
tensor([[ 1,  2,  3,  4,  9, 10, 11, 12],
        [ 5,  6,  7,  8, 13, 14, 15, 16]])

#### 还可以拼接多个
print(torch.cat((a,b,a))
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16],
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8]])

4.2 stack
指定维度,对多个数据进行拼接,拼接后总维数增加1

#### 按照第一个维度堆叠
print(torch.stack((a,b),0)) # 等价于torch.stack((a,b))
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]])
print(torch.stack((a,b)).size()) # 2维变成3维
torch.Size([2, 2, 4])

#### 按照第二个维度堆叠
print(torch.stack((a,b),1))
tensor([[[ 1,  2,  3,  4],
         [ 9, 10, 11, 12]],

        [[ 5,  6,  7,  8],
         [13, 14, 15, 16]]])
         
#### 按照第三个维度堆叠
print(torch.stack((a,b),2))
tensor([[[ 1,  9],
         [ 2, 10],
         [ 3, 11],
         [ 4, 12]],

        [[ 5, 13],
         [ 6, 14],
         [ 7, 15],
         [ 8, 16]]])

#### 和cat一样,可以对多个数据进行stack
print(torch.stack((a,b,a),2))
tensor([[[ 1,  9,  1],
         [ 2, 10,  2],
         [ 3, 11,  3],
         [ 4, 12,  4]],

        [[ 5, 13,  5],
         [ 6, 14,  6],
         [ 7, 15,  7],
         [ 8, 16,  8]]])

5 对数据进行切割:torch.split()

利用split对数据进行切割,split的第二个参数可以是一个数字也可以是一个list,第三个参数是维度。切割后的数据维度和原始数据一致

#### 先看一下数据
a = torch.arange(1,16).reshape(5,3)
print(a)
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [13, 14, 15]])
        
#### 均匀切割。根据第一个维度将a切割,每块包含2个元素,最后不足的就有多少输出多少
x = torch.split(a,2,0) # 获得3块结果,每块结果的维度和原始数据一致
print(x[0])
tensor([[1, 2, 3],
        [4, 5, 6]])

print(x[1])
tensor([[ 7,  8,  9],
        [10, 11, 12]])

print(x[2]) # 因为最后一块数据不足,所以只有一行,而不是两行
tensor([[13, 14, 15]])

#### 均匀切割。根据第二个维度进行切割,每块包含2个元素
x = torch.split(a,2,1)
print(x[0])
tensor([[ 1,  2],
        [ 4,  5],
        [ 7,  8],
        [10, 11],
        [13, 14]])

print(x[1])
tensor([[ 3],
        [ 6],
        [ 9],
        [12],
        [15]])

#### 自定义切割。根据第二个维度切割,一共切割成两块,第一个块包含1个元素(也就是1列),第二块包含2个元素(也就是2列)
x = torch.split(a,[1,2],1)
print(x[0])
tensor([[ 1],
        [ 4],
        [ 7],
        [10],
        [13]])

print(x[1])
tensor([[ 2,  3],
        [ 5,  6],
        [ 8,  9],
        [11, 12],
        [14, 15]])


喜欢 (0)