pytorch实战(二)——搭建卷积神经网络(CNN)
admin
2024-02-15 14:02:31
0

笔者以前都是用tensorflow做深度学习,tensorflow系列教程见tensorflow实战——一位安分的码农
后来做目标检测用yolo的时候,发现pytorch真香。yolo系列教程见yolov5实战——一位安分的码农
终于抽出时间系统学习pytorch了,开干!
此文讲的是基于pytorch,利用class和sequential搭建卷积神经网络。
关于如何配置pytorch环境,我很早就做过了,见pytorch实战(一)——环境配置教程(基于Anaconda)

一、利用Class和Sequential搭建CNN

此种方法更易懂,推荐采用Class和Sequential进行搭建

import torch
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),nn.ReLU(),nn.BatchNorm2d(num_features=6),nn.MaxPool2d(kernel_size=2),)self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),nn.ReLU(),nn.BatchNorm2d(num_features=16),nn.MaxPool2d(kernel_size=2))self.layer3 = nn.Sequential(# nn.Flatten(),nn.Linear(in_features=16 * 5 * 5, out_features=120),nn.ReLU(),nn.Linear(in_features=120, out_features=84),nn.ReLU(),nn.Linear(in_features=84, out_features=10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = x.reshape(x.size(0), -1)x = self.layer3(x)return x# N是批大小; D_in 是输入维度;
# H 是通道数; D_out 是输出维度
N, D_in, H, D_out = 1, 32, 6, 10# 产生输入和输出的随机张量
x = torch.randn(1, 1, 32, 32)
y = torch.randn(1, 10)
x = x.to(device)
y = y.to(device)# 通过实例化上面定义的类来构建我们的模型。
model = ConvNet().to(device)# 构造损失函数和优化器。
# SGD构造函数中对model.parameters()的调用,
# 将包含模型的一部分,即两个nn.Linear模块的可学习参数。
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):# 前向传播:通过向模型传递x计算预测值yy_pred = model(x)#计算并输出lossloss = criterion(y_pred, y)print(t, loss.item())# 清零梯度,反向传播,更新权重optimizer.zero_grad()loss.backward()optimizer.step()

二、利用Clas搭建CNN

此种方法阅读网络结构的时候,没有第一方法简便,因此不推荐采用此种方法

import torch
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.batchNorm1 = nn.BatchNorm2d(6)self.conv2 = nn.Conv2d(6, 16, 5)self.batchNorm2 = nn.BatchNorm2d(16)# an affine operation: y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# Max pooling over a (2, 2) windowx = F.relu(self.conv1(x))x = self.batchNorm1(x)x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = self.batchNorm2(x)x = F.max_pool2d(x, 2)x = x.reshape(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# N是批大小; D_in 是输入维度;
# H 是通道数; D_out 是输出维度
N, D_in, H, D_out = 1, 32, 6, 10# 产生输入和输出的随机张量
x = torch.randn(1, 1, 32, 32)
y = torch.randn(1, 10)
x = x.to(device)
y = y.to(device)# 通过实例化上面定义的类来构建我们的模型。
model = ConvNet().to(device)# 构造损失函数和优化器。
# SGD构造函数中对model.parameters()的调用,
# 将包含模型的一部分,即两个nn.Linear模块的可学习参数。
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):# 前向传播:通过向模型传递x计算预测值yy_pred = model(x)#计算并输出lossloss = criterion(y_pred, y)print(t, loss.item())# 清零梯度,反向传播,更新权重optimizer.zero_grad()loss.backward()optimizer.step()

备注:
pyorch官方教程的中文文档地址为https://pytorch.panchuang.net/ThirdSection/LearningPyTorch/

相关内容

热门资讯

全球首套,中天科技交付220k... IT之家 5 月 11 日消息,据中天科技集团消息,近日,中天科技交付全球首套 220kV 3500...
原创 北... 北京业主刚以488万的价格卖掉了自己的二手房,三天后宁愿付违约金,也要把房子拿回来。转手加价70多万...
“硬科技”场内基金频发溢价风险... 【导读】硬科技场内基金频发溢价风险提示 中国基金报记者天心 日前,多只聚焦海内外半导体芯片方向的场内...
伯希和再闯港股陷更名争议,CE... 5月8日,国内户外运动品牌伯希和(PELLIOT)再度向港交所递交上市申请,中金公司与中信证券担任联...
一季度货币政策报告明确:引导隔... 5月11日,人民银行披露一季度中国货币政策执行报告,指出下一步将引导隔夜利率在政策利率水平附近运行,...
科博会观察|能源转型的“下半场... 今年4月,光伏龙头隆基绿能发布“全栈隆基LONGi ONE”光储融合战略,这场发布会背后是公司对能源...
“茅台魔咒”失灵了?沪指站上4... 11日,沪指走出“八连阳”,站上4200点,创下自2015年6月26日以来的收盘点位新高。 板块方...
沪指涨0.94%站上4200点... 扬子晚报网5月11日讯(记者 范晓林)截至午盘,沪指站上4200点,创业板指大涨突破3900点,为2...
ETF周评:4200点之前,“... “五一”假期后的首个交易周(5月6日至5月8日),A股虽仅有短短三个交易日,却展现出强劲的做多动能。...
集智达GNS-2446主板赋能... 当前医疗自助终端面临四大行业痛点:多任务并发算力瓶颈;外设兼容集成难题;数据安全合规压力;复杂环境稳...
动荡市场中锚定稀缺确定性,新能... 3月以来,美伊冲突导致全球能源价格出现大幅波动。整体上看,本轮地缘冲突的复杂性和影响深度远超以往,加...
CFA协会:未来金融人才需具备... 由特许金融分析师协会(CFA协会)、北京市金融发展促进中心共同主办的2026第五届中国未来金融分析师...
最熟悉的“国民理财神器”,让你... 1万元放进余额宝,一天收益只有0.24元,连个鸡蛋都买不起。这不是某个冷门产品,而是那个曾创下6.7...
Circle从贝莱德等机构融资... 来源:环球市场播报 核心要点 Circle 互联网集团在其全新 Arc 区块链关联代币预售中融资...
张尧浠:美伊局势变数不断 金价... 来源:市场资讯 5月11日:黄金市场上周:国际黄金伦敦金触底回升收涨,再度收取垂线止跌看涨形态,但上...
中澳企业拓展新能源合作 来源:人民日报 2026年澳大利亚智慧能源展日前在悉尼国际会议中心举行。当前,中东局势引发全球能源...
七类技能培训“套路”曝光 中消... 记者今天(11日)从中消协获悉,近年来,各类技能培训迅速扩张,新型培训模式不断涌现,部分经营者借助新...
每日收评沪指涨超1%站上420... 财联社5月11日讯,市场全天震荡走强,沪指站上4200点,创业板指大涨突破3900点,为2015年6...
A股三大指数集体上涨:沪指站上... 观点网讯:5月11日,A股三大指数集体上涨,截至收盘,上证指数涨1.08%站上4200点,深证成指涨...