已经是最新一篇文章了!
已经是最后一篇文章了!
GNN、GCN与GAT详解
前言
图神经网络(GNN)专门处理图结构数据,能够捕捉节点之间的关系。本文介绍GNN的基本概念、GCN和GAT等经典架构。
图数据基础
图的表示
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
# 图的基本表示
class Graph:
"""简单图数据结构"""
def __init__(self, num_nodes):
self.num_nodes = num_nodes
self.adj_matrix = np.zeros((num_nodes, num_nodes))
self.node_features = None
def add_edge(self, i, j, bidirectional=True):
self.adj_matrix[i, j] = 1
if bidirectional:
self.adj_matrix[j, i] = 1
def set_features(self, features):
self.node_features = features
def get_neighbors(self, node):
return np.where(self.adj_matrix[node] == 1)[0]
def get_degree(self, node):
return int(np.sum(self.adj_matrix[node]))
# 创建示例图
graph = Graph(6)
edges = [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4), (3, 5), (4, 5)]
for i, j in edges:
graph.add_edge(i, j)
# 设置节点特征
features = np.random.randn(6, 4) # 6个节点,每个4维特征
graph.set_features(features)
print("邻接矩阵:")
print(graph.adj_matrix)
print(f"\n节点特征形状: {graph.node_features.shape}")
print(f"节点0的邻居: {graph.get_neighbors(0)}")
可视化图
def visualize_graph(adj_matrix, node_labels=None):
"""可视化图结构"""
num_nodes = adj_matrix.shape[0]
# 节点位置(圆形布局)
angles = np.linspace(0, 2*np.pi, num_nodes, endpoint=False)
positions = np.column_stack([np.cos(angles), np.sin(angles)])
fig, ax = plt.subplots(figsize=(8, 8))
# 绘制边
for i in range(num_nodes):
for j in range(i+1, num_nodes):
if adj_matrix[i, j] == 1:
ax.plot([positions[i, 0], positions[j, 0]],
[positions[i, 1], positions[j, 1]],
'gray', linewidth=2, alpha=0.5)
# 绘制节点
ax.scatter(positions[:, 0], positions[:, 1], s=500, c='lightblue',
edgecolors='black', linewidths=2, zorder=5)
# 节点标签
for i in range(num_nodes):
label = node_labels[i] if node_labels else str(i)
ax.text(positions[i, 0], positions[i, 1], label,
ha='center', va='center', fontsize=12, fontweight='bold')
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('图结构可视化')
plt.tight_layout()
plt.show()
visualize_graph(graph.adj_matrix)
消息传递机制
核心思想
GNN的核心是消息传递(Message Passing):
- 聚合(Aggregate):收集邻居信息
- 更新(Update):结合自身特征更新
def simple_message_passing(adj_matrix, features, iterations=2):
"""简单消息传递"""
num_nodes = features.shape[0]
h = features.copy()
for it in range(iterations):
h_new = np.zeros_like(h)
for i in range(num_nodes):
# 聚合邻居特征(求和)
neighbors = np.where(adj_matrix[i] == 1)[0]
if len(neighbors) > 0:
neighbor_sum = np.sum(h[neighbors], axis=0)
else:
neighbor_sum = np.zeros(h.shape[1])
# 更新:自身 + 邻居平均
degree = max(len(neighbors), 1)
h_new[i] = h[i] + neighbor_sum / degree
h = h_new
print(f"迭代 {it+1}: 特征范围 [{h.min():.3f}, {h.max():.3f}]")
return h
# 测试
updated_features = simple_message_passing(graph.adj_matrix, graph.node_features)
图卷积网络(GCN)
GCN层
GCN的传播规则:
\[H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)})\]其中 $\tilde{A} = A + I$ 是加入自环的邻接矩阵。
class GCNLayer:
"""GCN层(NumPy实现)"""
def __init__(self, input_dim, output_dim):
self.W = np.random.randn(input_dim, output_dim) * 0.01
def forward(self, adj_matrix, features):
# 加入自环
A_hat = adj_matrix + np.eye(adj_matrix.shape[0])
# 计算度矩阵
D_hat = np.diag(np.sum(A_hat, axis=1))
# 归一化
D_inv_sqrt = np.linalg.inv(np.sqrt(D_hat))
A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
# 传播
out = A_norm @ features @ self.W
# ReLU激活
return np.maximum(0, out)
class SimpleGCN:
"""简单GCN模型"""
def __init__(self, input_dim, hidden_dim, output_dim):
self.layer1 = GCNLayer(input_dim, hidden_dim)
self.layer2 = GCNLayer(hidden_dim, output_dim)
def forward(self, adj_matrix, features):
h = self.layer1.forward(adj_matrix, features)
h = self.layer2.forward(adj_matrix, h)
return h
# 测试GCN
gcn = SimpleGCN(input_dim=4, hidden_dim=16, output_dim=2)
output = gcn.forward(graph.adj_matrix, graph.node_features)
print(f"输入特征: {graph.node_features.shape}")
print(f"输出特征: {output.shape}")
PyTorch实现
GCN完整实现
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNConv(nn.Module):
"""GCN卷积层"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=False)
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x, adj):
# 加入自环
adj = adj + torch.eye(adj.size(0), device=adj.device)
# 度矩阵
deg = adj.sum(dim=1)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
# 归一化邻接矩阵
adj_norm = deg_inv_sqrt.unsqueeze(1) * adj * deg_inv_sqrt.unsqueeze(0)
# 消息传递
support = self.linear(x)
output = torch.mm(adj_norm, support) + self.bias
return output
class GCN(nn.Module):
"""GCN模型"""
def __init__(self, in_features, hidden_features, out_features, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(in_features, hidden_features)
self.conv2 = GCNConv(hidden_features, out_features)
self.dropout = dropout
def forward(self, x, adj):
x = self.conv1(x, adj)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, adj)
return F.log_softmax(x, dim=1)
# 测试
x = torch.randn(6, 4)
adj = torch.tensor(graph.adj_matrix, dtype=torch.float32)
model = GCN(in_features=4, hidden_features=16, out_features=3)
out = model(x, adj)
print(f"GCN输出: {out.shape}")
except ImportError:
print("PyTorch未安装")
图注意力网络(GAT)
注意力机制
GAT使用注意力机制学习邻居的重要性:
\[\alpha_{ij} = \frac{\exp(\text{LeakyReLU}(a^T [Wh_i || Wh_j]))}{\sum_{k \in \mathcal{N}_i} \exp(\text{LeakyReLU}(a^T [Wh_i || Wh_k]))}\]try:
class GATLayer(nn.Module):
"""GAT层"""
def __init__(self, in_features, out_features, num_heads=1, concat=True):
super().__init__()
self.num_heads = num_heads
self.out_features = out_features
self.concat = concat
# 每个头的变换
self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
# 注意力参数
self.a = nn.Parameter(torch.zeros(num_heads, 2 * out_features))
nn.init.xavier_uniform_(self.a)
self.leaky_relu = nn.LeakyReLU(0.2)
def forward(self, x, adj):
N = x.size(0)
# 线性变换
h = self.W(x).view(N, self.num_heads, self.out_features)
# 计算注意力系数
# [N, heads, out] -> [N, N, heads, 2*out]
a_input = torch.cat([
h.unsqueeze(1).repeat(1, N, 1, 1),
h.unsqueeze(0).repeat(N, 1, 1, 1)
], dim=-1)
# [N, N, heads]
e = self.leaky_relu(torch.einsum('ijhd,hd->ijh', a_input, self.a))
# 只对邻居计算注意力
adj_3d = adj.unsqueeze(-1).repeat(1, 1, self.num_heads)
e = e.masked_fill(adj_3d == 0, float('-inf'))
attention = F.softmax(e, dim=1)
attention = attention.masked_fill(torch.isnan(attention), 0)
# 聚合
# [N, heads, out]
h_prime = torch.einsum('ijh,jhd->ihd', attention, h)
if self.concat:
return h_prime.view(N, -1)
else:
return h_prime.mean(dim=1)
class GAT(nn.Module):
"""GAT模型"""
def __init__(self, in_features, hidden_features, out_features,
num_heads=8, dropout=0.6):
super().__init__()
self.dropout = dropout
self.gat1 = GATLayer(in_features, hidden_features, num_heads, concat=True)
self.gat2 = GATLayer(hidden_features * num_heads, out_features, 1, concat=False)
def forward(self, x, adj):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.gat1(x, adj))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.gat2(x, adj)
return F.log_softmax(x, dim=1)
# 测试GAT
gat_model = GAT(in_features=4, hidden_features=8, out_features=3, num_heads=4)
gat_out = gat_model(x, adj)
print(f"GAT输出: {gat_out.shape}")
except NameError:
print("需要先导入PyTorch")
使用PyG库
try:
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
# 创建PyG数据
edge_index = torch.tensor([[0, 0, 1, 1, 2, 3, 3, 4],
[1, 2, 2, 3, 4, 4, 5, 5]], dtype=torch.long)
x = torch.randn(6, 4)
data = Data(x=x, edge_index=edge_index)
print(f"节点数: {data.num_nodes}")
print(f"边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}")
# 使用PyG的GCN
class PyGGCN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(4, 16)
self.conv2 = GCNConv(16, 3)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
pyg_gcn = PyGGCN()
out = pyg_gcn(data)
print(f"PyG GCN输出: {out.shape}")
except ImportError:
print("PyTorch Geometric未安装")
print("安装命令: pip install torch-geometric")
GNN应用场景
应用领域
| 应用 | 图结构 | 任务 |
|---|---|---|
| 社交网络 | 用户-好友 | 节点分类、链接预测 |
| 分子预测 | 原子-键 | 属性预测 |
| 推荐系统 | 用户-物品 | 推荐 |
| 知识图谱 | 实体-关系 | 链接预测 |
| 交通预测 | 路口-道路 | 流量预测 |
常见问题
Q1: GCN和GAT的区别?
| 特性 | GCN | GAT |
|---|---|---|
| 邻居权重 | 固定(度归一化) | 学习(注意力) |
| 表达能力 | 较弱 | 较强 |
| 计算复杂度 | 低 | 高 |
Q2: 如何处理大规模图?
- 采样方法(GraphSAGE)
- 小批量训练
- 分布式计算
Q3: GNN的局限性?
- 过平滑问题(层数增加)
- 计算复杂度
- 难以处理异构图
Q4: 如何选择GNN架构?
- 节点分类:GCN、GAT
- 图分类:加入池化层
- 链接预测:编码节点对
总结
| 模型 | 特点 | 适用场景 |
|---|---|---|
| GCN | 简单高效 | 一般图任务 |
| GAT | 注意力加权 | 需要区分邻居重要性 |
| GraphSAGE | 采样聚合 | 大规模图 |
参考资料
- Kipf, T. & Welling, M. (2017). “Semi-Supervised Classification with Graph Convolutional Networks”
- Veličković, P. et al. (2018). “Graph Attention Networks”
- Hamilton, W. et al. (2017). “Inductive Representation Learning on Large Graphs”
- PyTorch Geometric官方文档
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:《 机器学习基础系列——图神经网络 》
本文链接:http://localhost:3015/ai/%E5%9B%BE%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C.html
本文最后一次更新为 天前,文章中的某些内容可能已过时!