Tree Matching Networks: What If We Gave Neural Networks the Parse Tree Instead?
Jason Lunder
- 8 minutes read - 1661 wordsWhen we feed text to a transformer, the model has to figure out the structure of language entirely from data. Subject-verb relationships, modifier scope, negation boundaries: all of it must be learned implicitly from millions (or billions) of examples. But what if we just… told the model the structure upfront?
That’s the core idea behind Tree Matching Networks (TMN), my recent research exploring whether dependency parse trees can serve as an efficient structural prior for natural language inference. The short version: a graph neural network operating on parse trees achieved 75.20% accuracy on the SNLI benchmark versus 35.38% for a BERT baseline of comparable size, both trained from scratch on the same data and the same hardware.
This post walks through the motivation, architecture, and what I think it means for the field.
The question behind the question
Transformers work. That’s not in dispute. But why they work, and specifically what structural information they’re learning internally, is still an active area of research. Attention patterns in BERT have been shown to loosely correspond to dependency relations (Clark et al., 2019), suggesting the model spends some of its capacity rediscovering linguistic structure that we already know how to extract.
This raises a practical question: if we hand the model that structure explicitly, can we get comparable performance with less compute? Not as a replacement for transformers, but as a way to understand what they’re actually doing and whether there’s a more efficient path.
From tokens to trees
A standard NLP pipeline treats a sentence as a sequence of tokens. A dependency parser instead produces a tree, where each word is a node and edges represent grammatical relationships.
Take the sentence “The cat sat on the mat.” As a token sequence, the model sees six tokens in order. As a dependency tree, it sees that sat is the root, cat is its subject (nsubj), mat is the object of the preposition, and the modifies both nouns. The hierarchical relationships are explicit.
This representation isn’t new; dependency parsing has been a core NLP tool for decades. What’s new here is using these trees as the input representation for a neural network designed to compare sentence pairs, rather than using them as a feature extraction step or auxiliary training signal.
Architecture: adapting graph matching for language
TMN builds on Graph Matching Networks (GMN), originally designed for comparing arbitrary graph structures. The key adaptation is replacing generic graph inputs with dependency parse trees generated by DiaParser and enriched with word embeddings from SpaCy.
The architecture has three main components:
1. Node and edge encoding. Each word in the parse tree becomes a node with an initial embedding. Dependency relation labels (nsubj, dobj, amod, etc.) become edge features. This gives the network both semantic (what the word means) and syntactic (how it relates to other words) information from the start.
2. Message-passing propagation. Nodes iteratively update their representations by aggregating information from their neighbors in the tree. After several rounds of propagation, each node’s representation encodes not just its own meaning but its structural context: a verb knows about its subject and object, a modifier knows what it modifies. This is the core GNN mechanism: local structure gets encoded into node states through repeated message passing.
3. Cross-graph matching. For comparing two sentences (as required for natural language inference), TMN computes attention-weighted correspondences between nodes across the two trees. This cross-graph attention allows the model to align semantically similar subtrees between the premise and hypothesis, then aggregate these alignments into a final similarity score.
Training: making it work on one GPU
One of the practical constraints of this research was compute: a single NVIDIA RTX 3090 on a desktop machine. No cluster. No cloud budget. This shaped the training strategy significantly.
Multi-stage contrastive learning. Rather than jumping straight to 3-class NLI classification (entailment, neutral, contradiction), TMN uses a staged approach:
-
Pretraining with InfoNCE contrastive loss on sentence pairs from WikiQS and AmazonQA (~7M sentences). This teaches the model basic semantic similarity: push similar sentences together, dissimilar ones apart.
-
Primary training with multi-objective InfoNCE on SNLI’s 550K labeled pairs. Instead of hard classification, this stage uses a softer objective: entailed pairs should be more similar than neutral pairs, which should be more similar than contradictory pairs. The relative ordering matters more than absolute boundaries.
This staged approach creates a smoother learning curve that works within the constraints of limited compute. With larger models and more GPUs, you can often get away with training everything end-to-end. With 280 GPU-hours on a single card, curriculum matters.
Why InfoNCE specifically? It’s well-understood for embedding tasks, works naturally with the similarity-based framing of NLI, and extends cleanly to the multi-class setting. The multi-objective variant lets us express the three-way relationship (entailment > neutral > contradiction in similarity) without forcing hard decision boundaries during training.
Results
The primary comparison: TreeMatchingNet (36M parameters) achieves 75.20% accuracy on the SNLI test set. BertMatchingNet (41M parameters, trained identically from scratch) achieves 35.38%, just above the 33.33% random baseline.
An important caveat up front: neither model is pretrained on large corpora. The comparison is controlled by design, isolating structural inductive bias rather than pretraining data. The BERT baseline here is not standard pretrained BERT; it is BERT trained from scratch on the same ~7M sentences as TMN. This makes the comparison fairer for understanding what architecture contributes, but it means absolute accuracy numbers shouldn’t be read against production systems.
The BERT failure mode is worth understanding. BertMatchingNet doesn’t simply underperform; it predicts Entailment for every test example, achieving 100% recall on that class and 0% on the other two. When we remove the cross-graph attention and train BertEmbeddingNet (which processes sentences independently), BERT recovers to 45.78% and learns non-trivial class distinctions. This suggests the cross-attention mechanism specifically is interfering with BERT’s learning under the randomized pairing regime used during pretraining, rather than the training protocol generally.
TMN responds differently. TreeEmbeddingNet (without cross-graph attention) reaches 57.57%, while TreeMatchingNet (with cross-graph attention) reaches 75.20%, a gap of 17.6 percentage points. The GNN-based matching architecture appears compatible with this regime in a way the transformer-based approach is not.
Scaling across three model sizes shows consistent improvement: Small (1M params, 60.53%), Medium (19M params, 68.81%), Large (36M params, 75.20%). Gains per additional parameter decrease at larger scales, consistent with the aggregation bottleneck described below. It is also consistent with a broader pattern in the literature: transformer models follow reliable power-law scaling with parameters and compute (Kaplan et al., 2020), whereas structure-based feedforward architectures have not been shown to do the same. TMN improves with scale, but the results here don’t provide evidence the trend would continue at much larger scales.
TMN also transfers to a different task with no additional training. On the SemEval semantic textual similarity benchmark, TMN matching achieves a Pearson correlation of 0.716. BertMatchingNet achieves 0.003 on the same task. The structural representations generalize across tasks, which is at least consistent with the model encoding something about sentence relationships rather than surface patterns.
The aggregation bottleneck
One architectural finding has held up consistently: the bottleneck in this architecture appears to be in the aggregation step (how node-level representations get combined into a sentence-level embedding) rather than in the propagation step (how nodes exchange information within the tree).
The current implementation uses weighted pooling, collapsing all node representations into a single sentence embedding. If a verb node has learned about its subject and object through message passing, that information gets averaged together with every other node equally, potentially losing the hierarchical signal the propagation worked to encode.
My thesis work investigates replacing this with multi-headed self-attention aggregation: essentially using a small transformer as the aggregation function. The hypothesis is that more expressive aggregation could better preserve what the message-passing step captures, rather than collapsing it into a flat average.
What this means (and what it doesn’t)
What it means:
- Explicit structural encoding provides a useful inductive bias at small scales. When you can’t afford to spend billions of parameters on implicitly learning structure, providing it directly appears to help.
- There’s a genuine question about what transformers learn versus what could be provided upfront. Parse trees encode linguistic knowledge that otherwise must be extracted from data.
- Controlled comparisons at moderate scale can still be informative, even on a single consumer GPU.
What it doesn’t mean:
- Trees will replace transformers. They won’t. Transformer models follow consistent scaling laws that structure-based feedforward architectures don’t appear to match (Kaplan et al., 2020). The advantage shown here at moderate scales may not hold as parameters grow into the billions.
- Structure-based approaches are production-ready. This is research exploring a specific hypothesis about the role of structure in language understanding, under constrained conditions.
Where this goes next
The immediate next step is the self-attention aggregation experiments for my thesis. Beyond that, the research direction I find most compelling is what I call the “propagation-then-transformer” approach:
- Parse each sentence into a dependency tree
- Run GNN propagation to produce structure-aware node representations
- Feed those representations (instead of token embeddings) into a transformer
After propagation, the node states are enriched token embeddings; they carry the same kind of information as standard embeddings, plus explicit structural context. The question is whether starting from structurally-informed representations lets you train a smaller or more efficient transformer.
Testing this properly requires compute beyond what’s available for a master’s thesis. But the TMN results suggest the structural prior has real value, and figuring out how to combine it with the scaling properties of transformers seems like a question worth pursuing.
Try it yourself
The code is available on GitHub, and the paper is on arXiv.
If you have questions or want to discuss tree-based NLP approaches, feel free to reach out.
Jason Lunder is an ML engineer and researcher at Intellipat Inc. and an MS Computer Science candidate at Eastern Washington University. His research focuses on tree-based architectures for natural language processing.