pytorch实战(二)——搭建卷积神经网络(CNN)
admin
2024-02-15 14:02:31
0

笔者以前都是用tensorflow做深度学习,tensorflow系列教程见tensorflow实战——一位安分的码农
后来做目标检测用yolo的时候,发现pytorch真香。yolo系列教程见yolov5实战——一位安分的码农
终于抽出时间系统学习pytorch了,开干!
此文讲的是基于pytorch,利用class和sequential搭建卷积神经网络。
关于如何配置pytorch环境,我很早就做过了,见pytorch实战(一)——环境配置教程(基于Anaconda)

一、利用Class和Sequential搭建CNN

此种方法更易懂,推荐采用Class和Sequential进行搭建

import torch
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),nn.ReLU(),nn.BatchNorm2d(num_features=6),nn.MaxPool2d(kernel_size=2),)self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),nn.ReLU(),nn.BatchNorm2d(num_features=16),nn.MaxPool2d(kernel_size=2))self.layer3 = nn.Sequential(# nn.Flatten(),nn.Linear(in_features=16 * 5 * 5, out_features=120),nn.ReLU(),nn.Linear(in_features=120, out_features=84),nn.ReLU(),nn.Linear(in_features=84, out_features=10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = x.reshape(x.size(0), -1)x = self.layer3(x)return x# N是批大小; D_in 是输入维度;
# H 是通道数; D_out 是输出维度
N, D_in, H, D_out = 1, 32, 6, 10# 产生输入和输出的随机张量
x = torch.randn(1, 1, 32, 32)
y = torch.randn(1, 10)
x = x.to(device)
y = y.to(device)# 通过实例化上面定义的类来构建我们的模型。
model = ConvNet().to(device)# 构造损失函数和优化器。
# SGD构造函数中对model.parameters()的调用,
# 将包含模型的一部分,即两个nn.Linear模块的可学习参数。
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):# 前向传播:通过向模型传递x计算预测值yy_pred = model(x)#计算并输出lossloss = criterion(y_pred, y)print(t, loss.item())# 清零梯度,反向传播,更新权重optimizer.zero_grad()loss.backward()optimizer.step()

二、利用Clas搭建CNN

此种方法阅读网络结构的时候,没有第一方法简便,因此不推荐采用此种方法

import torch
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.batchNorm1 = nn.BatchNorm2d(6)self.conv2 = nn.Conv2d(6, 16, 5)self.batchNorm2 = nn.BatchNorm2d(16)# an affine operation: y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# Max pooling over a (2, 2) windowx = F.relu(self.conv1(x))x = self.batchNorm1(x)x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = self.batchNorm2(x)x = F.max_pool2d(x, 2)x = x.reshape(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# N是批大小; D_in 是输入维度;
# H 是通道数; D_out 是输出维度
N, D_in, H, D_out = 1, 32, 6, 10# 产生输入和输出的随机张量
x = torch.randn(1, 1, 32, 32)
y = torch.randn(1, 10)
x = x.to(device)
y = y.to(device)# 通过实例化上面定义的类来构建我们的模型。
model = ConvNet().to(device)# 构造损失函数和优化器。
# SGD构造函数中对model.parameters()的调用,
# 将包含模型的一部分,即两个nn.Linear模块的可学习参数。
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):# 前向传播:通过向模型传递x计算预测值yy_pred = model(x)#计算并输出lossloss = criterion(y_pred, y)print(t, loss.item())# 清零梯度,反向传播,更新权重optimizer.zero_grad()loss.backward()optimizer.step()

备注:
pyorch官方教程的中文文档地址为https://pytorch.panchuang.net/ThirdSection/LearningPyTorch/

相关内容

热门资讯

花旗在等待日本买家重回债市 奢... 来源:环球市场播报 日元自2024年8月份以来最强劲的三日涨势,仍不足以让花旗集团策略师Daniel...
和讯投顾徐剑波:震荡轮动 现在市场进入了真空期,又到了量化完全主导的行情了,我们盯好大盘这三个细节就够了。和讯投顾徐剑波分析,...
原创 短... 谁能想到,今天在短剧圈崭露头角、每一部作品都掀起热潮的四位演员,曾经在长剧领域几乎无名,甚至因为戏份...
我国竹产业年产值超5200亿元 新华社北京1月27日电(记者黄垚)记者27日从国家林草局获悉,近年来我国竹产业规模持续壮大,初步形成...
钧达股份H股盘中跌超15% 上证报中国证券网讯(记者 何治民)1月27日,钧达股份H股持续下挫,一度跌超15%。截至10时36分...
原创 空... 倾尽六个钱包,压上毕生积蓄,换来的“梦中情房”,那个被吹嘘得如同“会呼吸的空中花园”的居所,最终却让...
易方达黄金主题LOF:暂停申购... 每经编辑|张锦河 1月27日,易方达黄金主题LOF公告,1月28日起暂停A类人民币份额申购及定期定...
呼和浩特外贸成绩单里的“马力” ●王英 2025年,呼和浩特市外贸进出口总值264.3亿元,同比增长16.53%,其中,呼和浩特综合...
黄金下跌,白银深度回调!事关降... 1月27日晚间,黄金出现下跌,白银深度回调。 截至发稿,纽约期金报5094.1美元/盎司。 纽约期银...
寒假健康不“放假”丨爆笑情景剧... 1月17日,由长春市卫生健康委、长春市中医药管理局主办的“乐享寒假 健康相伴”健康科普宣传体验活动在...
原创 A... 来源:互联网江湖 作者:刘致呈 腾讯做AI社交的消息,爆了。 AI、社交这几乎是当今科技行业最有含金...
荷兰下议院批准为银行奖金上限制... 荷兰下议院批准为银行奖金上限制度松绑。
港股异动 | 钧达股份H股盘中... 1月27日,钧达股份H股持续下挫,一度跌超15%。截至10时36分,钧达股份H股跌13.83%,报3...
美光宣布NAND新厂建设,总投... 周二,美光科技宣布将在未来十年向新加坡追加投资240亿美元,用于建设新的NAND闪存晶圆厂,以应对人...
康宁美股盘前飙升超7%!报道:... 科技巨头meta已与老牌玻璃制造商康宁达成一项价值高达60亿美元的长期供货协议,以获取其数据中心所需...
广东江门50场重点促消费活动助... 中新社江门1月27日电 (记者 郭军)记者27日从江门市商务局了解到,江门紧扣“广货行天下”主题,将...
创始人丁文军“离场”,腾讯、红... 1月26日,南都湾财社记者从重庆市市场监管局公示的《经营者集中简易案件公示表》中获悉,川香四溢(上海...
你手里有“睡眠卡”吗?银行在清... 银行业加强对长期不动户的管理并非等同于销户,且卡里的钱并不会被“清零”。
万科“22万科MTN005”宽... 1月27日,万科A(000002.SZ)公告,根据关于万科企业(02202.HK)2022年度第五期...
坦洲创投基金签约 市镇合作招大... 1月27日,中山坦洲创业投资基金合伙企业(有限合伙)项目签约仪式成功举行。该基金由坦洲镇属企业中山市...