Pytorch expand和repeat
admin
2024-02-27 23:12:00
0

       在torch中,如果要改变某一个tensor的维度,可以利用view、expand、repeat、transpose和permute等方法,这里对这些方法的一些容易混淆的地方做个总结。

​expand和repeat函数是pytorch中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1 torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上。
参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1。
expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。
 

import torch
a = tensor([1, 0, 2])
b = a.expand(2, -1)   # 第一个维度为升维,第二个维度保持原阳
# b为   tensor([[1, 0, 2],  [1, 0, 2]])a = torch.tensor([[1], [0], [2]])
b = a.expand(-1, 2)   # 保持第一个维度,第二个维度只有一个元素,可扩展
# b为  tensor([[1, 1],
#              [0, 0],
#              [2, 2]])a = torch.Tensor([[1,2,3]])
'''
tensor([[1.,2.,3.]]
)
torch.Size([1, 3])
'''aa = a.expand(4, 3)  # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,# 可以在该维度上进行tensor的复制,若大于1则不行
'''
tensor([[1.,2.,3.],[1.,2.,3.],[1.,2.,3.],[1.,2.,3.]]
)
'''a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor([[1.,2.,3.],[4.,5.,6.]]
)
torch.Size([2, 3])
'''
a = a.expand(4,6) # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''aa = a.expand(1,2,3) # 可以在tensor的低维增加更多维度
'''
tensor([[[1.,2.,3.],[4.,5.,6.]]]
)
'''
aaa = a.expand(2,2,3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''aaa = a.expand(2,3,2) # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''aaaa = a.expand(2, -1, -1) # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''# expand返回的张量与原版张量具有相同内存地址
print(aaaa.storage()) # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

 可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变。

前文提及expand仅能作用于单数维,那对于非单数维的拓展,那就需要借助于repeat函数了。

tensor.repeat(*sizes)

参数*sizes指定了原始张量在各维度上复制的次数。整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,而更接近于tile函数的效果。

与expand不同,repeat函数会真正的复制数据并存放于内存中。repeat开辟了新的内存空间,torch.repeat返回的张量在内存中是连续的

import torch
a = torch.tensor([1, 0, 2])
b = a.repeat(3,2)  # 在轴0上复制3份,在轴1上复制2份
# b为 tensor([[1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2]])import torch
a = torch.Tensor([[1,2,3]])
'''
tensor([[1.,2.,3.]]
)
'''aa = a.repeat(4, 3) # 维度不变,在各个维度上进行数据复制
'''
tensor([[1.,2.,3.,1.,2.,3.,1.,2.,3.],[1.,2.,3.,1.,2.,3.,1.,2.,3.],[1.,2.,3.,1.,2.,3.,1.,2.,3.],[1.,2.,3.,1.,2.,3.,1.,2.,3.]]
)
'''a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor([[1.,2.,3.],[4.,5.,6.]]
)
'''
aa = a.repeat(4,6) # 维度不变,在各个维度上进行数据复制
'''
tensor([[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.]]
)
'''aaa = a.repeat(1,2,3) # 可以在tensor的低维增加更多维度,并在各维度上复制数据
'''
tensor([[[1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.]]]
)
'''
aaaa = a.repeat(2,3,1) # 可以在tensor的高维增加更多维度,并在各维度上复制数据
'''
tensor([[[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.]]]
)
'''aaaaa = a.repeat(2, 3, -1) 
'''
RuntimeError: Trying to create tensor with negative dimension -3: [2,6,-3]
'''print(aaaa.storage()) # 存储区的数据,说明repeat后的a,aa,aaa,aaaa是有各自独立的storage的
'''
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
'''

2.1 repeat_intertile

Pytorch中,与Numpyrepeat函数相类似的函数为torch.repeat_interleave

torch.repeat_interleave(input, repeats, dim=None)

参数input为原始张量,repeats为指定轴上的复制次数,而dim为复制的操作轴,若取值为None则默认将所有元素进行复制,并会返回一个flatten之后一维张量。与repeat将整个原始张量作为整体不同,repeat_interleave操作是逐元素的。

a = torch.tensor([[1], [0], [2]])
b = torch.repeat_interleave(a, repeats=3)   # 结果flatten
# b为tensor([1, 1, 1, 0, 0, 0, 2, 2, 2])c = torch.repeat_interleave(a, repeats=3, dim=1)  # 沿着axis=1逐元素复制
# c为tensor([[1, 1, 1],
#        [0, 0, 0],
#        [2, 2, 2]])

总结
相同:
(1)都可以扩展维度,或在某个维度上进行tensor的复制

区别:
(1)参数意义不同,repeat的参数表示沿某维度的数据复制倍数,可为大于0的任何整数值;expand的参数表示tensor对应的维度上的值,且只有增加新的低维度时表示沿该低维度的数据复制倍数,其他参数必须和原始tensor保持一致
(2)返回的结果的存储区不同,repeat返回的tensor会重新拥有一个独立存储区,而expand返回的tensor则与原始tensor共享存储区

相关内容

热门资讯

4月份银行理财规模环比增加2.... 钛媒体App 5月16日消息,银行理财市场在4月份迎来规模与收益的双增长。据华源证券廖志明团队发布的...
【光明日报】黑龙江:免签红利释... 5月10日早上7时,一辆国际大巴缓缓停靠在黑龙江绥芬河公路口岸入境大厅前。游客们提着大包小裹,依次走...
又一跨国高端化工合作项目落子乐... 5月15日,福华化学携手瑞士特种化学品企业科莱恩打造的创新型高端磷系无卤阻燃剂项目(以下简称“福华科...
鸿蒙智行:已拥有1951家销售... IT之家 5 月 15 日消息,鸿蒙智行智界 V9 发布会正在进行,官方透露目前已拥有 1951 家...
黄金、白银,直线大跌! 5月15日晚间,贵金属价格突然大跌! 截至记者发稿时,现货黄金跌超2%,暂报4553美元/盎司附近。...
央视《焦点访谈》聚焦!万兴科技... 深圳商报·读创客户端首席记者 谢惠茜 5月14日,中央电视台《焦点访谈》推出专题节目《扩能提质强服务...
东方嘉富人寿董事长履职半年被换... 文|达摩财经 东方嘉富人寿再度进行人事调整。 5月13日,东方嘉富人寿发布公告称,自2026年4月...
重返西决!文班19+6卡斯尔3... 【搜狐体育战报】北京时间5月16日NBA季后赛,客场作战的马刺以139-109击败森林狼,总比分4-...
原创 美... 十万亿美债为什么还没有崩盘?或许答案在于,中国的存在让局势与众不同。现在的美债就像一张看似脆弱的网,...
原创 茅... 最近打开股票软件看白酒板块,是不是心里拔凉拔凉的? 茅台又回到1300元区间了,五粮液跌破90元,洋...
茅台宣布涨价 5月15日深夜,“i茅台”APP发布公告称,按照随行就市、供需适配、量价平衡、相对平稳的原则,贵州茅...
最高涨200元!茅台官宣4款产... 贵州茅台(600519.SH)凌晨宣布涨价几款产品。 茅台数字营销平台“i茅台”今日(5月16日)发...
面向地方国资产融转型全链条,X... 5月15日,XOD创新投融资模式3.0产品发布会在广州举办。该产品主要面向地方国资产融协同创新转型提...
2026Q1:10家上市商超9... 截至4月30日,所有A股上市公司2026年Q1财报全部出炉,传统商超也晒出自己的成绩单。10家披露的...
入主盟科药业失利后,拟借款2.... 来源:时代周报-时代在线 继去年试图通过定增入主盟科药业(688373.SH)失败后,海鲸药业再度出...
同比激增86%、规模突破760... 图片来源:界面图库 界面新闻记者 | 孙艺真 今年以来,证券行业融资补血热潮持续升温。前5个月...
促进青年消费,扶持青年创业,上... 5月14日,上海市政协团青界别、经济界跨界别活动在市政协全过程人民民主实践点举行。 今年初,团市委立...
苹果股价昨日创收盘新高,站上3... IT之家 5 月 16 日消息,苹果公司股价昨日(5 月 15 日)收于 300.23 美元,首次站...
杭州首批配售型保障房正式入市 杭州首批配售型保障房正式入市 价格约为周边商品房5折,18日开始报名 不能入市交易,可由政府指定机构...
“后巴菲特时代”,伯克希尔调仓... 当地时间5月15日,伯克希尔披露了2026年一季度美股持仓报告。这是伯克希尔在巴菲特卸任CEO并由阿...