[Python] Pytorch实现ListTensor转Tensor,reshape拼接等操作

1948 0
王子 2022-11-6 10:13:07 | 显示全部楼层 |阅读模式
目录

    一、List Tensor转Tensor (torch.cat)
      高维tensor
    二、List Tensor转Tensor (torch.stack)

持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。
其它Tensor操作如 einsum等见:待更新。
用到两个函数:
    torch.cattorch.stack

一、List Tensor转Tensor (torch.cat)


  1. // An highlighted block
  2. >>> t1 = torch.FloatTensor([[1,2],[5,6]])
  3. >>> t2 = torch.FloatTensor([3,4],[7,8]])
  4. >>> l = []
  5. >>> l.append(t1)
  6. >>> l.append(t2)
  7. >>> ta = torch.cat(l,dim=0)
  8. >>> ta = torch.cat(l,dim=0).reshape(2,2,2)
  9. >>> tb = torch.cat(l,dim=1).reshape(2,2,2)
  10. >>> ta
  11. tensor([[[1., 2.],
  12.          [5., 6.]],
  13.         [[3., 4.],
  14.          [7., 8.]]])
  15. >>> tb
  16. tensor([[[1., 2.],
  17.          [3., 4.]],
  18.         [[5., 6.],
  19.          [7., 8.]]])
复制代码
高维tensor

** 如果理解了2D to 3DTensor,以此类推,不难理解3D to 4D,看下面代码即可明白:**
  1. >>> t1 = torch.range(1,8).reshape(2,2,2)
  2. >>> t2 = torch.range(11,18).reshape(2,2,2)
  3. >>> l = []
  4. >>> l.append(t1)
  5. >>> l.append(t2)
  6. >>> torch.cat(l,dim=2).reshape(2,2,2,2)
  7. tensor([[[[ 1.,  2.],
  8.           [11., 12.]],
  9.          [[ 3.,  4.],
  10.           [13., 14.]]],
  11.         [[[ 5.,  6.],
  12.           [15., 16.]],
  13.          [[ 7.,  8.],
  14.           [17., 18.]]]])
  15. >>> torch.cat(l,dim=1).reshape(2,2,2,2)
  16. tensor([[[[ 1.,  2.],
  17.           [ 3.,  4.]],
  18.          [[11., 12.],
  19.           [13., 14.]]],
  20.         [[[ 5.,  6.],
  21.           [ 7.,  8.]],
  22.          [[15., 16.],
  23.           [17., 18.]]]])
  24. >>> torch.cat(l,dim=0).reshape(2,2,2,2)
  25. tensor([[[[ 1.,  2.],
  26.           [ 3.,  4.]],
  27.          [[ 5.,  6.],
  28.           [ 7.,  8.]]],
  29.         [[[11., 12.],
  30.           [13., 14.]],
  31.          [[15., 16.],
  32.           [17., 18.]]]])
复制代码
二、List Tensor转Tensor (torch.stack)



代码:
  1. import torch
  2. t1 = torch.FloatTensor([[1,2],[5,6]])
  3. t2 = torch.FloatTensor([[3,4],[7,8]])
  4. l = [t1, t2]
  5. t3 = torch.stack(l, dim=2)
  6. print(t3.shape)
  7. print(t3)
  8. ## output:
  9. ## torch.Size([2, 2, 2])
  10. ## tensor([[[1., 3.],
  11. ##          [2., 4.]],
  12. ##        [[5., 7.],
  13. ##         [6., 8.]]])
复制代码
以上为个人经验,希望能给大家一个参考,也希望大家多多支持中国红客联盟。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

中国红客联盟公众号

联系站长QQ:5520533

admin@chnhonker.com
Copyright © 2001-2025 Discuz Team. Powered by Discuz! X3.5 ( 粤ICP备13060014号 )|天天打卡 本站已运行