• 欢迎光临~

# torch.cat() 与 torch.stack() 的区别

• 1. torch.cat()
• 2. torch.stack()

# 1. torch.cat()

``````torch.cat(tensors, dim=0)
``````

• `tensors`：张量序列。
• `dim`：拼接张量序列的维度。
``````import torch

a = torch.rand(2, 3)
b = torch.rand(2, 3)
c = torch.cat((a, b))
print(a.size(), b.size(), c.size())
``````
``````torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([4, 3])
``````

``````d = torch.rand(2, 4)
print(torch.cat((a, d)))
``````
``````RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 4 for tensor number 1 in the list.
``````

``````print(a)
print(torch.cat((a, a, a), dim=0))
print(torch.cat((a, a, a), dim=1))
``````
``````tensor([[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186]])
tensor([[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186],
[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186],
[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186]])
tensor([[0.2381, 0.7100, 0.8150, 0.2381, 0.7100, 0.8150, 0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186, 0.5190, 0.5829, 0.9186, 0.5190, 0.5829, 0.9186]])
``````

# 2. torch.stack()

``````torch.stack(tensors, dim=0)
``````

• `tensors`：张量序列
• `dim`：要插入的维度。
``````import torch

a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
print(a.size(), b.size(), c.size())
``````
``````torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 2, 3])
``````

``````d = torch.rand(2, 4)
print(torch.stack((a, d)))
``````
``````RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [2, 4] at entry 1
``````

``````x = torch.arange(1, 7).reshape((3, 2))
y = torch.arange(10, 70, 10).reshape((3, 2))
z = torch.arange(100, 700, 100).reshape((3, 2))
print(x)
print(y)
print(z)
``````
``````tensor([[1, 2],
[3, 4],
[5, 6]])
tensor([[10, 20],
[30, 40],
[50, 60]])
tensor([[100, 200],
[300, 400],
[500, 600]])
``````
``````m = torch.stack((x,y,z))
print(m)
``````
``````tensor([[[  1,   2],
[  3,   4],
[  5,   6]],

[[ 10,  20],
[ 30,  40],
[ 50,  60]],

[[100, 200],
[300, 400],
[500, 600]]])
``````
``````n = torch.stack((x,y,z), 1)
print(n)
``````
``````tensor([[[  1,   2],
[ 10,  20],
[100, 200]],

[[  3,   4],
[ 30,  40],
[300, 400]],

[[  5,   6],
[ 50,  60],
[500, 600]]])
``````
``````h = torch.stack((x,y,z), 2)
print(h)
``````
``````tensor([[[  1,  10, 100],
[  2,  20, 200]],

[[  3,  30, 300],
[  4,  40, 400]],

[[  5,  50, 500],
[  6,  60, 600]]])
``````