PyTorch深度学习实践——加载数据集
admin
2024-02-11 04:50:16
0

参考资料

参考资料1:https://blog.csdn.net/bit452/article/details/109686474
参考资料2:http://biranda.top/Pytorch%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0009%E2%80%94%E2%80%94%E5%85%B3%E4%BA%8E%E6%95%B0%E6%8D%AE%E9%9B%86/#%E6%95%B0%E6%8D%AE%E9%9B%86

加载数据集

#Definition: Epoch   One forward pass and one backward pass of all the training examples.
#Definition: Batch-Size   The number of training examples in one forward backward pass.
#Definition: Iteration   Number of passes, each pass(传递) using [batch size] number of examples.
#例如有10000 个样本, Batch—Size为1000,Iteration为10import torch
import numpy as np
#DataSet是抽象类,无法实例化
from torch.utils.data import Dataset
#DataLoader可实例化
from torch.utils.data import DataLoaderclass DiabetesDataset(Dataset):def __init__(self,filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)#获得数据集长度self.len=xy.shape[0] #shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1]) #取所有行,除最后一列的所有列self.y_data = torch.from_numpy(xy[:, [-1]])#取所有行,和最后一列#↓要实现__getitem__和__len__,不然DataLoader时会报错。#获得索引方法def __getitem__(self, index):return self.x_data[index], self.y_data[index] #return x,y 相当于返回元组(x,y)#获得数据集长度def __len__(self):return self.lendataset = DiabetesDataset('../diabetes.csv')
#num_workers表示多线程的读取
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)# batch_size是batch数量,shuffle是是否打乱顺序,num_workers是设计几个线程class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(reduction='mean')optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#在windows中利用多线程读取,需要将主程序(对数据操作的程序)封装到函数中.这里不加main的话,会报错
#如果设置了num_workers之后,会出现运行时错误,需要将其封装到函数里面或者if语句里面,比如这里加上了if __name__ =='__main__':就不报错了
if __name__ =='__main__':for epoch in range(100):#enumerate:可获得当前迭代的次数for i,data in enumerate(train_loader,0):#enumerate中的o是指起始位置(下标)为0,即i从0开始,i=0,1,2......#准备数据dataloader会将按batch_size返回的数据整合成矩阵加载inputs, labels = data#input.shape=([32, 8]),最后一次是([23, 8])print(inputs.shape)#前馈y_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())#反向传播optimizer.zero_grad()loss.backward()#更新optimizer.step()

相关内容

热门资讯

消息称百度旗下昆仑芯瞄准500... 6 月 29 日消息,据《The Information》昨日援引知情人士消息,百度旗下 AI 芯片...
打造夏日消费新场景 第35届北... 北京商报讯(记者 翟枫瑞)6月29日消息,第35届北京国际燕京啤酒文化节新闻发布会在京举行。本届啤酒...
社保基金持仓数据出炉,一季度增... 最近各大上市公司一季度财报都公开了,咱们国家社保基金的持仓数据也全部曝光。目前社保拿着比亚迪价值44...
36氪首发 | 海思、中兴团队... 作者 | 乔钰杰 编辑 | 袁斯来 硬氪获悉,广州宸思通讯科技有限公司(以下简称“宸思科技”)近日完...
两天蒸发47亿市值!一纸税务通... 一纸税务通知书,能让一家百亿龙头两天蒸发47亿市值。 6月22日,北大荒(600598.SH)公告称...
SK海力士将投资1100万亿韩... SK集团会长崔泰源6月29日在韩国“三大重大计划”发布会上宣布,公司将投资1100万亿韩元扩大半导体...
两只A股,终止上市! 两家A股公司,即将摘牌。 6月29日,退市沪科(600608.SH)公告称,上海证券交易所将在202...
原创 M... 一家成立近十年的自动驾驶公司,在IPO时吸引了14家基石投资者认购近一半的发行股份,其中不乏奔驰、比...
基金忠言|国寿安保滤镜碎,三年... 图片来源:视觉中国 蓝鲸新闻6月29日讯(记者 祁和忠)保险系基金公司国寿安保总经理换人了。 6月2...
三星电机计划加码玻璃基板!相关... 6月29日,玻璃基板概念股午后有所回升, 华工科技(000988.SZ)逼近涨停, 彩虹股份(600...
拉萨海关持续壮大外贸经营主体 ...   新华网拉萨6月28日电(记者蒋梦辰)近日,记者从拉萨海关获悉,今年前5个月,西藏有进出口实绩的外...
机构:二季报临近,医药生物板块... 6月29日,华源证券发布了一篇医药生物行业的研究报告,报告指出,业绩期临近,产业链景气度有望再次迎来...
每日收评科创50放量涨超4.5... 财联社6月29日讯,三大指数全线收红,创业板指探底回升,科创50指数大涨4.61%。沪深两市成交额3...
6月多地土拍结构性升温:深圳单... 进入2026年6月,不少城市核心区地块集中诞生高溢价宗地,热度突出的城市包含深圳、杭州、长沙。 其中...
业绩炸裂!盛达资源半年预盈3.... 6月29日,贵金属矿山龙头盛达资源(000603.SZ)发布 2026 年半年度业绩预告,上半年业绩...
A股午后拉升三大股指收涨:半导... A股三大股指6月29日开盘涨跌互现。早盘沪强深弱,创指一度跌超2%。半导体午后拉升,带动两市上涨,沪...
原创 空... 前言 大家好,我是老金。 这几天,两幅极度割裂的画面放在一起,把我看笑了。 一边是在持续的热浪下,欧...
澳大利亚审慎监管局拟放宽银行风... 澳大利亚审慎监管局(APRA)6月29日就修改 银行信用风险资本设定公开征求意见,旨在加大信贷投放以...
全民炒股,急踩刹车!韩国股市突... 屈红燕/证券时报网 全民狂欢、交易高度拥挤、杠杆资金猛增、新入市投资者表现激进、大型IPO吸金等现象...