【三维几何学习】从零开始网格上的深度学习-3:Transformer篇(Pytorch)
创始人
2025-05-31 23:16:56
0

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052

从零开始网格上的深度学习-3:Transformer篇

  • 引言
  • 一、概述
  • 二、核心代码
    • 2.1 位置编码
    • 2.2 网络框架
  • 三、基于Transformer的网格分类
    • 3.1 分类结果
    • 3.2 全部代码

引言

本文主要内容如下:

  • 简述网格上的位置编码
  • 参考点云上的Transformer-1:PCT:Point cloud transformer,构造网格分类网络

一、概述

在这里插入图片描述

个人认为对于三角形网格来说,想要将Transformer应用到其上较为重要的一步是位置编码。三角网格在3D空间中如何编码每一个元素的位置,能尽可能保证的泛化性能? 以xyz坐标为例,最好是模型经过对齐的预处理,使朝向一致。或者保证网格水密的情况下使用谱域特征,如热核特征。或者探索其他位置编码等等… 上图为一个外星人x坐标的位置编码可视化

  • 使用简化网格每一个面直接作为一个Token即可,高分辨率的网格(考虑输入特征计算、训练数据对齐等)并不适合深度学习(个人认为)
  • 直接应用现有的Tranformer网络框架、自注意力模块等,细节或参数需要微调

二、核心代码

2.1 位置编码

使用每一个网格面的中心坐标作为位置编码,计算代码在DataLoader中

  • 需要平移到坐标轴原点,并进行尺度归一化
# xyz
xyz_min = np.min(vs[:, 0:3], axis=0)
xyz_max = np.max(vs[:, 0:3], axis=0)
xyz_move = xyz_min + (xyz_max - xyz_min) / 2
vs[:, 0:3] = vs[:, 0:3] - xyz_move
# scale
scale = np.max(vs[:, 0:3])
vs[:, 0:3] = vs[:, 0:3] / scale
# 面中心坐标
xyz = []
for i in range(3):xyz.append(vs[faces[:, i]])
xyz = np.array(xyz)  # 转为np
mean_xyz = xyz.sum(axis=0) / 3

2.2 网络框架

在这里插入图片描述

  • 参考上图PCT框架,修改了部分细节,如减少了Attention模块数量等

在这里插入图片描述

  • 参考上图自注意力模块,个人感觉图中应该有误. 从一个共享权重的Linear里出来了Q、K、VQ、K、VQ、K、V三个矩阵,但VVV的维度和Q、KQ、KQ、K不一致,少画了一个Linear?
class SA(nn.Module):def __init__(self, channels):super().__init__()self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.q_conv.weight = self.k_conv.weightself.v_conv = nn.Conv1d(channels, channels, 1, bias=False)self.trans_conv = nn.Conv1d(channels, channels, 1)self.after_norm = nn.BatchNorm1d(channels)self.act = nn.GELU()self.softmax = nn.Softmax(dim=-1)def forward(self, x):x_q = self.q_conv(x).permute(0, 2, 1)x_k = self.k_conv(x)x_v = self.v_conv(x)energy = x_q @ x_kattention = self.softmax(energy)attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))x_r = x_v @ attentionx_r = self.act(self.after_norm(self.trans_conv(x - x_r)))x = x + x_rreturn xclass TriTransNet(nn.Module):def __init__(self, dim_in, classes_n=30):super().__init__()self.conv_fea = FaceConv(6, 128, 4)self.conv_pos = FaceConv(3, 128, 4)self.bn_fea = nn.BatchNorm1d(128)self.bn_pos = nn.BatchNorm1d(128)self.sa1 = SA(128)self.sa2 = SA(128)self.gp = nn.AdaptiveAvgPool1d(1)self.linear1 = nn.Linear(256, 128, bias=False)self.bn1 = nn.BatchNorm1d(128)self.linear2 = nn.Linear(128, classes_n)self.act = nn.GELU()def forward(self, x, mesh):x = x.permute(0, 2, 1).contiguous()# 位置编码 放到DataLoader中比较好pos = [m.xyz for m in mesh]pos = np.array(pos)pos = torch.from_numpy(pos).float().to(x.device).requires_grad_(True)batch_size, _, N = x.size()x = self.act(self.bn_fea(self.conv_fea(x, mesh).squeeze(-1)))pos = self.act(self.bn_pos(self.conv_pos(pos, mesh).squeeze(-1)))x1 = self.sa1(x + pos)x2 = self.sa2(x1 + pos)x = torch.cat((x1, x2), dim=1)x = self.gp(x)x = x.view(batch_size, -1)x = self.act(self.bn1(self.linear1(x)))x = self.linear2(x)return x

三、基于Transformer的网格分类

数据集是SHREC’11 可参考三角网格(Triangular Mesh)分类数据集 或 MeshCNN

3.1 分类结果

在这里插入图片描述在这里插入图片描述
准确率太低… 可以尝试改进的点:

  • 尝试不同的位置编码(谱域特征),不同的位置嵌入方式 (sum可改为concat)
  • 数据集较小的情况下Transformer略难收敛,加入更多CNN可加速且提升明显 (或者加入降采样)
  • 打印loss进行分析,是否欠拟合,尝试增加网络参数?

基于Transformer的网络在网格分割上的表现会很好,仅用少量参数即可媲美甚至超过基于面卷积的分割结果,个人感觉得益于其近乎全局的感受野…

3.2 全部代码

DataLoader代码请参考2:从零开始网格上的深度学习-1:输入篇(Pytorch)
FaceConv代码请参考3:从零开始网格上的深度学习-2:卷积网络CNN篇

import torch
import torch.nn as nn
import numpy as np
from CNN import FaceConv
from DataLoader_shrec11 import DataLoader
from DataLoader_shrec11 import Meshclass SA(nn.Module):def __init__(self, channels):super().__init__()self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.q_conv.weight = self.k_conv.weightself.v_conv = nn.Conv1d(channels, channels, 1, bias=False)self.trans_conv = nn.Conv1d(channels, channels, 1)self.after_norm = nn.BatchNorm1d(channels)self.act = nn.GELU()self.softmax = nn.Softmax(dim=-1)def forward(self, x):x_q = self.q_conv(x).permute(0, 2, 1)x_k = self.k_conv(x)x_v = self.v_conv(x)energy = x_q @ x_kattention = self.softmax(energy)attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))x_r = x_v @ attentionx_r = self.act(self.after_norm(self.trans_conv(x - x_r)))x = x + x_rreturn xclass TriTransNet(nn.Module):def __init__(self, dim_in, classes_n=30):super().__init__()self.conv_fea = FaceConv(6, 128, 4)self.conv_pos = FaceConv(3, 128, 4)self.bn_fea = nn.BatchNorm1d(128)self.bn_pos = nn.BatchNorm1d(128)self.sa1 = SA(128)self.sa2 = SA(128)self.gp = nn.AdaptiveAvgPool1d(1)self.linear1 = nn.Linear(256, 128, bias=False)self.bn1 = nn.BatchNorm1d(128)self.linear2 = nn.Linear(128, classes_n)self.act = nn.GELU()def forward(self, x, mesh):x = x.permute(0, 2, 1).contiguous()# 位置编码 放到DataLoader中比较好pos = [m.xyz for m in mesh]pos = np.array(pos)pos = torch.from_numpy(pos).float().to(x.device).requires_grad_(True)batch_size, _, N = x.size()x = self.act(self.bn_fea(self.conv_fea(x, mesh).squeeze(-1)))pos = self.act(self.bn_pos(self.conv_pos(pos, mesh).squeeze(-1)))x1 = self.sa1(x + pos)x2 = self.sa2(x1 + pos)x = torch.cat((x1, x2), dim=1)x = self.gp(x)x = x.view(batch_size, -1)x = self.act(self.bn1(self.linear1(x)))x = self.linear2(x)return xif __name__ == '__main__':# 输入data_train = DataLoader(phase='train')         # 训练集data_test = DataLoader(phase='test')           # 测试集print('#train meshes = %d' % len(data_train)) # 输出训练模型个数print('#test  meshes = %d' % len(data_test))  # 输出测试模型个数# 网络net = TriTransNet(data_train.input_n, data_train.class_n)    # 创建网络 以及 优化器optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))net = net.cuda(0)loss_fun = torch.nn.CrossEntropyLoss(ignore_index=-1)num_params = 0for param in net.parameters():num_params += param.numel()print('[Net] Total number of parameters : %.3f M' % (num_params / 1e6))print('-----------------------------------------------')# 迭代训练for epoch in range(1, 201):print('---------------- Epoch: %d -------------' % epoch)for i, data in enumerate(data_train):# 前向传播net.train(True)        # 训练模式optimizer.zero_grad()  # 梯度清零face_features = torch.from_numpy(data['face_features']).float()face_features = face_features.to(data_train.device).requires_grad_(True)labels = torch.from_numpy(data['label']).long().to(data_train.device)out = net(face_features, data['mesh'])     # 输入到网络# 反向传播loss = loss_fun(out, labels)loss.backward()optimizer.step()              # 参数更新# 测试net.eval()acc = 0for i, data in enumerate(data_test):with torch.no_grad():# 前向传播face_features = torch.from_numpy(data['face_features']).float()face_features = face_features.to(data_test.device).requires_grad_(False)labels = torch.from_numpy(data['label']).long().to(data_test.device)out = net(face_features, data['mesh'])# 计算准确率pred_class = out.data.max(1)[1]correct = pred_class.eq(labels).sum().float()acc += correctacc = acc / len(data_test)print('epoch: %d, TEST ACC: %0.2f' % (epoch, acc * 100))

  1. PCT:Point cloud transformer ↩︎

  2. 从零开始网格上的深度学习-1:输入篇(Pytorch) ↩︎

  3. 从零开始网格上的深度学习-2:卷积网络CNN篇 ↩︎

相关内容

热门资讯

王凤英入职小鹏3年终获股权,此... 5月7日消息,小鹏汽车披露的监管及年报信息显示,公司总裁王凤英已正式进入股东名册,入职小鹏3年后股权...
五块钱红酒卖断货,便宜红酒为何... 最近一段时间,中国的酒类消费市场可以说是显得格外奇怪,一方面,各种高端酒特别是白酒的消费量出现了明显...
财联社C50风向指数调查:4月... 财联社5月8日讯(记者 夏淑媛)新一期财联社“C50风向指数”结果显示,市场机构对4月新增人民币贷款...
央视硬刚国际足联拒掏20亿,背... 作者| 史大郎&猫哥 来源| 是史大郎&大猫财经Pro 央视这次太刚了,离世界杯开幕还有1个月,死活...
新CEO上任直接放大招!Air... 快科技5月8日消息,苹果即将上任的CEO John Ternus对未来一系列新产品充满信心,称这些设...
“特朗普拟邀英伟达、波音等CE... 据路透社当地时间5月7日报道,特朗普政府正邀请英伟达、苹果、埃克森美孚、波音等大公司首席执行官,于下...
世界杯,还能看到直播吗? 2026年美加墨世界杯距离开幕,仅剩一个多月时间。多方信息显示,中央广播电视总台(以下简称“央视”)...
机构警告AI芯片热潮风险,超威... 5月7日,据央视财经,隔夜超威半导体公司(AMD)股价飙升近19%,带动AI芯片热潮持续升温。AMD...
银行员工转走储户1800万最新... 银行员工转走储户1800万最新进展:2名储户已收到银行全部款项
原创 中... 1994年,安徽省的经济格局曾发生过一次戏剧性的转折。在那一年,一座名为安庆的城市,其国内生产总值(...
昆都仑区:政策“蓄力”消费焕新 “一台5000多元的空调,叠加‘国补’和商场的以旧换新活动,能优惠1000元左右,旧机还能免费上门拆...
乐悦置业竞得佛山顺德乐从镇一商... 观点网讯:5月6日,佛山市顺德区乐从镇一商业地块成功出让,由广东省乐悦置业有限公司竞得,乐从南区·邻...
原创 亦... 《爱情没有神话》这部剧,一开始的命运颇为多舛,经历了几次撤档的波折后,终于在观众面前亮相,但其首播的...
美联储34年最大分歧叠加油价飙... 美联储按预期维持利率不变,但内部出现34年来最严重分歧,叠加布油创2022年6月以来新高,美债遭抛售...
支付宝消费券回收后,资金是否支... 摘要: 支付宝消费券回收变现后,资金能否直接转入信用卡?本文解答到账方式的相关规则,帮助用户了解资金...
中医介绍5个化痰穴位!收藏这篇... 很多人忽略了“痰”的危害,觉得咳几下就没事,殊不知,肺里的痰长期堆积,只会一步步加重身体负担。 中医...
黄金平台“杰我睿”涉嫌经济犯罪... 红星资本局5月7日消息,深圳水贝知名金店“杰我睿”兑付困难事件有了新进展。日前,深圳市公安局罗湖分局...
多地出台购房新政促楼市升温 记... 今年的“五一”假期,伴随着多个城市楼市新政密集落地,在叠加市场信心持续修复的作用下,房地产市场热度持...
谁是五一“吸金王”?这5座城市... 来源:市场资讯 (来源:21城市观) 哪座城市成为“五一”假期的大赢家? 图源:摄图网 作者|赵晓...
“低招低裁”格局稳固劳动力市场... 智通财经APP获悉,美国上周初请失业金人数在经历前一周回落至近几十年来最低水平后出现小幅反弹,表明尽...