2020-ICLR-Memory-Based Graph Networks

 Paper:https://arxiv.org/abs/2002.09518
 Code: https://github.com/amirkhas/GraphMemoryNet
基于内存的图网络
图神经网络(GNN)是一类可对任意拓扑结构的数据进行操作的深度模型。 作者为GNN引入了一个有效的memory layer,该memory layer可以共同学习节点表示并对图谱进行粗化。 作者还介绍了基于此层的两个新网络:基于memory的GNN(MemGNN)和可以学习层次化图形表示的图形记忆网络(GMN)。实验结果表明,所提出的模型在九种图谱分类和回归数据集中有八种获得了state-of-the-art性能。作者还表明,学习的表示形式可能与分子数据中的化学特征相对应。
模型

 作者在第
    
     
      
       
        l
       
      
      
       l
      
     
    ll层定义一个记忆层
    
     
      
       
        
         M
        
        
         l
        
       
       
        :
       
       
        
         R
        
        
         
          
           n
          
          
           l
          
         
         
          ×
         
         
          
           d
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
        
       
      
      
       M^{l}:R^{n_l \times d_{l+1}}
      
     
    Ml:Rnl×dl+1, 它接受大小为 
    
     
      
       
        
         d
        
        
         l
        
       
      
      
       d_l
      
     
    dl 的 
    
     
      
       
        
         n
        
        
         l
        
       
      
      
       n_l
      
     
    nl 查询向量并生成大小为 
    
     
      
       
        
         d
        
        
         
          l
         
         
          +
         
         
          1
         
        
       
      
      
       d_{l+1}
      
     
    dl+1 的 
    
     
      
       
        
         n
        
        
         
          l
         
         
          +
         
         
          1
         
        
       
      
      
       n_{l+1}
      
     
    nl+1 查询向量,使得 
    
     
      
       
        
         n
        
        
         
          l
         
         
          +
         
         
          1
         
        
       
       
        <
       
       
        
         n
        
        
         l
        
       
      
      
       n_{l+1} < n_l
      
     
    nl+1<nl。输入和输出查询分别表示输入图和粗略图的节点表示形式。存储层学会联合粗化输入节点。即池化和变换它们的特征,即表征学习。
记忆层由memory keys的数组即多头memory和卷积层组成。 假设有 ∣ h ∣ |h| ∣h∣个memory head,则将输入查询与每个头的memory key进行比较,从而生成 ∣ h ∣ |h| ∣h∣个注意力矩阵,然后使用卷积层将其汇总到单个注意力矩阵中。
作者将输入查询
    
     
      
       
        
         Q
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
       
        ∈
       
       
        
         R
        
        
         
          
           n
          
          
           l
          
         
         
          ×
         
         
          
           d
          
          
           l
          
         
        
       
      
      
       Q^{(l)} \in R^{n_l×d_l}
      
     
    Q(l)∈Rnl×dl 视为输入图的节点表示,并将键 
    
     
      
       
        
         K
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
       
        ∈
       
       
        
         R
        
        
         
          n
         
         
          
           l
          
          
           +
          
          
           1
          
         
        
       
       
        ×
       
       
        
         d
        
        
         l
        
       
      
      
       K^{(l)} \in R^{n_{l+1}}×d_l
      
     
    K(l)∈Rnl+1×dl 视为查询的聚类质心。为了满足这一假设,作者强制使用聚类友好分布作为键和查询之间的距离度量。使用学生的 t 分布作为核来测量查询 
    
     
      
       
        
         q
        
        
         i
        
       
      
      
       q_i
      
     
    qi 和键 
    
     
      
       
        
         k
        
        
         j
        
       
      
      
       k_j
      
     
    kj 之间的归一化相似性,如下所示:
 
 其中 
    
     
      
       
        
         C
        
        
         
          i
         
         
          j
         
        
       
      
      
       C_{ij}
      
     
    Cij 是查询 
    
     
      
       
        
         q
        
        
         i
        
       
      
      
       q_i
      
     
    qi 和键 
    
     
      
       
        
         k
        
        
         j
        
       
      
      
       k_j
      
     
    kj 之间的归一化分数,即将节点 
    
     
      
       
        i
       
      
      
       i
      
     
    i 分配给聚类 
    
     
      
       
        j
       
      
      
       j
      
     
    j 的概率或查询 
    
     
      
       
        
         q
        
        
         i
        
       
      
      
       q_i
      
     
    qi 和内存键 
    
     
      
       
        
         k
        
        
         j
        
       
      
      
       k_j
      
     
    kj 之间的注意力得分,
    
     
      
       
        τ
       
      
      
       \tau
      
     
    τ 是学生 t 分布的自由度,即温度。为了增加容量,作者将memory keys建模为multi-head阵列, 即
    
     
      
       
        [
       
       
        
         C
        
        
         0
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
       
        .
       
       
        .
       
       
        .
       
       
        
         C
        
        
         h
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
       
        ]
       
       
        ∈
       
       
        
         R
        
        
         
          ∣
         
         
          h
         
         
          ∣
         
         
          ×
         
         
          
           n
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
         
          ×
         
         
          
           n
          
          
           l
          
         
        
       
      
      
       [C^{(l)}_0...C^{(l)}_h] \in R^{|h| \times n_{l+1} \times n_l}
      
     
    [C0(l)...Ch(l)]∈R∣h∣×nl+1×nl, 其中
    
     
      
       
        ∣
       
       
        h
       
       
        ∣
       
      
      
       |h|
      
     
    ∣h∣表示head数量。
为了将上述结果聚合为一组结果,作者将三个维度分别看作深度、高度和宽度,然后使用一个 1X1的卷积核进行聚合降维:
 
 其中 
    
     
      
       
        
         Γ
        
        
         φ
        
       
      
      
       \Gamma_{\varphi}
      
     
    Γφ是由 
    
     
      
       
        φ
       
      
      
       φ
      
     
    φ 参数化的 [1 × 1] 卷积算子,
    
     
      
       
        ∣
       
       
        ∣
       
      
      
       ||
      
     
    ∣∣是串联运算符,
    
     
      
       
        
         C
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
      
      
       C^{(l)}
      
     
    C(l) 是聚合的软赋值矩阵。
内存读取生成一个值矩阵 
    
     
      
       
        
         V
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
       
        ∈
       
       
        
         R
        
        
         
          
           n
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
         
          ×
         
         
          
           d
          
          
           l
          
         
        
       
      
      
       V^{(l)} \in R^{n_{l+1} \times d_l}
      
     
    V(l)∈Rnl+1×dl,该矩阵表示与输入查询位于同一空间中的粗略节点表示形式,并定义为软分配分数和原始查询的乘积,如下所示:
 
 值矩阵被馈送到单层前馈神经网络,以将 
    
     
      
       
        
         R
        
        
         
          
           n
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
         
          ×
         
         
          
           d
          
          
           l
          
         
        
       
      
      
       R^{n_{l+1}×d_l}
      
     
    Rnl+1×dl 的粗嵌入投影到 
    
     
      
       
        
         R
        
        
         
          n
         
         
          
           l
          
          
           +
          
          
           1
          
         
        
       
       
        ×
       
       
        
         d
        
        
         
          l
         
         
          +
         
         
          1
         
        
       
      
      
       R^{n_{l+1}}×d_{l+1}
      
     
    Rnl+1×dl+1 中,如下所示:
 
 其中 
    
     
      
       
        
         Q
        
        
         
          (
         
         
          l
         
         
          +
         
         
          1
         
         
          )
         
        
       
       
        ∈
       
       
        
         R
        
        
         
          
           n
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
         
          ×
         
         
          
           d
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
        
       
      
      
       Q^{(l+1)} \in R^{n_{l+1}×d_{l+1}}
      
     
    Q(l+1)∈Rnl+1×dl+1 是输出查询,
    
     
      
       
        W
       
       
        ∈
       
       
        
         R
        
        
         
          
           d
          
          
           l
          
         
         
          ×
         
         
          
           d
          
          
           
            l
           
           
            +
           
           
            1
           
          
         
        
       
      
      
       W \in R^{d_l×d_{l+1}}
      
     
    W∈Rdl×dl+1 是网络参数,σ是使用 LeakyReLU 实现的非线性。
对于图分类任务,我们可以通过堆叠记忆层最终获得整个图的向量表示,然后用全连接层进行分类:
 
 其中, 
    
     
      
       
        
         Q
        
        
         0
        
       
       
        =
       
       
        
         f
        
        
         q
        
       
       
        (
       
       
        g
       
       
        )
       
      
      
       Q_0 = f_q(g)
      
     
    Q0=fq(g)是将图
    
     
      
       
        g
       
      
      
       g
      
     
    g输入网络
    
     
      
       
        
         f
        
        
         g
        
       
      
      
       f_g
      
     
    fg得到的初始查询表示,也就是初始节点向量。根据
    
     
      
       
        
         f
        
        
         g
        
       
      
      
       f_g
      
     
    fg的不同,作者引出了两种模型,即GMN和MemGNN。
GMN
作者使用带重启的随机游走(RWR)(Pan et al., 2004)来计算拓扑嵌入(基于图结构的表示学习),然后按行对它们进行排序,以强制节点嵌入保持顺序不变。得到包含拓扑信息的节点表示即图扩散矩阵
    
     
      
       
        S
       
       
        ∈
       
       
        
         R
        
        
         
          n
         
         
          ×
         
         
          n
         
        
       
      
      
       S \in R^{n \times n}
      
     
    S∈Rn×n后,初始的查询表示通过两层前向网络计算得到:
 
 其中
    
     
      
       
        
         W
        
        
         0
        
       
       
        ∈
       
       
        
         R
        
        
         
          n
         
         
          ×
         
         
          
           d
          
          
           
            i
           
           
            n
           
          
         
        
       
      
      
       W_0 \in R^{n×d_{in}}
      
     
    W0∈Rn×din和 
    
     
      
       
        
         W
        
        
         1
        
       
       
        ∈
       
       
        
         R
        
        
         
          2
         
         
          
           d
          
          
           
            i
           
           
            n
           
          
         
         
          ×
         
         
          
           d
          
          
           0
          
         
        
       
      
      
       W_1 \in R^{2d_{in}×d_0}
      
     
    W1∈R2din×d0 是参数,
    
     
      
       
        ∣
       
       
        ∣
       
      
      
       ||
      
     
    ∣∣表示拼接操作,
    
     
      
       
        σ
       
      
      
       σ
      
     
    σ是LeakyReLU激活函数。
MemGNN
MemGNN直接使用图神经网络计算初始查询:
 
 其中, 
    
     
      
       
        
         G
        
        
         θ
        
       
      
      
       G_{\theta}
      
     
    Gθ是任意的图神经网络。作者在实现时使用了GAT模型的改进版e-GAT,也就是在计算注意力权重时考虑了边特征。注意力权重计算公式为:
 
 其中
    
     
      
       
        
         h
        
        
         i
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
      
      
       h^{(l)}_i
      
     
    hi(l) 和 
    
     
      
       
        
         h
        
        
         
          i
         
         
          →
         
         
          j
         
        
        
         
          (
         
         
          l
         
         
          )
         
        
       
      
      
       h^{(l)}_{ i→j}
      
     
    hi→j(l) 分别表示节点 
    
     
      
       
        i
       
      
      
       i
      
     
    i 的表示形式和将节点 
    
     
      
       
        i
       
      
      
       i
      
     
    i 连接到第 
    
     
      
       
        L
       
      
      
       L
      
     
    L 层中其单跳邻居节点 
    
     
      
       
        j
       
      
      
       j
      
     
    j 的边的表示形式。
    
     
      
       
        
         W
        
        
         n
        
       
      
      
       W_n
      
     
    Wn 和 
    
     
      
       
        
         W
        
        
         e
        
       
      
      
       W_e
      
     
    We 是可学习的节点和边权重,
    
     
      
       
        W
       
      
      
       W
      
     
    W 是计算注意力分数的单层前馈网络的参数。
    
     
      
       
        σ
       
      
      
       σ
      
     
    σ是使用LeakyReLU实现的非线性。
实验

 表1表明:(i)作者模型在GNN模型中获得了最先进的结果,并显着提高了酶,蛋白质,DD,Collab和RedditBinary数据集的性能,绝对精度分别为14.49%,6.0%,3.76%,2.62%和8.98%。(ii)作者模型在所有数据集上都优于图内核。(iii)与基线GNN相比,两种提出的模型都实现了更好的性能或具有竞争力,(iv)与MemGNN相比,GMN取得了更好的结果,这表明用全局拓扑嵌入替换本地邻接信息为模型提供了更有用的信息
 
 表2所示的结果表明,作者模型在BACE基准上以4.0 AUC-ROC的绝对margin实现了最先进的结果,并且在Tox21数据集上与最先进的GCN模型竞争,即绝对margin为0.001。
 表3表明,作者MemGNN模型在ESOL和亲脂性基准上分别以0.07和0.1 RMSE的绝对margin实现了最先进的结果。
 
 图2可视化了原子上学习的团簇(即,具有相同颜色的原子在同一团簇内),表明团簇主要由有意义的化学亚结构组成,例如碳链和羟基(OH)(即图2a),以及羧基(COOH)和苯环(即图2b)。从化学角度来看,羟基和羧基以及碳链对分子在水或脂质中的溶解度有显着影响。这证实了网络已经学习了对确定分子溶解度至关重要的化学特征。



















