本文学习纪录
使用LSTM来实现文本匹配任务
使用LSTM(Long Short-Term Memory)网络来实现文本匹配任务是自然语言处理(NLP)中的一个常见应用。文本匹配任务的目标是确定两个文本段落是否在某种程度上相似或相关,例如在问答系统、文档检索、相似问题匹配等场景中非常有用。
模型构建
输入层
:两个独立的输入,分别对应两个文本序列。
LSTM层
:为每个输入文本设计一个LSTM层来捕获序列信息。可以使用双向LSTM(BiLSTM)来获取前后文信息。
相似度计算
:使用余弦相似度、曼哈顿距离、欧式距离等方法计算两个LSTM层的输出向量之间的相似度。
输出层
:根据相似度分数输出匹配程度,可以是二分类(匹配或不匹配)或者回归(相似度得分)。
定义网络
# 定义网络结构
class LSTM(nn.Module):
def __init__(self, vocab_size, hidden_dim, num_layers, embedding_dim, output_dim):
super(LSTM, self).__init__()
self.hidden_dim = hidden_dim # 隐层大小
self.num_layers = num_layers # LSTM层数
# 嵌入层,会对所有词形成一个连续型嵌入向量,该向量的维度为embedding_dim
# 然后利用这个向量来表示该字,而不是用索引继续表示
self.embeddings_x = nn.Embedding(vocab_size + 1, embedding_dim)
self.embeddings_y = nn.Embedding(vocab_size + 1, embedding_dim)
# 定义LSTM层,第一个参数为每个时间步的特征大小,这里就是每个字的维度
# 第二个参数为隐层大小
# 第三个参数为lstm的层数
self.lstm_x = nn.LSTM(embedding_dim, hidden_dim, num_layers)
self.lstm_y = nn.LSTM(embedding_dim, hidden_dim, num_layers)
self.cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
# 利用全连接层将其映射为2维,即0和1的概率
self.fc = nn.Linear(1, output_dim)
def forward(self, x_input, y_input):
# 1.首先形成嵌入向量
embeds_x = self.embeddings_x(x_input)
embeds_y = self.embeddings_y(x_input)
# 2.将嵌入向量导入到lstm层
output_x, _ = self.lstm_x(embeds_x)
output_y, _ = self.lstm_x(embeds_y)
timestep, batch_size, hidden_dim = output_x.shape
output_x = output_x.reshape(timestep, batch_size, -1)
output_y = output_y.reshape(timestep, batch_size, -1)
# 3.获取lstm最后一个隐层表示向量
output_x = output_x[-1]
output_y = output_y[-1]
# 4.计算两个向量的余弦相似度
sim = self.cos_sim(output_x, output_y)
sim = sim.view(-1, 1)
# 5.形成最终输出结果
output = self.fc(sim)
return output
模型训练
# 6.模型训练
model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,
embedding_dim=embedding_dim, output_dim=output_dim)
Configimizer = optim.Adam(model.parameters(), lr=lr) # 优化器
criterion = nn.CrossEntropyLoss() # 多分类损失函数
model.to(device)
loss_meter = meter.AverageValueMeter()
best_acc = 0 # 保存最好准确率
best_model = None # 保存对应最好准确率的模型参数
for epoch in range(epochs):
model.train() # 开启训练模式
epoch_acc = 0 # 每个epoch的准确率
epoch_acc_count = 0 # 每个epoch训练的样本数
train_count = 0 # 用于计算总的样本数,方便求准确率
loss_meter.reset()
train_bar = tqdm(train_loader) # 形成进度条
for data in train_bar:
x_input, y_input, label = data # 解包迭代器中的X和Y
x_input = x_input.long().transpose(1, 0).contiguous()
x_input = x_input.to(device)
y_input = y_input.long().transpose(1, 0).contiguous()
y_input = y_input.to(device)
Configimizer.zero_grad()
# 形成预测结果
output_ = model(x_input, y_input)
# 计算损失
loss = criterion(output_, label.long().view(-1))
loss.backward()
Configimizer.step()
loss_meter.add(loss.item())
# 计算每个epoch正确的个数
epoch_acc_count += (output_.argmax(axis=1) == label.view(-1)).sum()
train_count += len(x_input)
# 每个epoch对应的准确率
epoch_acc = epoch_acc_count / train_count
# 打印信息
print("【EPOCH: 】%s" % str(epoch + 1))
print("训练损失为%s" % (str(loss_meter.mean)))
print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')
# 保存模型及相关信息
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model = model.state_dict()
# 在训练结束保存最优的模型参数
if epoch == epochs - 1:
# 保存模型
torch.save(best_model, './best_model.pkl')
测试语句
try:
# 数据预处理
input_shape = 20 # 序列长度,就是时间步大小,也就是这里的每句话中的词的个数
# 用于测试的话
sentence1 = "我不爱吃剁椒鱼头,但是我爱吃鱼头"
sentence2 = "我爱吃土豆,但是不爱吃地瓜"
# 将对应的字转化为相应的序号
x_input = [[word2idx[word] for word in sentence1]]
x_input = pad_sequences(maxlen=input_shape, sequences=x_input, padding='post', value=0)
x_input = torch.from_numpy(x_input)
y_input = [[word2idx[word] for word in sentence2]]
y_input = pad_sequences(maxlen=input_shape, sequences=y_input, padding='post', value=0)
y_input = torch.from_numpy(y_input)
# 加载模型
model_path = './best_model.pkl'
model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,
embedding_dim=embedding_dim, output_dim=output_dim)
model.load_state_dict(torch.load(model_path, 'cpu'))
# 模型预测,注意输入的数据第一个input_shape
y_pred = model(x_input.long().transpose(1, 0), y_input.long().transpose(1, 0))
idx2label = {0:"匹配失败!", 1:"匹配成功!"}
print('输入语句: %s \t %s' % (sentence1, sentence2))
print('文本匹配结果: %s' % idx2label[y_pred.argmax().item()])
except KeyError as err:
print("您输入的句子有汉字不在词汇表中,请重新输入!")
print("不在词汇表中的单词为:%s." % err)
数据集为QA_corpus,训练数据10w条,验证集和测试集均为1w条
其中对应模型文件夹下的args.py文件是超参数
QA_corpus
数据集展示