Restrições de poder representativas e estimativa de erro de generalização para redes neurais de gráfico

Atualmente, uma das tendências no estudo de redes neurais de grafos é a análise do funcionamento de tais arquiteturas, comparação com métodos nucleares, avaliação da complexidade e capacidade de generalização. Tudo isso ajuda a entender os pontos fracos dos modelos existentes e cria espaço para novos.





O trabalho tem como objetivo investigar dois problemas relacionados a redes neurais de grafos. Primeiro, os autores dão exemplos de gráficos que são diferentes em estrutura, mas indistinguíveis para GNNs simples e mais poderosos . Em segundo lugar, eles limitaram o erro de generalização para redes neurais de grafos com mais precisão do que os limites VC.





Introdução

Redes neurais de grafos são modelos que trabalham diretamente com gráficos. Eles permitem que você leve em consideração informações sobre a estrutura. Um GNN típico inclui um pequeno número de camadas que são aplicadas sequencialmente, atualizando as representações de vértices em cada iteração. Exemplos de arquiteturas populares: GCN , GraphSAGE , GAT , GIN .









O processo de atualização de embeddings de vértice para qualquer arquitetura GNN pode ser resumido por duas fórmulas:





a_v ^ {t + 1} = AGG \ esquerda (h_w ^ t: w \ in \ mathcal {N} \ esquerda (v \ direita) \ direita) \\ h_v ^ {t + 1} = COMBINAR \ esquerda (h_v ^ t, a_v ^ {t + 1} \ right),

onde AGG é geralmente uma função invariante a permutações ( soma , média , max etc.), COMBINE é uma função que combina a representação de um vértice e seus vizinhos.





Árvore de cálculo para GNN de duas camadas usando como exemplo o nó A. Fonte: https://eng.uber.com/uber-eats-graph-learning/
Árvore de computação para GNN de duas camadas usando como exemplo o nó A. Fonte: https://eng.uber.com/uber-eats-graph-learning/

Arquiteturas mais avançadas podem considerar informações adicionais, como recursos de borda, ângulos de borda, etc.





O artigo discute a classe GNN para o problema de classificação de grafos. Esses modelos são estruturados assim:





  1. Primeiro, os vértices são embeddings usando L etapas das convoluções do gráfico





  2. (, sum, mean, max)









GNN:





  • (LU-GNN). GCN, GraphSAGE, GAT, GIN





  • CPNGNN, , 1 d, d - ( port numbering)





  • DimeNet, 3D-,





LU-GNN

G G LU-GNN, , , readout-, . CPNGNN G G, .





CPNGNN

, “” , CPNGNN .





S8 S4 , , ( ), , , CPNGNN readout-, , . , .





CPNGNN G2 G1. , DimeNet , , , , \ ângulo A_1B_1C_1 \angle \underline{A}_1\underline{B}_1\underline{C}_1.





DimeNet

DimeNet G4 , G3, . , . , G4 G3 S4 S8, , , DimeNet S4 S8 .





GNN

. , , .





GNN, :





  1. DimeNet





  2. message- m_{uv}^{\left(l\right)} \Phi_{uv} \underline{m}_{uv}^{\left(l\right)} = \underline{f}\left(m_{uv}^{\left(l\right)}, \Phi_{uv}\right)





  3. \left(c_v\left(i\right), t_{i, v}\right), c - i- v, t - .



    :



    h_{v}^{\left( l + 1 \right)} = f \left( h_{v}^{\left( l \right)}, \underline{m}_{c_v\left( 1 \right)v}^{\left( l \right )}, t_{1, v}, ..., \underline{m}_{c_v\left( d \left( v \right ) \right)v}^{\left( l \right )}, t_{ d \left( v \right ), v} \right )





  4. readout-









.





: LU-GNN,





h_v^{l + 1} = \phi \left( W_1x_v + W_2 \rho \left( \sum_{u \in \mathcal{N} \left( v \right)} g\left( h_u^l \right)\right) \right),

\phi,\ g,\ \rho - , x_v - v, , \rho \left(0\right) = 0,\ \forall v:  \lVert x_v \rVert_2 \le B_x,\ \forall x \in \mathbb{R}^r: \lVert \phi \left( x \right ) \rVert_{\infty} \le b < \infty,\ \phi\left( 0 \right ) = 0,\ g\left( 0 \right ) = 0. , \phi,\ g,\ \rho C_{\phi},\ C_{g},\ C_{\rho}, \lVert W_1 \rVert_2 \le B_1,\ \lVert W_2 \rVert_2 \le B_2. W_1,\ W_2,\ \phi,\ g,\ \rho GNN.

. \beta \lVert \beta \rVert_2 \le B_{\beta}.





f\left(G\right) - GNN y \in \{0, 1\}, p\left( f \left( G \right ), y \right ) = y \left( 2f \left( G \right ) - 1 \right ) + \left( 1 - y \right ) \left( 1 - 2 f \left( G \right ) \right ) - , p\left( f \left( G \right ), y \right ) < 0 .





, a = -p\left( f \left( G \right ), y \right ), \mathbb{I}\left[\right] - :





loss_{\gamma}\left( a \right ) = \mathbb{I}\left[ a > 0\right ] + (1 + \frac{a}{\gamma})\mathbb{I}\left[ a \in \left[ \gamma, 0 \right ] \right].

GNN f \{G_j, y_j\}_{j=1}^m:





\hat{\mathcal{R}}\left( f \right) = \frac{1}{m} \sum_{j = 1}^m loss_{\gamma} \left( -p\left( f \left(G_j\right), y_j \right) \right)

, , , , GNN . , (GNN, ), , , .





, :





  • ,





  • ( )





RNN





C  RNN,
C RNN,

\mathcal{C}- “ ”: \mathcal{C} = C_{\phi}C_gC_{\rho}B_2, r - , d - , m - , L - , \gamma- ,





, , \ tilde {\ mathcal {O}} \ left (r ^ 3N / \ sqrt {m} \ right)





GNN, ( readout-), . , , .





( ), . , , , , , , .





Provas e informações mais detalhadas podem ser encontradas lendo o artigo original ou assistindo a um relatório de um dos autores.








All Articles