【NLP】LSTM 唐诗生成器 pytorch 版
admin
2024-04-03 13:13:31
0

参考这篇文章LSTM唐诗生成器Keras版

将相关的 keras 模型代码进行修改,改成对应的 pytorch 模型,现将有区别的部分放在这里。

训练模型

搭建网络

# 把keras 模型改成 pytorch 模型
# 建立LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as F# 设置 CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Sequential()
# model.add(Embedding(10000, 128, input_length=20))
# model.add(LSTM(128, return_sequences=True))
# model.add(Dropout(0.2))
# model.add(LSTM(128))
# model.add(Dropout(0.2))
# model.add(Dense(10000, activation='softmax'))# 参考上述的 keras 模型,建立 pytorch 模型# 第二层 LSTM 只取最后一个输出,所以 return_sequences=Falseclass LSTMNet(nn.Module):def __init__(self):super(LSTMNet, self).__init__()self.embedding = nn.Embedding(10000, 128)self.lstm1 = nn.LSTM(input_size=128, hidden_size=128, num_layers=1, batch_first=True)self.dropout1 = nn.Dropout(0.2)self.lstm2 = nn.LSTM(input_size=128, hidden_size=128, num_layers=1, batch_first=True)self.dropout2 = nn.Dropout(0.2)self.fc = nn.Linear(128, 10000)def forward(self, x):x = self.embedding(x) # [batch_size, seq_len, embedding_size]x, _ = self.lstm1(x)  # [batch_size, seq_len, hidden_size]x = self.dropout1(x)  # [batch_size, seq_len, hidden_size]x, _ = self.lstm2(x)  # [batch_size, seq_len, hidden_size]x = self.dropout2(x)  # [batch_size, seq_len, hidden_size]x = x[:, -1, :] #       这里-1的意思是:取最后一个输出 [batch_size, hidden_size]x = self.fc(x)  #       [batch_size, 10000]return x
# 实例化模型
model = LSTMNet().to(device)
model

LSTMNet(
(embedding): Embedding(10000, 128)
(lstm1): LSTM(128, 128, batch_first=True)
(dropout1): Dropout(p=0.2, inplace=False)
(lstm2): LSTM(128, 128, batch_first=True)
(dropout2): Dropout(p=0.2, inplace=False)
(fc): Linear(in_features=128, out_features=10000, bias=True)
)

Pytorch 数据转换

注意:因为 y_train 和 y_test [batch, 1] 最后一个维度是没用的,
所以要把它去掉,变成 [batch] 才能正常给交叉熵损失函数计算

# 先把 x_train, x_test, y_train, y_test 转化为 tensor
x_train = torch.tensor(x_train).to(device)
x_test = torch.tensor(x_test).to(device)
y_train = torch.tensor(y_train).to(device)
y_test = torch.tensor(y_test).to(device)
# 测试样本能否正常输入网络
pred = model(x_train[0:3].to(device))
print(x_train[0:3].shape) # [3, 20] # 3个样本,每个样本20个词
print(pred.shape) # [3, 10000]     #  3个样本,每个样本10000个分类

torch.Size([3, 20])
torch.Size([3, 10000])

# 因为 y_train 和 y_test [batch, 1] 最后一个维度是没用的,
# 所以要把它去掉,变成 [batch] 才能正常给交叉熵损失函数计算
y_train = y_train.squeeze()
y_test = y_test.squeeze()# 转化成 Long
y_train = y_train.long()
y_test = y_test.long()# 查看形状
y_train.shape,y_test.shape

(torch.Size([39405]), torch.Size([16889]))

训练模型

# 训练模型
import torch.optim as optim
from tqdm import tqdm
optimizer = optim.Adam(model.parameters(), lr=0.001)batch_size = 256
epochs = 20# 注意,这里 y_train, y_test 的形状都是 [batch, 1] ,也就是说,并不是 one-hot 编码
# 所以,损失函数用的是 CrossEntropyLossloss_func = nn.CrossEntropyLoss()
for epoch in range(epochs):print('Epoch: ', epoch)for i in tqdm(range(0, len(x_train), batch_size)):x_batch = x_train[i:i+batch_size]y_batch = y_train[i:i+batch_size]pred = model(x_batch)loss = loss_func(pred, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()# 每个 epoch 结束后,计算一下准确率# 训练集准确率pred = model(x_train)pred = torch.argmax(pred, dim=1)acc = (pred == y_train).sum().item() / len(y_train)print('Train acc: ', acc)# 测试集准确率pred = model(x_test)pred = torch.argmax(pred, dim=1)acc = (pred == y_test).sum().item() / len(y_test)print('Test acc: ', acc)

Epoch: 0
100%|██████████| 154/154 [00:38<00:00, 4.01it/s]
Train acc: 0.10216977540921203
Test acc: 0.10320326839955
Epoch: 1

Epoch: 19
100%|██████████| 154/154 [00:37<00:00, 4.09it/s]
Train acc: 0.20576069026773253
Test acc: 0.17970276511338742

test_string = '白日依山盡,黃河入海流,欲窮千里目,更上一'for i in range(300):# 循环 300 步,每步都要预测一个字test_string_token = tokenizer.texts_to_sequences([test_string[-20:]]) # 取最后20个字test_string_mat = np.array(test_string_token)pred = model(torch.tensor(test_string_mat).to(device)) # pred 的形状是 [1, 10000]pred_argmax = torch.argmax(pred, dim=1).item()         # pred_argmax 的形状是 [1]# 把预测的字转化为文字tokenizer.index_word[pred_argmax]test_string = test_string + tokenizer.index_word[pred_argmax]
print(test_string)

相关内容

热门资讯

原创 壳... 编辑:XL 国际能源圈最近炸开了锅,壳牌这家百年石油巨头在2026年3月与委内瑞拉政府正式签署多项油...
存储热潮愈演愈烈!奖金拿到手软... 财联社5月24日讯(编辑 卞纯)在席卷全球的存储芯片热潮中,韩国“存储芯片双雄”SK海力士和三星无疑...
揽牌、合作、生态,跨境支付头部... 近日,国内头部跨境支付机构密集落地海外重要布局,一方面,连连数字、PingPong两家公司相继在中东...
原创 帮... 老铁们,周末好!我是帮主郑重。刚扫了一眼下周的财经日历,好家伙,事件一个接一个,堪称“消息面轰炸周”...
海南省住建厅与中国石化海南石油... 5月22日,中国石化海南石油分公司代表、党委书记李新强、总经理蔡文东一行赴海南省住建厅拜访交流。省住...
原创 金... 2026年5月22日,国际黄金价格报4536.7美元/盎司,较前期高点5597美元回落约1100美元...
“双标”换卡背后,银行还需多些... 新华社记者 颜之宏、杨深深 持到期银行卡和身份证去银行网点换新卡,却被要求“必须交回旧卡才能取新卡”...
“离境退税2.0”带动“中国购... 【环球时报综合报道】编者的话:5月18日,商务部等6部门联合发布《关于加力优化离境退税措施扩大入境消...
一年烧掉2000亿、市值蒸发3... 商业润点 |Biz Run Review 三国归晋,用了六十年。即时零售的"三国杀",才刚刚开局...
原创 金... 2026年5月22日,国内黄金市场呈现出令人咋舌的价格鸿沟。基础金价徘徊在每克995.3元,而回收价...
原创 人... SpaceX的星舰V3终于在全球瞩目中成功升空。北京时间5月23日清晨,这颗高达124米的巨型火箭顺...
原创 被... 5月19日,欧洲议会掀起了一场引人注目的风暴,以压倒性的票数通过了最新的钢铁进口规定。 这套规则...
光纤量价齐升,烽火通信加快布局... 烽火通信(600498)5月22日披露的投资者关系活动记录表显示,公司于5月21日参加了中国信息通信...
原创 突... 今天5月24日一大早,打开行情一看,国际现货黄金报4508.25美元/盎司,单日跌了26.68美元,...
企业快讯 | 携手联通!狄耐克... 狄耐克 厦门总商会副会长企业 厦门狄耐克智能科技股份有限公司 与中国联通厦门分公司 将5G智慧“嵌入...
美银策略师警告:SpaceX与... 环球网 据彭博社报道,美国银行首席投资策略师迈克尔·哈特奈特(Michael Hartnett)最新...
卸任55天后,知名基金经理任相... 【导读】卸任55天后,知名基金经理任相栋“奔私”谜底揭晓 见习记者 闫军 知名基金经理任相栋“奔私”...
原创 大... “免签+手机刷一切”就能让老外连夜订机票?2026年一季度,阿根廷人来华暴涨九倍,北京三源里菜市场三...
从泰山顶峰掉落!“大佬背后的大... 文/刘工昌 他曾是柳传志的“大哥”,助力联想完成混合所有制改革;是史玉柱眼中的“贵人”,帮他东山再起...
原创 2... 最近网上流传出一份2030年GDP10强预测榜单,其中一些城市位次的变化也挺有趣的。上海排在第一,深...