Pytorch expand和repeat
admin
2024-02-27 23:12:00
0

       在torch中,如果要改变某一个tensor的维度,可以利用view、expand、repeat、transpose和permute等方法,这里对这些方法的一些容易混淆的地方做个总结。

​expand和repeat函数是pytorch中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1 torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上。
参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1。
expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。
 

import torch
a = tensor([1, 0, 2])
b = a.expand(2, -1)   # 第一个维度为升维,第二个维度保持原阳
# b为   tensor([[1, 0, 2],  [1, 0, 2]])a = torch.tensor([[1], [0], [2]])
b = a.expand(-1, 2)   # 保持第一个维度,第二个维度只有一个元素,可扩展
# b为  tensor([[1, 1],
#              [0, 0],
#              [2, 2]])a = torch.Tensor([[1,2,3]])
'''
tensor([[1.,2.,3.]]
)
torch.Size([1, 3])
'''aa = a.expand(4, 3)  # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,# 可以在该维度上进行tensor的复制,若大于1则不行
'''
tensor([[1.,2.,3.],[1.,2.,3.],[1.,2.,3.],[1.,2.,3.]]
)
'''a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor([[1.,2.,3.],[4.,5.,6.]]
)
torch.Size([2, 3])
'''
a = a.expand(4,6) # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''aa = a.expand(1,2,3) # 可以在tensor的低维增加更多维度
'''
tensor([[[1.,2.,3.],[4.,5.,6.]]]
)
'''
aaa = a.expand(2,2,3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''aaa = a.expand(2,3,2) # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''aaaa = a.expand(2, -1, -1) # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''# expand返回的张量与原版张量具有相同内存地址
print(aaaa.storage()) # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

 可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变。

前文提及expand仅能作用于单数维,那对于非单数维的拓展,那就需要借助于repeat函数了。

tensor.repeat(*sizes)

参数*sizes指定了原始张量在各维度上复制的次数。整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,而更接近于tile函数的效果。

与expand不同,repeat函数会真正的复制数据并存放于内存中。repeat开辟了新的内存空间,torch.repeat返回的张量在内存中是连续的

import torch
a = torch.tensor([1, 0, 2])
b = a.repeat(3,2)  # 在轴0上复制3份,在轴1上复制2份
# b为 tensor([[1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2]])import torch
a = torch.Tensor([[1,2,3]])
'''
tensor([[1.,2.,3.]]
)
'''aa = a.repeat(4, 3) # 维度不变,在各个维度上进行数据复制
'''
tensor([[1.,2.,3.,1.,2.,3.,1.,2.,3.],[1.,2.,3.,1.,2.,3.,1.,2.,3.],[1.,2.,3.,1.,2.,3.,1.,2.,3.],[1.,2.,3.,1.,2.,3.,1.,2.,3.]]
)
'''a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor([[1.,2.,3.],[4.,5.,6.]]
)
'''
aa = a.repeat(4,6) # 维度不变,在各个维度上进行数据复制
'''
tensor([[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.]]
)
'''aaa = a.repeat(1,2,3) # 可以在tensor的低维增加更多维度,并在各维度上复制数据
'''
tensor([[[1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.],[1.,2.,3.,1.,2.,3.,1.,2.,3.],[4.,5.,6.,4.,5.,6.,4.,5.,6.]]]
)
'''
aaaa = a.repeat(2,3,1) # 可以在tensor的高维增加更多维度,并在各维度上复制数据
'''
tensor([[[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.],[1.,2.,3.],[4.,5.,6.]]]
)
'''aaaaa = a.repeat(2, 3, -1) 
'''
RuntimeError: Trying to create tensor with negative dimension -3: [2,6,-3]
'''print(aaaa.storage()) # 存储区的数据,说明repeat后的a,aa,aaa,aaaa是有各自独立的storage的
'''
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
'''

2.1 repeat_intertile

Pytorch中,与Numpyrepeat函数相类似的函数为torch.repeat_interleave

torch.repeat_interleave(input, repeats, dim=None)

参数input为原始张量,repeats为指定轴上的复制次数,而dim为复制的操作轴,若取值为None则默认将所有元素进行复制,并会返回一个flatten之后一维张量。与repeat将整个原始张量作为整体不同,repeat_interleave操作是逐元素的。

a = torch.tensor([[1], [0], [2]])
b = torch.repeat_interleave(a, repeats=3)   # 结果flatten
# b为tensor([1, 1, 1, 0, 0, 0, 2, 2, 2])c = torch.repeat_interleave(a, repeats=3, dim=1)  # 沿着axis=1逐元素复制
# c为tensor([[1, 1, 1],
#        [0, 0, 0],
#        [2, 2, 2]])

总结
相同:
(1)都可以扩展维度,或在某个维度上进行tensor的复制

区别:
(1)参数意义不同,repeat的参数表示沿某维度的数据复制倍数,可为大于0的任何整数值;expand的参数表示tensor对应的维度上的值,且只有增加新的低维度时表示沿该低维度的数据复制倍数,其他参数必须和原始tensor保持一致
(2)返回的结果的存储区不同,repeat返回的tensor会重新拥有一个独立存储区,而expand返回的tensor则与原始tensor共享存储区

相关内容

热门资讯

别再骂孩子“笨”!学习困难不是... “同样的知识点,别人听一遍就会,你怎么教都不会”“作业磨磨蹭蹭,简单的题也总出错”“明明很努力,成绩...
通辽年货节:消费盛宴燃动冬日,... “这个牛肉干味道超棒,奶茶香得不得了!”“年货都快置办齐全啦!”“用卡消费能凑满减,太划算了!”1月...
黄金白银价格剧烈波动,豫光金铅... 【大河财立方消息】现货黄金和白银价格近期在刷新历史高点后出现大幅波动,2月2日再度跳水。截至发稿,现...
AI眼镜能把Meta从元宇宙的... 出品|虎嗅科技组 作者|赵致格 编辑|苗正卿 头图|视觉中国 “我们的AI眼镜正在重塑人们体验世界的...
马上评|上海定制消费,凭啥吸引... 近期,上海的定制消费引发热议:南外滩轻纺面料市场,众多中外消费者排队定制西装、旗袍、中式礼服,多种语...
黄金暴跌,市场总有轮回。 今天不聊别的,还是聊黄金。但今天为了说清楚黄金,我会先分析同样经历了暴跌的比特币、以太坊等crypt...
2025年净利最高预亏2.9亿... 北京商报讯(记者 丁宁)2月2日,双鹭药业(002038)盘中触及跌停,截至北京商报记者发稿,双鹭药...
盘点2025信托业(二)|新监... 中国网财经2月2日讯 2025年,信托行业监管制度体系经历根本性重塑,“1+N”政策框架落地生根,为...
中国国航2026年春运计划执行... 中国国际航空股份有限公司(下称“国航”)2月2日宣布,2026年春运将全面升级运力投入,在册飞机数量...
原创 今... 一家公司,在不到一个月的时间里,股价像坐上了火箭,一口气连拉17个涨停板,价格翻了4倍多。 就在所有...
脑出血后昏迷不醒有什么治疗方法... 脑出血是一种发病急、进展快、致死致残率极高的急性脑血管疾病,其发病率约占全部脑卒中的20%~30%,...
原创 今... 今天A股跌得让人心慌,商业航天、黄金、农业,热点一个接一个熄火。 但在一片惨绿中,有一个板块却像打了...
2026年存款搬家,有望为A股... 根据财通证券测算,2026年企业与居民中长期存款的到期规模在57.3万亿元,而居民中长期存款到期规模...
两部门发文明确增值税进项税额抵... 财政部 税务总局 关于增值税进项税额抵扣等有关事项的公告 财政部 税务总局公告2026年第13号 根...
兴业银行多位经理被禁止从事银行... 据国家金融监督管理总局大连监管局2月2日公开的行政处罚信息显示,兴业银行股份有限公司大连分行因信贷业...
原创 金... 深圳水贝市场里,一位女士快速计算着价格,掏出银行卡支付了近12万元,买下100克金条。 黄金价格正在...
韩国股市大跌触发熔断机制 1日,韩国首尔,中区一家银行交易室的电子屏显示韩国综合股价指数(KOSPI)。视觉中国/图 韩国股市...
尿色加深、皮肤泛黄时,你的肝可... 生活中,不少人发现尿色变深、皮肤泛黄时,会误以为是喝水少、上火或肤色问题,简单调整后便不再关注。可这...
30亿元红包!千问宣布 2月2日,千问APP宣布投入30亿元启动“春节请客计划”,以免单形式请全国人民在春节期间吃喝玩乐,感...
周生生足金挂坠被检测出含铁银钯... 编者按:维护消费者权益,守护消费安全。央广网啄木鸟消费者投诉平台,保障消费者合法权益,为新消费时代保...