深度学习-4 Linear regression for Pytorch
创始人
2025-05-28 08:56:58
0

线性回归的简洁复现

        • 1-torch.utils.data 数据处理
        • 2-torch.nn.Module 模型的定义
        • 3-nn.Sequential 容器搭建网络
        • 4-net.parameters() 查看模型学习的参数
        • 6-线性回归代码实现

1-torch.utils.data 数据处理

用于读取数据,高效的处理数据方面
在这里插入图片描述

from torch.utils.data import DataLoader
dataset = my_dataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
for i, batch_data in enumerate(dataloader):# 处理batch_data

2-torch.nn.Module 模型的定义

写一个自己的网络

from torch import nnclass LinearNet(nn.Module):def __init__(self,n_feature):super(LinearNet,self).__init__()self.linear = nn.Linear(n_feature,1)#定义前向传播def forward(self,x):y = self.linear(x)return y
net = LinearNet(num_inputs)
print(net)
#输出网络结构
==> LinearNet((linear): Linear(in_features=2, out_features=1, bias=True))
net.linear#Linear(in_features=2, out_features=1, bias=True)
net.linear.weight
#net[0]这样根据下标访问子模块的写法只有当net是个ModuleList或者Sequential实例时才可以
#这里不允许

3-nn.Sequential 容器搭建网络

利用序列搭建网络,Sequential是一个有序的容器,网络层将按照在传入Sequential的顺序依次被添加到计算图中

# 写法一
net = nn.Sequential(nn.Linear(num_inputs, 1)# 此处还可以传入其他层)
# 写法二
net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))
# net.add_module ......
# 写法三
from collections import OrderedDict
net = nn.Sequential(OrderedDict([('linear', nn.Linear(num_inputs, 1))# ......]))

4-net.parameters() 查看模型学习的参数

net.parameters() 来查看模型所有的可学习参数,此函数将返回一个生成器。

for param in net.parameters():print(param)

torch.nn仅支持输入一个batch的样本不支持单个样本输入,如果只有单个样本,可使用input.unsqueeze(0)来添加一维。

参数初始化
init.normal_将权重参数每个元素初始化为随机采样于均值为0、标准差为0.01的正态分布。
偏差会初始化为零。

from torch.nn import init
init.normal_(net[0].weight, mean=0, std=0.01)
init.constant_(net[0].bias, val=0)  
# 也可以直接修改bias的data: net[0].bias.data.fill_(0)

6-线性回归代码实现

import torch
import numpy as np
import torch.utils.data as Data
from torch.nn import init#创建数据集
true_w = [2, -3.4]
true_b = 4.2
X = torch.tensor(np.random.normal(0,1,(1000,2)),dtype=torch.float)
Y = true_w[0]*X[:,0]+true_w[1]*X[:,1]+true_b
Y += torch.tensor(np.random.normal(0,0.01,size=y.size()),dtype=torch.float)#读取数据集
batch_size=10
dataset = Data.TensorDataset(X,Y)#将训练集的特征和标签组合
data_iter = Data.DataLoader(dataset,batch_size,shuffle=True)#随机读取小批量#定义模型
net=torch.nn.Sequential(torch.nn.Linear(2,1))
#参数初始化
init.normal_(net[0].weight,0,1)
init.constant_(net[0].bias,0)#定义损失函数 优化器
loss=torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)num_epocs=10
for epoch in range(1,num_epochs+1):for X,y in data_iter:y_pre=net(X)l = loss(y_pre,y.view(-1,1))optimizer.zero_grad()#梯度清零l.backward()#反向传播optimizer.step()print('epoch %d,loss:%f'%(epoch,l.item()))
print(true_w ,net[0].weight)
print(true_b ,net[0].bias)

相关内容

热门资讯

刚刚,大跳水!超42万人爆仓!... 来源:券商中国 加密货币,遭遇抛售潮! 凯文·沃什被提名为下一任美联储主席所产生的后续效应,正持续波...
做好银行网点“加减法” 国家金融监督管理总局网站披露的信息显示,2025年共有约1.1万家银行业金融机构的线下网点获准退出,...
金价暴跌引热议,网友:商场门口... 来源:中国基金报 随着国际金价急速下跌,国内首饰金价也迎来大幅回调。 1月31日,老庙报1546元/...
内蒙古一银行员工将储户220万... 内蒙古一银行员工将储户220万元存款转走并挥霍,银行称员工已离岗不愿承担赔偿 1月31日,有媒体报...
老年医学科进修轶事|老年医学如... 和年苑,北京协和医院老年医学科公众号,传递老年医学的价值和声音 在这里,了解当代老年医学 Autum...
和讯投顾余兴栋:周五杀跌,下周... 周五大盘大幅度的杀跌又探底回升,收出一根长长的下影线,不少的朋友又在问我,那这根k线是不是就意味着调...
【数智周报】马化腾评豆包手机;... 【数智周报将整合本周最重要的企业级服务、云计算、大数据领域的前沿趋势、重磅政策及行研报告。】 观点马...
和美字节,用字节连接和美 和美字节(Hemei Byte),是杭州桑桥网络科技有限公司于 2026 年 1 月完成品牌升级后启...
仙乐健康56岁副总姚壮民业务员... 瑞财经 刘治颖 1月29日,仙乐健康科技股份有限公司(以下简称:仙乐健康)向港交所主板递交上市申请书...
詹姆斯下家概率:骑士最高退役第... 近日,有关詹姆斯的未来引发了大众的热议,相关机构也更新了这位巨星的下家概率,回归骑士是最大可能。 相...
原创 猛... 在国际金价屡创历史新高之时,资本市场正经历一场有趣的分化:有人急于套现离场,有人却大举加码。近日,一...
原创 男... 在爱情的海洋中,星座与情感交织出无数动人的故事。当一个男性用以下这四个称呼来称呼你时,他的爱情之舟正...
民航持续回暖:南航、海航预计去... 时隔五年,南航预计在三大航中率先实现年度扭亏。 截至1月30日晚间,中国国航(601111.SH)、...
公募加仓非银金融,后市机会如何... 基金增配保险、券商股。 最新数据显示,公募基金2025年四季度的非银金融仓位提高1个百分点。继有色金...
赵慧芳主任中医治疗产后“月子病... 赵慧芳主任中医治疗产后“月子病”的临床智慧 产后调理是中华民族传承千年的养生智慧,在中医理论中占据重...
江西万年青水泥股份有限公司20... 本公司及董事会全体成员保证信息披露的内容真实、准确、完整,没有虚假记载、误导性陈述或重大遗漏。 一、...
科学应对甲状腺结节,别让“结节... 随着健康意识的提升 超声检查在体检中普及率不断提高 甲状腺结节的检出率也显著上升 不少人拿着“结节”...
春节前,政府债发行提速 来源:郁言债市 01 1月资金面,两轮波动,中枢平稳 回顾开年以来资金利率走势,月内资金经历两轮波动...
【央行多措并举护航,专家预期节... 【央行多措并举护航,专家预期节前流动性保持充裕】1月29日,中国人民银行以固定利率、数量招标方式开展...
季节性因素叠加市场需求不足,1... 来源:界面新闻 记者 辛圆 国家统计局周六公布数据显示,1月份,中国制造业采购经理人指数(PM...