目录
链路预测是指在一个给定的网络中,根据已有的网络结构信息,尝试预测两个节点之间是否存在连接或者可能会建立连接的概率。这在社交网络分析、生物信息学、推荐系统等领域中都有广泛的应用。
在复杂网络中,链路预测可以帮助我们理解网络的演化过程、发现隐藏的关系和未知的连接,以及预测未来的网络演化趋势。
1、常见的链路预测方法
链路预测并非一种绝对准确的预测方法,因为网络的演化和连接行为具有一定的随机性。
2、图神经网络上的链路预测
图神经网络(Graph Neural Networks,简称GNN)可以用于链路预测任务。GNN是一类专门用于处理图结构数据的深度学习模型,能够学习节点和边的特征表示,并在此基础上进行预测任务。
步骤:
3、使用PyTorch和DGL库实现图神经网络进行链路预测
导入必要的库,包括PyTorch和DGL。
import torch
import torch.nn as nn
import dgl
定义图神经网络模型 GNNLinkPredict
,模型包含两个图卷积层,输入特征维度为2,输出特征维度为1。
# 定义图神经网络模型
class GNNLinkPredict(nn.Module):
def __init__(self, in_feats, hidden_size, out_feats):
super(GNNLinkPredict, self).__init__()
self.conv1 = dgl.nn.GraphConv(in_feats, hidden_size)
self.conv2 = dgl.nn.GraphConv(hidden_size, out_feats)
def forward(self, g, features):
x = torch.relu(self.conv1(g, features))
x = torch.relu(self.conv2(g, x))
return x
创建示例图数据 g
,其中包括5个节点和7条边。定义节点特征 features
,每个节点有两个特征值。定义标签 labels
,表示边的连接情况。
# 构建示例图数据
# 创建一个有向图
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 4, 3, 4])
# 定义节点特征
features = torch.tensor([
[0.2, 0.4],
[0.3, 0.5],
[0.4, 0.6],
[0.5, 0.7],
[0.6, 0.8]
])
# 定义标签(边是否存在连接)
labels = torch.tensor([1, 1, 1, 0, 0, 1, 0], dtype=torch.float32)
划分训练集和测试集,使用布尔类型的掩码 train_mask
和 test_mask
表示。
# 划分训练集和测试集
train_mask = torch.tensor([True, True, True, False, False])
test_mask = torch.tensor([False, False, False, True, True])
创建图神经网络模型实例 model
。
定义优化器和损失函数,这里使用Adam优化器和二分类的交叉熵损失函数。
# 创建图神经网络模型
model = GNNLinkPredict(in_feats=2, hidden_size=16, out_feats=1)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()
进行模型训练。循环迭代多个epoch,在每个epoch中执行以下步骤
# 训练模型
for epoch in range(50):
model.train()
logits = model(g, features)
pred = logits.squeeze()
loss = criterion(pred[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练损失
print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")
在测试集上评估模型。将模型设置为评估模式 model.eval()
,然后使用训练好的模型对测试集进行预测。通过将预测结果应用sigmoid函数将其映射到0-1之间,并使用四舍五入将其转换为0或1的预测标签。计算预测准确率并输出。
# 在测试集上评估模型
model.eval()
with torch.no_grad():
logits = model(g, features)
pred = logits.squeeze()
pred = torch.sigmoid(pred) # 使用sigmoid函数将预测值映射到0-1之间
pred_labels = torch.round(pred) # 四舍五入为0或1的预测标签
accuracy = (pred_labels[test_mask] == labels[test_mask]).float().mean()
print(f"Accuracy: {accuracy.item()}")
汇总的代码:
# https://www.dgl.ai/pages/start.html
import torch
import torch.nn as nn
import dgl
# 定义图神经网络模型
class GNNLinkPredict(nn.Module):
def __init__(self, in_feats, hidden_size, out_feats):
super(GNNLinkPredict, self).__init__()
self.conv1 = dgl.nn.GraphConv(in_feats, hidden_size)
self.conv2 = dgl.nn.GraphConv(hidden_size, out_feats)
def forward(self, g, features):
x = torch.relu(self.conv1(g, features))
x = torch.relu(self.conv2(g, x))
return x
# 构建示例图数据
# 创建一个有向图
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 4, 3, 4])
# 添加自环
g = dgl.add_self_loop(g)
# 定义节点特征
features = torch.tensor([
[0.2, 0.4],
[0.3, 0.5],
[0.4, 0.6],
[0.5, 0.7],
[0.6, 0.8]
])
# 定义标签(边是否存在连接)
labels = torch.tensor([1, 1, 1, 0, 0, 1, 0], dtype=torch.float32)
# 划分训练集和测试集
train_mask = torch.tensor([True, True, True, False, False, False, False])
test_mask = torch.tensor([False, False, False, True, True, True, True])
# 创建图神经网络模型
model = GNNLinkPredict(in_feats=2, hidden_size=16, out_feats=1)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()
# 训练模型
for epoch in range(50):
model.train()
logits = model(g, features)
pred = logits.squeeze()
loss = criterion(pred[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练损失
print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")
# 在测试集上评估模型
model.eval()
with torch.no_grad():
logits = model(g, features)
pred = logits.squeeze()
pred = torch.sigmoid(pred) # 使用sigmoid函数将预测值映射到0-1之间
pred_labels = torch.round(pred) # 四舍五入为0或1的预测标签
accuracy = (pred_labels[test_mask] == labels[test_mask]).float().mean()
print(f"Accuracy: {accuracy.item()}")
留下个问题有空再解决。
关于复杂网络建模,我前面写了很多,大家可以学习参考。
【复杂网络建模】——Python通过平均度和随机概率构建ER网络