参考资料1:https://blog.csdn.net/bit452/article/details/109686474
参考资料2:http://biranda.top/Pytorch%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0009%E2%80%94%E2%80%94%E5%85%B3%E4%BA%8E%E6%95%B0%E6%8D%AE%E9%9B%86/#%E6%95%B0%E6%8D%AE%E9%9B%86
#Definition: Epoch One forward pass and one backward pass of all the training examples.
#Definition: Batch-Size The number of training examples in one forward backward pass.
#Definition: Iteration Number of passes, each pass(传递) using [batch size] number of examples.
#例如有10000 个样本, Batch—Size为1000,Iteration为10import torch
import numpy as np
#DataSet是抽象类,无法实例化
from torch.utils.data import Dataset
#DataLoader可实例化
from torch.utils.data import DataLoaderclass DiabetesDataset(Dataset):def __init__(self,filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)#获得数据集长度self.len=xy.shape[0] #shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1]) #取所有行,除最后一列的所有列self.y_data = torch.from_numpy(xy[:, [-1]])#取所有行,和最后一列#↓要实现__getitem__和__len__,不然DataLoader时会报错。#获得索引方法def __getitem__(self, index):return self.x_data[index], self.y_data[index] #return x,y 相当于返回元组(x,y)#获得数据集长度def __len__(self):return self.lendataset = DiabetesDataset('../diabetes.csv')
#num_workers表示多线程的读取
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)# batch_size是batch数量,shuffle是是否打乱顺序,num_workers是设计几个线程class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(reduction='mean')optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#在windows中利用多线程读取,需要将主程序(对数据操作的程序)封装到函数中.这里不加main的话,会报错
#如果设置了num_workers之后,会出现运行时错误,需要将其封装到函数里面或者if语句里面,比如这里加上了if __name__ =='__main__':就不报错了
if __name__ =='__main__':for epoch in range(100):#enumerate:可获得当前迭代的次数for i,data in enumerate(train_loader,0):#enumerate中的o是指起始位置(下标)为0,即i从0开始,i=0,1,2......#准备数据dataloader会将按batch_size返回的数据整合成矩阵加载inputs, labels = data#input.shape=([32, 8]),最后一次是([23, 8])print(inputs.shape)#前馈y_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())#反向传播optimizer.zero_grad()loss.backward()#更新optimizer.step()