网站首页 > 文章精选 正文
摘要:深度学习还没学完,怎么图深度学习又来了?别怕,这里有份系统教程,可以将0基础的你直接送到图深度学习。还会定期更新哦。
主要是基于图深度学习的入门内容。讲述最基本的基础知识,其中包括深度学习、数学、图神经网络等相关内容。
文章涉及使用到的框架以PyTorch和TensorFlow为主。默认读者已经掌握Python和TensorFlow基础。如有涉及到PyTorch的部分,会顺带介绍相关的入门使用。
本教程主要针对的人群:
- 已经掌握TensorFlow基础应用,并想系统学习的学者。
- PyTorch学习者
- 正在从TensorFlow转型到PyTroch的学习者
- 已经掌握Python,并开始学习人工智能的学者。
本篇文章主要通过一个实例介绍如何在DGL中,搭建带有残差结构的多层GAT模型。它是在教程的第六篇GAT模型 基础上进行的延申。
1. 什么是残差结构
残差结构最早源自于ResNet50模型。
ResNet50模型是ResNet(残差网络)的第1个版本,该模型于2015年由何凯明等提出,模型有50层。
残差结构是ResNet50模型的核心特点,它解决了当时深层神经网络难于的训练问题。该网络借鉴了Highway Network(高速通道网络)的思想。在网络的主处理层旁边开了个额外的通道,使得输入可以直达输出。其结构如图所示。
假设x经过神经网络层处理之后,输出的结果为H(x),则结构中的残差网络输出的结果为Y(x)= H(x)+x。
在2015年的ILSVRC(大规模视觉识别挑战赛)中ResNet模型以成绩为:79.26%的Top-1准确率和94.75%的Top-5准确率,取得了当年比赛的第一名。这个模型简单实用,经常被嵌入其它深层网络结构中,作为特征提取层使用。
2.残差结构的原理
残差网络结构是由若干个残差块组成的深度卷积网络结构,如图所示是一个残差块。
在图中,x是该残差块输入,H(x)是期望输出。identity表示恒等映射,即输入是x,输出也是x。F(x)表示期望输出H(x)与输入x的残差,即F(x) =H(x) -x。
残差结构的基本想法是:假设已经有了一个深度神经网络,在其中再增加几个恒等映射,那么不仅增加了网络的深度,并且至少不会增加误差,这样更深的网络不应该导致误差的增加。因此残差结构学习的是残差。
从图中可以看出,当残差F(x)=0时,H(x) =x,这时网络没有误差。
利用这种残差结构,可以使得网络达到上百层的深度。
这种方式看似解决的梯度越传越小的问题,但是残差连接在正向同样也起到作用,由于正向的作用,导致网络结构已经不再是深层了。而是一个并行的模型,即残差连接的作用是将网络串行改成了并行。本质上起到与多通道卷积一致的效果。
3.残差结构在图神经网络中的应用
如果将图卷积或是图注意力卷积层,当作一个普通的卷积层。则也可以搭建出带有残差结构的图神经网络。在这种神经网络中残差结构同样有效,可以使图神经网络模型的层数达到很深。而它的性能更由于对图卷积或是图注意力卷积层进行简单堆叠的图神经网络模型。
4 实例:用带有残差结构的多层GAT模型实现论文分类
在教程三——全连接神经网络与图卷积中介绍过DGL库中有多种数据集。本例就来使用其中的论文数据集——CORA。
并使用带有残差结构的多层GAT模型对其进行分类。
4.1 代码实现:下载CORA数据集
直接使用dgl.data库中的citation_graph模块即可实现CORA数据集的下载。具体代码如下:
代码文件:code_30_dglGAT.py
import dgl
import torch
from torch import nn
from dgl.data import citation_graph
from dgl.nn.pytorch import GATConv
data = citation_graph.CoraDataset()#下载并加载数据集
代码第6行会自动在后台对CORA数据集进行下载。待数据集下载完成之后,对其进行加载返回data对象。
代码运行后输出如下内容:
Downloading C:\Users\ljh\.dgl/cora.zip from https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/cora_raw.zip...
#Extracting file to C:\Users\ljh\.dgl/cora
系统默认的下载路径为当前用户的.dgl文件夹。以作者的本机为例,下载路径为C:\Users\ljh\.dgl/cora.zip。
代码第6行返回的data对象中含有数据集的样本(features)、标签(labels)以及论文中引用关系的邻接矩阵,还有拆分好的训练、测试、验证数据集掩码。
其中,数据集的样本(features)已经被归一化处理,邻接矩阵是以NetWorkx图的形式存在的。
4.2. 代码实现:加工图数据
编写代码查看data对象中的样本数据。具体代码如下:
代码文件:code_30_dglGAT.py(续)
#输出运算资源请况
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)
features = torch.FloatTensor(data.features).to(device)#获得样本特征
labels = torch.LongTensor(data.labels).to(device)#获得标签
train_mask = torch.BoolTensor(data.train_mask).to(device)#获得训练集掩码
val_mask = torch.BoolTensor(data.val_mask).to(device) #获得验证集掩码
test_mask = torch.BoolTensor(data.test_mask).to(device) #获得测试集掩码
feats_dim = features.shape[1]#获得特征维度
n_classes = data.num_labels#获得类别个数
n_edges = data.graph.number_of_edges()#获得邻接矩阵边数
print("""----数据统计------、
#边数 %d
#样本特征维度 %d
#类别数 %d
#训练样本 %d
#验证样本 %d
#测试样本 %d""" % (n_edges, feats_dim,n_classes,
train_mask.int().sum().item(),val_mask.int().sum().item(),
test_mask.int().sum().item()))#输出结果
g = dgl.DGLGraph(data.graph)#将networkx图转成DGL图
g.add_edges(g.nodes(), g.nodes()) #添加自环
n_edges = g.number_of_edges()
代码第25~27行对邻接矩阵进行加工。需要为邻接矩阵加上自环边。
代码运行后输出如下内容:
----数据统计------
#边数 10556
#样本特征维度 1433
#类别数 7
#训练样本 140
#验证样本 300
#测试样本 1000
4.3 代码实现:用DGL库中的GATConv搭建多层GAT模型
在使用DGL库中的GATConv层时,可以将GATConv层直接当作深度学习中的卷积层,搭建多层图卷积网络。具体代码如下:
代码文件:code_30_dglGAT.py(续)
class GAT(nn.Module):#定义多层GAT模型
def __init__(self,
num_layers,#层数
in_dim, #输入维度
num_hidden,#隐藏层维度
num_classes,#类别个数
heads,#多头注意力的计算次数
activation,#激活函数
feat_drop,#特征层的丢弃率
attn_drop,#注意力分数的丢弃率
negative_slope,#LeakyReLU激活函数的负向参数
residual):#是否使用残差网络结构
super(GAT, self).__init__()
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
self.gat_layers.append(GATConv(in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation))
#定义隐藏层
for l in range(1, num_layers):
#多头注意力 the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation))
#输出层
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None))
def forward(self, g,inputs):
h = inputs
for l in range(self.num_layers):#隐藏层
h = self.gat_layers[l](g, h).flatten(1)
#输出层
logits = self.gat_layers[-1](g, h).mean(1)
return logits
def getmodel( GAT ): #定义函数实例化模型
#定义模型参数
num_heads = 8
num_layers = 1
num_out_heads =1
heads = ([num_heads] * num_layers) + [num_out_heads]
#实例化模型
model = GAT( num_layers, num_feats, num_hidden= 8,
num_classes = n_classes,
heads = ([num_heads] * num_layers) + [num_out_heads],#总的注意力头数
activation = F.elu, feat_drop=0.6, attn_drop=0.6,
negative_slope = 0.2, residual = True) #使用残差结构
return model
代码第11行设置了激活函数leaky_relu的负向参数,该激活函数在DGL库中的GATConv类在计算注意力时的非线性变换使用。这部分内容请参考教程三——全连接神经网络与图卷积
本节代码所实现的多层GAT网络模型主要结构分为两部分,隐藏层和输出层:
- 隐藏层:根据设置的层数进行多层图注意力网络的叠加。
- 输出层:在隐藏层之后,再叠加一个单层图注意力网络,输出的特征维度与类别数相同。
通过如下两行代码即可将模型结构打印出来:
model = getmodel(GAT)print(model)#输出模型
代码运行后输出如下结果:
结果中的“(0): GATConv”是隐藏层部分;“(1): GATConv”是输出层部分。
4.4 训练模型
训练模型与正常的深度学习训练过程完全一致。具体细节如下:
- 损失函数:torch.nn.CrossEntropyLoss()
- 优化器:torch.optim.Adam
- 学习率:lr=0.005
将前面准备好的图对象g和节点特征features传入模型中model(g,features)即可输出预测结果。
代码运行后,输出如下结果:
如果直接使用单层的GAT模型,其准确率只有0.7800。没有本例中模型的准确率0.8350高。
欢迎各位朋友学习,有问题可以在留言区指出,谢谢支持。
猜你喜欢
- 2025-01-09 图注意力网络论文详解和PyTorch实现
- 2025-01-09 使用scikit-learn为PyTorch 模型进行超参数网格搜索
- 2025-01-09 神经网络调试:梯度可视化
- 2025-01-09 涨姿势!「手动」调试神经网络,可以这样做
- 2025-01-09 深度学习的秘密武器:用 PyTorch 的 torch.nn.ReLU 打造高效模型
- 2025-01-09 #轻松学习深度学习(AI) 4 神经元的一般化
- 2025-01-09 基于深度学习的运动想象脑机接口研究综述
- 2025-01-09 使用多尺度patch合成来做高分辨率的图像复原
- 2025-01-09 神经网络训练tricks
- 2025-01-09 汇总|实时性语义分割算法
- 04-23关于linux coreutils/sort.c源码的延展思考最小堆为什么不用自旋
- 04-23一文精通如何使用二叉树
- 04-23二叉树(Binary Tree)
- 04-23数据结构入门:树(Tree)详细介绍
- 04-23数据结构错题收录(六)
- 04-23Kubernetes原理深度解析:万字图文全总结!
- 04-23一站式速查知识总结,助您轻松驾驭容器编排技术(水平扩展控制)
- 04-23kubectl常用删除命令
- 最近发表
- 标签列表
-
- newcoder (56)
- 字符串的长度是指 (45)
- drawcontours()参数说明 (60)
- unsignedshortint (59)
- postman并发请求 (47)
- python列表删除 (50)
- 左程云什么水平 (56)
- 计算机网络的拓扑结构是指() (45)
- 稳压管的稳压区是工作在什么区 (45)
- 编程题 (64)
- postgresql默认端口 (66)
- 数据库的概念模型独立于 (48)
- 产生系统死锁的原因可能是由于 (51)
- 数据库中只存放视图的 (62)
- 在vi中退出不保存的命令是 (53)
- 哪个命令可以将普通用户转换成超级用户 (49)
- noscript标签的作用 (48)
- 联合利华网申 (49)
- swagger和postman (46)
- 结构化程序设计主要强调 (53)
- 172.1 (57)
- apipostwebsocket (47)
- 唯品会后台 (61)
- 简历助手 (56)
- offshow (61)