Metric Based Few-Shot Graph Classification

Abstract

Many modern deep-learning techniques do not work without enormous datasets. At the same time, several fields demand methods working in scarcity of data. This problem is even more complex when the samples have varying structures, as in the case of graphs. Graph representation learning techniques have recently proven successful in a variety of domains. Nevertheless, the employed architectures perform miserably when faced with data scarcity. On the other hand, few-shot learning allows employing modern deep learning models in scarce data regimes without waiving their effectiveness. In this work, we tackle the problem of few-shot graph classification, showing that equipping a simple distance metric learning baseline with a state-of-the-art graph embedder allows to obtain competitive results on the task. While the simplicity of the architecture is enough to outperform more complex ones, it also allows straightforward additions. To this end, we show that additional improvements may be obtained by encouraging a task-conditioned embedding space. Finally, we propose a MixUp-based online data augmentation technique acting in the latent space and show its effectiveness on the task.

Introduction

Graphs have ruled digital representations since the dawn of Computer Science. Their structure is simple and general, and their structural properties are well studied. Given the success of deep learning in different domains that enjoy a regular structure, such as those found in computer vision (Baumgartner et al. 2017; J. Zhang et al. 2019; Smirnov and Solomon 2021) and natural language processing (Brown et al. 2020; Mikolov et al. 2013; Vaswani et al. 2017; Devlin et al. 2019), a recent line of research has sought to extend it to manifolds and graph-structured data (Bronstein et al. 2017; Hamilton, Ying, and Leskovec 2017; Battaglia et al. 2018). Nevertheless, deep learning expressivity comes at a cost: deep models require vast amounts of data to search the complex hypothesis spaces they define. When data is scarce, these models end up overfitting the training set, hindering their generalization capability on unseen samples. While annotations are usually abundant in computer vision and natural language processing, they are harder to obtain for graph-structured data due to the impossibility or expensiveness of the annotation process (Hu et al. 2020; Sun et al. 2020). This is particularly true when the samples come from specialized domains such as biology, chemistry and medicine (Hassani 2022), where graph-structured data are ubiquitous. For instance, drug testing requires expensive in-vivo and laborious wet experiments to label drugs and protein graphs (Ma et al. 2020).

To address this problem, the field of Few-Shot (FS) learning (Fei-Fei, Fergus, and Perona 2006; Fink 2004) aims at designing models which can effectively operate in scarce data scenarios. While there has been some interest in FS node classification and link prediction, graph-level classification has mostly been overlooked in the FS setting. A general, high-level strategy for learning with little data consists of extracting knowledge from a related domain where the available data is enough to learn a reliable hypothesis and then adapt it to the task of interest. The latter is usually called the target task, while the data abundant one is denoted as source. When the task of interest is classification, categories from the source task are said to be base classes, while those of the target are called . Distance metric learning techniques have proven to be particularly effective for the problem, as samples embedded in a lower-dimensional space need fewer data to be discriminated. Among these methods, Prototypical Networks (Snell, Swersky, and Zemel 2017) has been one of the most successful, obtaining remarkable results in various tasks without waiving its simplicity. Surprisingly, versions of it have been derubricated to baseline in other existing FS graph classification works without much relevance. We confute in this work these misconceptions, showing that we can obtain competitive results in the task by equipping Prototypical Networks with a state-of-the-art graph embedder.

As typical in FS learning, we frame tasks as episodes, where an episode is defined by a set of classes and several supervised samples (supports) for each of them (Vinyals et al. 2016a). Such an episode is depicted in . While a standard Prototypical Network would embed the samples in the same way independently of the episode, we take inspiration from (Oreshkin, López, and Lacoste 2018) and empower the graph embeddings by conditioning them on the particular set of classes seen in the episode. This way, the intermediate features and the final embeddings may be modulated according to what is best for the current episode.

Given the insufficiency of the data to allow learning meaningful features for the target task from scratch, a learning model will often represent samples from the novel classes as a rearrangement of features from the data-abundant base classes. To this end, we propose a novel online data augmentation technique that creates artificial samples from two existing ones as a mixup of their latent representations. Training on these mixups of samples from the base classes forces the model to expect combinations of the base samples already at training time.

In summary, our contribution is three-fold:

  1. Showing that a distance metric learning approach, while simple, works better than more complex meta-learning approaches for the task of FS graph classification;

  2. Equipping the architecture with an episode-adaptive module, enhancing its expressivity to obtain more dynamic representations;

  3. Suggesting a novel online data augmentation technique that creates new artificial samples in the latent space, showing a significant performance gain.

All our results are general and apply equivalently to any kind of graph family considered. All our code is available for research purposes. [^1].

An N-way K-shot episode. In this example, there are N=3 classes. Each class has k=4 supports yielding a support set with size N*K=12. The class information provided by the supports is exploited to classify the queries. We test the classification accuracy on all N classes. In Figure there are Q=2 queries for each class thus the query set has size N*Q = 6.

Related work

Graph Representation Learning

Given graphs expressive power, graph-based modeling is ubiquitous in a variety of domains, ranging from social networks (Fan et al. 2022; Monti et al. 2019) to chemistry (Duvenaud et al. 2015; Fout et al. 2017; Stokes et al. 2020). In the graph representation learning context, the focus is on finding the best possible representation for these data structures: it usually involves the learning of a mapping between the discrete graph space and a continuous latent one in \(\mathbb{R}^d\) that should encode both topological features and signal possibly defined over nodes or edges. These features are usually obtained by graph convolutions (Kipf and Welling 2017; Velickovic et al. 2018). Due to its theoretical expressive power, we employ a Graph Isomorphism Network (Xu et al. 2019) as graph embedder. While the essence of supervised DL approaches requires large volumes of annotated data to learn from (Lake et al. 2017), the intrinsic structural complexity of graphs and their application in very specialized fields often give rise to annotation problems. When this happens, the too few supervised samples are not enough to fit the deep models, consequently hindering their performances.

Few-Shot Learning

FS learning has gained attention as a means to tighten the gap between machine learning models and human-like learning capability. In fact, the capacity to generalize from few examples is an hallmark of human intelligence already from infancy, as children can, for example, learn words they have heard only once (Carey and Bartlett 1978). In practice, several high-level paradigms are used in literature to solve tasks in a FS scenario: transfer learning techniques (Liu et al. 2018; Luo et al. 2017) aim at transferring the knowledge gained from a data-abundant task to a task with scarce data; ii) meta-learning (Finn, Abbeel, and Levine 2017; Yoon et al. 2018; Ravi and Larochelle 2017) techniques more generally introduce a meta-learning procedure to gradually learn meta-knowledge that generalizes across several tasks; iii) data augmentation works (Wu et al. 2018; Gao et al. 2018; Tsai and Salakhutdinov 2017) seek to augment the data applying transformations on the available samples to generate new ones preserving specific properties. We refer the reader to (Yaqing Wang et al. 2020) for an extensive treatment of the matter. Particularly relevant to our work are distance metric learning approaches: in this direction, (Vinyals et al. 2016a) suggest embedding both supports and queries and then labeling the query with the label of its nearest neighbor in the embedding space. By obtaining a class distribution for the query using a softmax over the distances from the supports, they then learn the embedding space by minimizing the negative log-likelihood. (Snell, Swersky, and Zemel 2017) generalize this intuition by allowing \(k\) supports for class to be aggregated to form prototypes. Given its effectiveness and simplicity, we chose this approach as the starting point for our architecture.

Graph Data Augmentation

Data augmentation follows the idea that in the working domain, there exist transformations that can be applied to samples to generate new ones in a controlled way (e.g., preserving the sample class in a classification setting while changing its content). Therefore, synthetic samples can meet the needs of large neural networks that require training with high volumes of data (Yaqing Wang et al. 2020). In Euclidean domains (e.g., images), this can often be achieved by simple rotations and translations (Benaim and Wolf 2018; Santoro et al. 2016). Unfortunately, in the graph domain, it is challenging to define such transformations on a given graph sample while keeping control of its properties. To this end, (Park, Shim, and Yang 2022; Yiwei Wang et al. 2020; H. Guo and Mao 2021; Han et al. 2022) propose to augment graph data directly in the data space, while (Yiwei Wang et al. 2021) interpolates latent representations to create novel ones. We also operate in the latent space, but differently from (Yiwei Wang et al. 2021), we suggest creating a new sample by selecting only certain features of one representation and the remaining ones from the other by employing a random gating vector. This allows to obtain artificial samples that are randomly composed of the features of the existing samples, rather than a linear interpolation of their features.

Few-Shot Graph Representation Learning

FS graph representation learning is concerned in applying graph representation learning techniques in scarce data scenarios. Similarly to standard graph representation learning, it tackles tasks at different levels of granularity: node-level (Zhou et al. 2019; S. Zhang et al. 2019; Yao et al. 2019; N. Wang et al. 2020; Ding et al. 2020), edge-level (Baek, Lee, and Hwang 2020; Sheng et al. 2020; S. Wang et al. 2021; Lv et al. 2019) and graph-level (Z. Guo et al. 2021; Yaqing Wang et al. 2021; Chauhan, Nathani, and Kaul 2020; Ma et al. 2020; Jiang et al. 2021). In this work, we tackle the problem of FS graph classification. The first work to study graph classification in a FS setting can be traced to (Chauhan, Nathani, and Kaul 2020) (GSM). Its contributions resulted in a helpful baseline method and datasets for the following works in the literature. Following works adapted MAML (Finn, Abbeel, and Levine 2017) to the graph setting (Ma et al. 2020) and proposed a modified version of Prototypical Networks leveraging domain specific priors (Jiang et al. 2021). Differently from the latter, we do not make any such assumption and do not change the original loss function, obtaining superior performance.

Approach

Setting and Notation

In FS graph classification each sample is a tuple \((\mathcal{G}=(\mathcal{V},\mathcal{E}), y)\) where \(\mathcal{G}=(\mathcal{V},\mathcal{E})\) is a graph with node set \(\mathcal{V}\) and edge set \(\mathcal{E}\), while \(y\) is a graph-level class. Given a set of data-abundant base classes \(C_{\text{b}}\), we aim to classify a set of data-scarce novel classes \(C_{\text{n}}\). We cast this problem through an episodic framework (Vinyals et al. 2016b): during training, we mimic the FS setting dividing the base training data in episodes. Each episode e is a \(N\)-way \(K\)-shot classification task, with its own train (\(D_{\text{train}}\)) and test (\(D_{\text{test}}\)) data. For each of the \(N\) classes, \(D_{\text{train}}\) contains \(K\) corresponding support graphs, while \(D_{\text{test}}\) contains \(Q\) query graphs. A schematic visualization of an episode is depicted in .

Prototypical Networks architecture. A graph encoder embeds the supports graphs, the embeddings that belong to the same class are averaged to obtain the class prototype p. To classify a query graph q, it is embedded in the same space of the supports. The distances in the latent space between the query and the prototypes determine the similarities and thus the probability distribution of the query among the different classes, computed as in .

Prototypical Network (PN) Architecture

We build our network upon the simple-yet-effective idea of Prototypical Networks (Snell, Swersky, and Zemel 2017), originally proposed for FS image classification. We employ a state-of-the-art Graph Neural Network (GNN) as node embedder, composed of a set of layers of GIN convolutions (Xu et al. 2019), each equipped with a MLP regularized with GraphNorm (Cai et al. 2021). In practice, each sample is first passed through a set of convolutions, obtaining a hidden representation \({h}^{(l)}\) for each layer. According to (Xu et al. 2019), the latter is obtained by updating at each layer its hidden representation as \[\begin{aligned} \label{GIN-agg} \mathbf{h}_v^{(l)} = {\rm MLP}^{(l)} \left( \left( 1 + \epsilon^{(l)} \right) \cdot \mathbf{h}_v^{(l-1)} + \sum\nolimits_{u \in \mathcal{N}(v)} \mathbf{h}_u^{(l-1)}\right).\end{aligned}\] where \(\epsilon^{(l)}\) is a learnable parameter. Following (Xu et al. 2018), the final node \(d\)-dimensional embedding \(\mathbf{h}_v \in R^{d}\) is then given by the concatenation of the outputs of all the layers \[\mathbf{h}_v = \text{CONCAT}\left( \left\{ \mathbf{h}_v^{(l)} \right\}_{l=1}^{L} \right)\] where \(L\) is the total number of layers in the network. The graph-level embedding is then obtained by employing a global pooling function, such as mean or sum. While the sum is a more expressive pooling function for GNNs (Xu et al. 2019), we observed the mean to behave better for the task in most considered datasets and will therefore be the aggregation function of choice when not specified differently. The \(K\) embedded supports \(\mathbf{s}_1^{(n)}, \dots, \mathbf{s}_K^{(n)}\) for each class \(n\) are then aggregated to form the class prototypes \(\mathbf{p}^{(n)}\), \[\mathbf{p}^{(n)} = \frac{1}{K} \sum_{k=1}^{K} \mathbf{s}_k^{(n)}\] In the same way, the \(Q\) query graphs for each class \(n\) are embedded to obtain \(\mathbf{q}_1^{(n)}, \dots, \mathbf{q}_Q^{(n)}\). To compare each query graph embedding \(\mathbf{q}\) with the class prototypes \(\mathbf{p}_1, \dots, \mathbf{p}_N\), we use an \(\mathcal{L}_2\)-metric scaled by a learnable temperature factor \(\alpha\) as suggested in (Oreshkin, López, and Lacoste 2018). We refer to this metric as \(d_{\alpha}\). The class probability distribution \(\boldsymbol{\rho}\) for the query is finally computed by taking the softmax over these distances \[\boldsymbol{\rho}_n = \frac{\exp \left( -d_{\alpha}(\mathbf{q}, \mathbf{p}_n) \right)}{\sum_{n'=1}^{N} \exp(-d_{\alpha}(\mathbf{q}, \mathbf{p}_{n'}))}. \label{eq:proto-class-distr}\] The model is then trained end-to-end by minimizing via SGD the log-probability \(\mathcal{L}(\phi) = -\log \boldsymbol{\rho}_n\) of the true class \(n\). We will refer to this approach without additions as PN in the experiments.

Task-Adaptive Embedding (TAE)

Until now, our module computes the embeddings regardless of the specific composition of the episode. Our intuition is that the context in which a graph appears should influence its representation. In practice, inspired by (Oreshkin, López, and Lacoste 2018), we condition the embeddings on the particular task (episode) for which they are computed. Such influence will be expressed by a translation \(\boldsymbol{\beta}\) and a scaling \(\boldsymbol{\gamma}\).

First of all, given an episode \(e\) we compute an episode representation \(\mathbf{p}_{\mathbf{e}}\) as the mean of the prototypes \(\mathbf{p}_n\) for the classes \(n = 1, \dots, N\) in the episode. We consider \(\mathbf{p}_{\mathbf{e}}\) as a prototype for the episode and a proxy for the task. Then, we feed it to a Task Embedding Network (TEN), composed of two distinct residual MLPs. These output a shift vector \(\boldsymbol{\beta}_\ell\) and a scale vector \(\boldsymbol{\gamma}_\ell\) respectively for each of the layers of the graph embedding module. At layer \(\ell\), the output \(\mathbf{h}_{\ell}\) is then conditioned on the episode by transforming it as \[h_{\ell}' = \boldsymbol{\gamma} \odot h_{\ell}+\boldsymbol{\beta}.\] As in (Oreshkin, López, and Lacoste 2018), at each layer \(\boldsymbol{\gamma}\) and \(\boldsymbol{\beta}\) are multiplied by two \(L_{2}\)-penalized scalars \(\gamma_{0}\) and \(\beta_{0}\) so to to promote significant conditioning only if useful. Wrapping up, defining \(g_{\Theta}\) and \(h_{\Phi}\) to be the predictors for the shift and scale vectors respectively, the actual vectors to be multiplied to the hidden representation are respectively \(\boldsymbol{\beta}= \beta_0 g_{\Theta}(\mathbf{e})\) and \(\boldsymbol{\gamma}= \gamma_0 h_{\Phi}(\mathbf{e}) + \mathbf{1}\). When we use this improvement in our experiments, we add the label TAE to the method name.

MixUp (MU) Embedding Augmentation

Typical learning pipelines rely on data augmentation to overcome limited variability in the dataset. While this is mainly performed to obtain invariance to specific transformations, we use it to improve our embedding representation, promoting generalization on unseen feature combinations. In practice, given an episode \(\mathbf{e}\), we randomly sample for each pair of classes \(n_1, n_2\) two graphs \(\mathcal{G}^{(1)}\) and \(\mathcal{G}^{(2)}\) from the corresponding support sets. Then, we compute their embeddings \(\mathbf{s}^{(1)}\) and \(\mathbf{s}^{(2)}\), as well as their class probability distributions \(\boldsymbol{\rho}^{(1)}\) and \(\boldsymbol{\rho}^{(2)}\) according to . Next, we randomly obtain a boolean mask \(\boldsymbol{\sigma} \in \{0,1\}^k\). We can then obtain a novel synthetic example by mixing the features of the two graphs in the latent space \[\tilde{\boldsymbol{s}} = \boldsymbol{\sigma} \mathbf{s}_{1} + (\mathbf{1} - \boldsymbol{\sigma}) \mathbf{s}_{2},\] where \(\mathbf{1}\) is a \(d\)-dimensional vector of ones. Finally, we craft a synthetic class probability \(\tilde{\rho}\) for this example by linear interpolation \[\tilde{\boldsymbol{\rho}} =\lambda \boldsymbol{\rho}^{(1)} + (1 - \lambda) \boldsymbol{\rho}^{(2)} \label{eq:linear-interpolation-distr}\] where \(\lambda = \left( \frac{1}{d} \sum\limits_{i=1}^{d} \boldsymbol{\sigma}_i \right)\) represents the percentage of features sampled from the first sample. If we then compute the class distribution \(\boldsymbol{\rho}\) for \(\mathbf{s}\) according to , we can ask it to be similar to by adding the following regularizing term to the training loss \[\mathcal{L}_{\text{MU}} = \| \boldsymbol{\rho} - \tilde{\boldsymbol{\rho}}\|_2^2.\] Intuitively, by adopting this online data augmentation procedure, the network is faced with new feature combinations during training, helping to explore unseen regions of the embedding space. The overall procedure is summarized in .

Mixup procedure. Each graph is embedded into a latent representation. We generate a random boolean mask \boldsymbol{\sigma} and its complementary \mathbf{1} - \boldsymbol{\sigma}, which describe the features to select from \mathbf{s}_1 and \mathbf{s}_2. The selected features are then recomposed to generated the novel latent vector \tilde{\mathbf{s}}.

Experiments

Datasets

We benchmark our approach over two type of datasets: the first one was introduced in (Chauhan, Nathani, and Kaul 2020), and consists of:

  1. TRIANGLES, a collection of graphs labeled \(i=1, \dots, 10\), where \(i\) is the number of triangles in the graph.

  2. ENZYMES, a dataset of protein tertiary structures from the BRENDA database (Chang et al. 2020); each label corresponds to a different top-level enzyme.

  3. Letter-High, a collection of graph-represented letter drawings from the English alphabet; each drawing is labeled with the corresponding letter.

  4. Reddit-12K, a social network dataset where graphs represent threads, with edges connecting users interacting. The label of a thread is given by the corresponding discussion forum.

We will refer to this set of datasets as \(\mathcal{D}_\text{A}\). The second set of datasets was introduced in (Ma et al. 2020), and consists of:

  1. Graph-R52, a textual dataset in which each graph represents a different text, with words being connected by an edge if they appear together in a sliding window.

  2. COIL-DEL, a collection of graph-represented images obtained through corner detection and Delaunay triangulation.

We will refer to this set of datasets as \(\mathcal{D}_\text{B}\). The overall dataset statistics are reported in .

It is important to note that only the datasets in \(\mathcal{D}_\text{B}\) have enough classes to permit a disjoint set of classes for validation. In contrast, a disjoint subset of the training samples is used as a validation set in the first four by existing works. We argue that this setting is critically unfit for FS learning, as the validation set does not make up for a good proxy for the actual testing environment since the classes are not novel. Moreover, the lack of a reliable validation set prevents the usage of early stopping, as there is no way to decide on a good stopping criterion for samples from unseen classes. We nevertheless report the reports on this evaluation setting for the sake of comparison.

Baselines

We extensively compare our results to the other existing works in current literature, i.e. the pioneering work GSM (Chauhan, Nathani, and Kaul 2020), the current state-of-the-art method (AS-MAML) (Ma et al. 2020) and the most recent SMF-GIN (Jiang et al. 2021). As in (Ma et al. 2020), we decided to directly report the results included by the respective authors in their works.

GSM

As a first step, the authors compute graph prototypes and then cluster them based on their spectral properties to create super-classes. These are then used to generate a super-graph which is employed to separate the novel graphs. As previously stated, the class prototypes are computed from the spectral properties of the graphs rather than the average of their embeddings, as usually done in prototypical networks.

AS-MAML

This solution follows the meta-learning framework, and it is based on the popular Model-Agnostic Meta-Learning (MAML) architecture (Finn, Abbeel, and Levine 2017). The authors moreover propose a reinforcement learning-based adaptive step controller to improve the robustness of the graph meta-learner.

SMF-GIN

A GNN is employed to encode both global (attention over different GNN layer encodings) and local (attention over different substructure encodings) properties. We point out that they include a ProtoNet-based baseline. However, their implementation does not accurately follow the original one, and, differently from us, leverages domain specific prior knowledge.

Experimental details

Our graph embedder is composed of two layers of GIN followed by a mean pooling layer and the dimension of the resulting embeddings is set to \(64\). Furthermore, both the latent mixup regularizer and the L2 regularizer of the task-adaptive embedding are weighted at \(0.1\). The framework is trained with a batch size of \(32\) using Adam optimizer with a learning rate of \(0.0001\). We implement our framework with Pytorch Lightning (Falcon et al. 2019) using Pytorch Geometric (Fey and Lenssen 2019) and Wandb (Biewald 2020) to log the experiment results of the model. The specific configurations of all our approaches are reported in .

Results

Benchmark \(\mathcal{D}_\text{A}\)

As can be seen in , the proposed approach compares favorably in the large majority of the considered datasets. We obtain an overall margin of \(+7.24\%\) and \(+4.9\%\) accuracy compared to GSM (Chauhan, Nathani, and Kaul 2020) for \(k=5, 10\) respectively, while the improvements are \(+0.96\%\) and \(+0.86\%\) with respect to AS-MAML (Ma et al. 2020). We note that the latter are only computed over TRIANGLES and Letter-High, as the authors chose not to evaluate over the two remaining datasets in the \(\mathcal{D}_A\) benchmark. However, we again stress the partial inadequacy of these datasets as a realistic evaluation tool due to the lack of a disjoint set of classes for the validation set.

Macro accuracy scores over different k-shot settings and architectures. They are partitioned into baselines (upper section) and our full architecture (lower section). The best scores are in bold.
TRIANGLESLetter-HighENZYMESReddit
5-shot10-shot5-shot10-shot5-shot10-shot5-shot10-shot
GSM71.4075.6069.9173.2855.42\(\mathbf{60.64}\)41.5945.67
SMF-GIN79.85-------
AS-MAML\(86.47\)87.2676.2977.87----
PN+TAE+MU\(\mathbf{87.47}\)\(\mathbf{87.59}\)\(\mathbf{77.21}\)\(\mathbf{79.26}\)\(\mathbf{56.86}\)\(59.39\)\(\mathbf{45.75}\)\(\mathbf{48.55}\)

Benchmark \(\mathcal{D}_\text{B}\)

Consistently, our approach yields significant improvements over both COIL-DEL and R-52. In fact, shows that the model obtains an average margin of \(+4.37\%\) and \(+4.53\%\) over AS-MAML (Ma et al. 2020) for \(k=5, 10\) respectively. Interestingly, a vanilla Prototypical Network (PN) architecture with the proposed graph embedder is already sufficient to obtain state-of-the-art results. The subsequent rows then show the efficacy of the proposed improvements. Task-adaptive embedding (TAE) allows obtaining the most critical gain, yielding an average increment of \(+2.82\%\) and \(+2.42\%\) for the \(5\)-shot and \(10\)-shot cases, respectively. Then, the proposed online data augmentation technique (MU) allows obtaining an additional boost, especially on COIL-DEL. In fact, in the latter case, its addition yields a \(+0.65\%\) and \(+1.72\%\) improvement in accuracy for \(k=5, 10\). We notice that on this benchmark we obtain a large margin of improvement w.r.t. the state of the art. These are also the datasets in which the evaluation has the more grounded setting, given the availability of a proper validation set.

Macro accuracy scores over different k-shot settings and architectures. They are partitioned into baselines (upper section) and our architectures (lower section). AS-MAML is the only baseline available on these two datasets. The best scores are in bold.
Graph-R52COIL-DELmean
5-shot10-shot5-shot10-shot5-shot10-shot
AS-MAML75,3378,3381.5584.7578.4481.54
PN\(73.11\)\(78.04\)\(85.57\)\(87.28\)\(79.34\)\(82.66\)
PN+TAE\(77.91\)\(81.32\)\(86.41\)\(88.85\)\(82.16\)\(85.08\)
PN+TAE+MU\(\mathbf{77.92}\)\(\mathbf{81.57}\)\(\mathbf{87.7}\)\(\mathbf{90.57}\)\(\mathbf{82.81}\)\(\mathbf{86.07}\)

Qualitative analysis

The latent space learned by the graph embedder is the core element of our approach since it determines the prototypes and the subsequent sample classification. To provide a better insight into our method peculiarities, depicts a T-SNE representation of the learned embeddings for novel classes. Each row represents different episodes, while the different columns show the different embeddings obtained with our approach and its further refinements. We also highlight the queries (crosses), the supports (circles) and the prototypes (star). As can be seen, our approach separates samples belonging to novel classes into clearly defined clusters. Already in PN model, some classes naturally cluster in different regions of the embedding. The TAE regularization improves the class separation without changing the disposition of the cluster in the space significantly. Our insight, is that the context may let the network reorganize the already seen space, without moving far from the already obtained representation. Finally, MU allows better use of previously unexplored regions, as expected from this kind of data augmentation. We show that our feature recombination helps the network better generalize and anticipate the coming of novel classes.

Visualization of novel episodes latent spaces from the COIL-DEL dataset, through T-SNE dimensionality reduction. Each row is a different episode, the colors represent novel classes, the crosses are the queries, the circles are the supports and the stars are the prototypes. The left column is produced with the base model PN, the middle one with the PN+TAE model, the right one with the full model PN+TAE+MU. This comparison shows the TAE and MU regularizations improve the class separation in the latent space, in particular our novel data augmentation MU is essential to obtain accurate latent clusters.

Conclusions

Limitations.

Employing a graph neural network embedder, the proposed approach may inherit known issues such as the presence of information bottlenecks (Topping et al. 2021) and over smoothing. These may be aggravated by the additional aggregation required to compute the prototypes. The readout function to obtain a graph-level representation is already an aggregation of the node embeddings. Also, the nearest-neighbour association in the final embedding assumes that it enjoys a euclidean metric. While this is an excellent local approximation, we expect it may lead to imprecision. To overcome this, further improvements can be inspired by the Computer Vision community.

Future works.

In future work, we aim to enrich the latent space defined by the architecture, for instance, forcing the class prototypes in each episode to be sampled from a learnable distribution rather than directly computed as the mean of the supports. Moreover, it may be worth introducing an attention layer to have supports (or prototypes, directly) affect each other directly and not implicitly, as it now happens with the task embedding module. We also believe data augmentation is a crucial technique for the future of this task: the capacity to meaningfully inflate the small available datasets may result in a significant performance improvement. In this regard, we plan to extensively test the existing graph data augmentation techniques in the few-shot scenario and build upon MixUp to exploit different mixing strategies, such as non-linear interpolation.

Conclusions.

In this paper, we tackle the problem of FS graph classification, an under-explored problem in the broader machine learning community. The problem is particularly relevant in specialized domains such as biology and chemistry, given the nature of the samples to annotate. We propose a distance metric learning approach to tackle the problem, as samples embedded in a lower-dimensional space need fewer data to be discriminated. In this regard, we suggest adapting Prototypical Networks to the graph modality, showing that its simplicity does not undermine its effectiveness in the task. We then suggest valuable additions to the architecture, adapting a task-adaptive embedding procedure and designing a novel online graph data augmentation technique. Lastly, we prove their benefits for the problem over several datasets. We hope this work encourages researchers and practitioners to reconsider the effectiveness of distance metric learning when dealing with graph-structured data. In fact, we believe metric learning to be incredibly fit for dealing with graphs, considering that the latent spaces encoded by graph neural networks are known to capture both topological features and node signals efficaciously.

References

Baek, Jinheon, Dong Bok Lee, and Sung Ju Hwang. 2020. “Learning to Extrapolate Knowledge: Transductive Few-Shot Out-of-Graph Link Prediction.” In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, Virtual.

Battaglia, Peter, Jessica Blake Chandler Hamrick, Victor Bapst, Alvaro Sanchez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, et al. 2018. “Relational Inductive Biases, Deep Learning, and Graph Networks.” arXiv.

Baumgartner, Christian F, Lisa M Koch, Marc Pollefeys, and Ender Konukoglu. 2017. “An Exploration of 2d and 3d Deep Learning Techniques for Cardiac MR Image Segmentation.” In International Workshop on Statistical Atlases and Computational Models of the Heart. Springer.

Benaim, Sagie, and Lior Wolf. 2018. “One-Shot Unsupervised Cross Domain Translation.” In Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, December 3-8, 2018, Montréal, Canada.

Biewald, Lukas. 2020. “Experiment Tracking with Weights and Biases.”

Bronstein, Michael M., Joan Bruna, Yann LeCun, Arthur Szlam, and Pierre Vandergheynst. 2017. “Geometric Deep Learning: Going Beyond Euclidean Data.” IEEE Signal Processing Magazine 34 (4). https://doi.org/10.1109/MSP.2017.2693418.

Brown, Tom B., Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, et al. 2020. “Language Models Are Few-Shot Learners.” In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, Virtual.

Cai, Tianle, Shengjie Luo, Keyulu Xu, Di He, Tie-Yan Liu, and Liwei Wang. 2021. “GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training.” In 2021 International Conference on Machine Learning.

Carey, Susan, and E. Bartlett. 1978. “Acquiring a Single New Word.” Proceedings of the Stanford Child Language Conference 15.

Chang, Antje, Lisa Jeske, Sandra Ulbrich, Julia Hofmann, Julia Koblitz, Ida Schomburg, Meina Neumann-Schaal, Dieter Jahn, and Dietmar Schomburg. 2020. “BRENDA, the ELIXIR core data resource in 2021: new developments and updates.” Nucleic Acids Research 49 (D1). https://doi.org/10.1093/nar/gkaa1025.

Chauhan, Jatin, Deepak Nathani, and Manohar Kaul. 2020. “Few-Shot Learning on Graphs via Super-Classes Based on Graph Spectral Measures.” In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020. OpenReview.net.

Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. “BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding.” In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers). Minneapolis, Minnesota: Association for Computational Linguistics. https://doi.org/10.18653/v1/N19-1423.

Ding, Kaize, Jianling Wang, Jundong Li, Kai Shu, Chenghao Liu, and Huan Liu. 2020. “Graph Prototypical Networks for Few-Shot Learning on Attributed Networks.” In CIKM ’20: The 29th ACM International Conference on Information and Knowledge Management, Virtual Event, Ireland, October 19-23, 2020. ACM. https://doi.org/10.1145/3340531.3411922.

Duvenaud, David, Dougal Maclaurin, Jorge Aguilera-Iparraguirre, Rafael Gómez-Bombarelli, Timothy Hirzel, Alán Aspuru-Guzik, and Ryan P. Adams. 2015. “Convolutional Networks on Graphs for Learning Molecular Fingerprints.” In Advances in Neural Information Processing Systems 28: Annual Conference on Neural Information Processing Systems 2015, December 7-12, 2015, Montreal, Quebec, Canada.

Falcon et al., William. 2019. “PyTorch Lightning.” GitHub. Note: Https://Github.com/PyTorchLightning/Pytorch-Lightning 3.

Fan, Wenqi, Yao Ma, Qing Li, Jianping Wang, Guoyong Cai, Jiliang Tang, and Dawei Yin. 2022. “A Graph Neural Network Framework for Social Recommendations.” IEEE Transactions on Knowledge and Data Engineering 34 (5).

Fei-Fei, Li, R. Fergus, and P. Perona. 2006. “One-Shot Learning of Object Categories.” IEEE Transactions on Pattern Analysis and Machine Intelligence 28 (4). https://doi.org/10.1109/TPAMI.2006.79.

Fey, Matthias, and Jan E. Lenssen. 2019. “Fast Graph Representation Learning with PyTorch Geometric.” In ICLR Workshop on Representation Learning on Graphs and Manifolds.

Fink, Michael. 2004. “Object Classification from a Single Example Utilizing Class Relevance Metrics.” In Advances in Neural Information Processing Systems 17 [Neural Information Processing Systems, NIPS 2004, December 13-18, 2004, Vancouver, British Columbia, Canada].

Finn, Chelsea, Pieter Abbeel, and Sergey Levine. 2017. “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” In Proceedings of the 34th International Conference on Machine Learning, ICML 2017, Sydney, NSW, Australia, 6-11 August 2017. Vol. 70. Proceedings of Machine Learning Research. PMLR.

Fout, Alex, Jonathon Byrd, Basir Shariat, and Asa Ben-Hur. 2017. “Protein Interface Prediction Using Graph Convolutional Networks.” In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, December 4-9, 2017, Long Beach, CA, USA.

Gao, Hang, Zheng Shou, Alireza Zareian, Hanwang Zhang, and Shih-Fu Chang. 2018. “Low-Shot Learning via Covariance-Preserving Adversarial Augmentation Networks.” In Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, December 3-8, 2018, Montréal, Canada.

Guo, Hongyu, and Yongyi Mao. 2021. “Intrusion-Free Graph Mixup.”

Guo, Zhichun, Chuxu Zhang, Wenhao Yu, John Herr, Olaf Wiest, Meng Jiang, and Nitesh V Chawla. 2021. “Few-Shot Graph Learning for Molecular Property Prediction.” arXiv Preprint arXiv:2102.07916.

Hamilton, William L., Rex Ying, and Jure Leskovec. 2017. “Representation Learning on Graphs: Methods and Applications.” ArXiv abs/1709.05584.

Han, Xiaotian, Zhimeng Jiang, Ninghao Liu, and Xia Hu. 2022. “G-Mixup: Graph Data Augmentation for Graph Classification.” https://arxiv.org/abs/2202.07179.

Hassani, Kaveh. 2022. “Cross-Domain Few-Shot Graph Classification.” https://arxiv.org/abs/2201.08265.

Hu, Weihua, Bowen Liu, Joseph Gomes, Marinka Zitnik, Percy Liang, Vijay S. Pande, and Jure Leskovec. 2020. “Strategies for Pre-Training Graph Neural Networks.” In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020. OpenReview.net.

Jiang, Shunyu, Fuli Feng, Weijian Chen, Xiang Li, and Xiangnan He. 2021. “Structure-Enhanced Meta-Learning for Few-Shot Graph Classification.” AI Open 2.

Kipf, Thomas N., and Max Welling. 2017. “Semi-Supervised Classification with Graph Convolutional Networks.” In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings. OpenReview.net.

Lake, Brenden M., Tomer D. Ullman, Joshua B. Tenenbaum, and Samuel J. Gershman. 2017. “Building Machines That Learn and Think Like People.” Behavioral and Brain Sciences 40. https://doi.org/10.1017/S0140525X16001837.

Liu, Bo, Xudong Wang, Mandar Dixit, Roland Kwitt, and Nuno Vasconcelos. 2018. “Feature Space Transfer for Data Augmentation.” In 2018 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2018, Salt Lake City, UT, USA, June 18-22, 2018. IEEE Computer Society. https://doi.org/10.1109/CVPR.2018.00947.

Luo, Zelun, Yuliang Zou, Judy Hoffman, and Li Fei-Fei. 2017. “Label Efficient Learning of Transferable Representations Across Domains and Tasks.” In Proceedings of the 31st International Conference on Neural Information Processing Systems. NIPS’17. Red Hook, NY, USA: Curran Associates Inc.

Lv, Xin, Yuxian Gu, Xu Han, Lei Hou, Juanzi Li, and Zhiyuan Liu. 2019. “Adapting Meta Knowledge Graph Information for Multi-Hop Reasoning over Few-Shot Relations.” In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP). Hong Kong, China: Association for Computational Linguistics. https://doi.org/10.18653/v1/D19-1334.

Ma, Ning, Jiajun Bu, Jieyu Yang, Zhen Zhang, Chengwei Yao, Zhi Yu, Sheng Zhou, and Xifeng Yan. 2020. “Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification.” In CIKM ’20: The 29th ACM International Conference on Information and Knowledge Management, Virtual Event, Ireland, October 19-23, 2020. ACM. https://doi.org/10.1145/3340531.3411951.

Mikolov, Tomás, Ilya Sutskever, Kai Chen, Gregory S. Corrado, and Jeffrey Dean. 2013. “Distributed Representations of Words and Phrases and Their Compositionality.” In Advances in Neural Information Processing Systems 26: 27th Annual Conference on Neural Information Processing Systems 2013. Proceedings of a Meeting Held December 5-8, 2013, Lake Tahoe, Nevada, United States.

Monti, Federico, Fabrizio Frasca, Davide Eynard, Damon Mannion, and Michael M Bronstein. 2019. “Fake News Detection on Social Media Using Geometric Deep Learning.” ICLR.

Oreshkin, Boris N., Pau Rodrı́guez López, and Alexandre Lacoste. 2018. “TADAM: Task Dependent Adaptive Metric for Improved Few-Shot Learning.” In Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, December 3-8, 2018, Montréal, Canada.

Park, Joonhyung, Hajin Shim, and Eunho Yang. 2022. “Graph Transplant: Node Saliency-Guided Graph Mixup with Local Structure Preservation.” In Proceedings of the First MiniCon Conference.

Ravi, Sachin, and Hugo Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings. OpenReview.net.

Santoro, Adam, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, and Timothy P. Lillicrap. 2016. “Meta-Learning with Memory-Augmented Neural Networks.” In Proceedings of the 33nd International Conference on Machine Learning, ICML 2016, New York City, NY, USA, June 19-24, 2016. Vol. 48. JMLR Workshop and Conference Proceedings. JMLR.org.

Sheng, Jiawei, Shu Guo, Zhenyu Chen, Juwei Yue, Lihong Wang, Tingwen Liu, and Hongbo Xu. 2020. “Adaptive Attentional Network for Few-Shot Knowledge Graph Completion.” In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP). Online: Association for Computational Linguistics. https://doi.org/10.18653/v1/2020.emnlp-main.131.

Smirnov, Dmitriy, and Justin Solomon. 2021. “HodgeNet: Learning Spectral Geometry on Triangle Meshes.” ACM Transactions on Graphics (TOG) 40 (4).

Snell, Jake, Kevin Swersky, and Richard S. Zemel. 2017. “Prototypical Networks for Few-Shot Learning.” In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, December 4-9, 2017, Long Beach, CA, USA.

Stokes, Jonathan M., Kevin Yang, Kyle Swanson, Wengong Jin, Andres Cubillos-Ruiz, Nina M. Donghia, Craig R. MacNair, et al. 2020. “A Deep Learning Approach to Antibiotic Discovery.” Cell 180 (4).

Sun, Fan-Yun, Jordan Hoffmann, Vikas Verma, and Jian Tang. 2020. “InfoGraph: Unsupervised and Semi-Supervised Graph-Level Representation Learning via Mutual Information Maximization.” In 8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020. OpenReview.net.

Topping, Jake, Francesco Di Giovanni, Benjamin Paul Chamberlain, Xiaowen Dong, and Michael M. Bronstein. 2021. “Understanding over-Squashing and Bottlenecks on Graphs via Curvature.” arXiv.

Tsai, Yao-Hung Hubert, and Ruslan Salakhutdinov. 2017. “Improving One-Shot Learning Through Fusing Side Information.” CoRR abs/1710.08347. https://arxiv.org/abs/1710.08347.

Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. “Attention Is All You Need.” In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, December 4-9, 2017, Long Beach, CA, USA.

Velickovic, Petar, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio. 2018. “Graph Attention Networks.” In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net.

Vinyals, Oriol, Charles Blundell, Tim Lillicrap, Koray Kavukcuoglu, and Daan Wierstra. 2016a. “Matching Networks for One Shot Learning.” In Advances in Neural Information Processing Systems 29: Annual Conference on Neural Information Processing Systems 2016, December 5-10, 2016, Barcelona, Spain.

———. 2016b. “Matching Networks for One Shot Learning.” In Advances in Neural Information Processing Systems 29: Annual Conference on Neural Information Processing Systems 2016, December 5-10, 2016, Barcelona, Spain.

Wang, Ning, Minnan Luo, Kaize Ding, Lingling Zhang, Jundong Li, and Qinghua Zheng. 2020. “Graph Few-Shot Learning with Attribute Matching.” In CIKM ’20: The 29th ACM International Conference on Information and Knowledge Management, Virtual Event, Ireland, October 19-23, 2020. ACM. https://doi.org/10.1145/3340531.3411923.

Wang, Song, Xiao Huang, Chen Chen, Liang Wu, and Jundong Li. 2021. “REFORM: Error-Aware Few-Shot Knowledge Graph Completion.” In Proceedings of the 30th ACM International Conference on Information i& Knowledge Management. New York, NY, USA: Association for Computing Machinery.

Wang, Yaqing, Abulikemu Abuduweili, Quanming Yao, and Dejing Dou. 2021. “Property-Aware Relation Networks for Few-Shot Molecular Property Prediction.” In Advances in Neural Information Processing Systems.

Wang, Yaqing, Quanming Yao, James T. Kwok, and Lionel M. Ni. 2020. “Generalizing from a Few Examples: A Survey on Few-Shot Learning.” ACM Comput. Surv. 53 (3).

Wang, Yiwei, Wei Wang, Yuxuan Liang, Yujun Cai, and Bryan Hooi. 2020. “GraphCrop: Subgraph Cropping for Graph Classification.” https://arxiv.org/abs/2009.10564.

———. 2021. “Mixup for Node and Graph Classification.” In Proceedings of the Web Conference 2021. WWW ’21. New York, NY, USA: Association for Computing Machinery.

Wu, Yu, Yutian Lin, Xuanyi Dong, Yan Yan, Wanli Ouyang, and Yi Yang. 2018. “Exploit the Unknown Gradually: One-Shot Video-Based Person Re-Identification by Stepwise Learning.” In 2018 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2018, Salt Lake City, UT, USA, June 18-22, 2018. IEEE Computer Society. https://doi.org/10.1109/CVPR.2018.00543.

Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. 2019. “How Powerful Are Graph Neural Networks?” In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net.

Xu, Keyulu, Chengtao Li, Yonglong Tian, Tomohiro Sonobe, Ken-ichi Kawarabayashi, and Stefanie Jegelka. 2018. “Representation Learning on Graphs with Jumping Knowledge Networks.” In Proceedings of the 35th International Conference on Machine Learning, ICML 2018, Stockholmsmässan, Stockholm, Sweden, July 10-15, 2018. Vol. 80. Proceedings of Machine Learning Research. PMLR.

Yao, Huaxiu, Chuxu Zhang, Ying Wei, Meng Jiang, Suhang Wang, Junzhou Huang, Nitesh V. Chawla, and Zhenhui Li. 2019. “Graph Few-Shot Learning via Knowledge Transfer.” CoRR abs/1910.03053. https://arxiv.org/abs/1910.03053.

Yoon, Jaesik, Taesup Kim, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn. 2018. “Bayesian Model-Agnostic Meta-Learning.” In Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, December 3-8, 2018, Montréal, Canada.

Zhang, Jiaying, Xiaoli Zhao, Zheng Chen, and Zhejun Lu. 2019. “A Review of Deep Learning-Based Semantic Segmentation for Point Cloud.” IEEE Access 7.

Zhang, Shengzhong, Ziang Zhou, Zengfeng Huang, and Zhongyu Wei. 2019. “Few-Shot Classification on Graphs with Structural Regularized GCNs.”

Zhou, Fan, Chengtai Cao, Kunpeng Zhang, Goce Trajcevski, Ting Zhong, and Ji Geng. 2019. “Meta-GNN: On Few-Shot Node Classification in Graph Meta-Learning.” In Proceedings of the 28th ACM International Conference on Information and Knowledge Management, CIKM 2019, Beijing, China, November 3-7, 2019. ACM. https://doi.org/10.1145/3357384.3358106.

Donato Crisostomi
Donato Crisostomi
Ph.D. candidate in Computer Science

My research interests revolve around artificial intelligence, in particular model merging and representational alignment.