Chapter 12: Graph Neural Networks
Train neural networks directly on graph structure inside AstraeaDB. Learn how message passing, differentiable edge weights, and in-database training eliminate the need to export data to external ML frameworks.
12.1 What Are GNNs?
Neural Networks on Graphs
Traditional neural networks operate on fixed-size inputs: images have a grid of pixels, text has a sequence of tokens. But graphs have no fixed structure -- each node can have a different number of neighbors, and the topology varies throughout the graph. Graph Neural Networks (GNNs) are a family of neural network architectures designed specifically for this irregular, relational data.
The Key Idea: Message Passing
GNNs work through a process called message passing. In each layer of the network, every node:
- Gathers feature vectors from its neighbors (and their edge weights).
- Aggregates those features into a single summary (e.g., sum, mean, or max).
- Updates its own feature vector based on the aggregated neighborhood information.
After multiple rounds of message passing, each node's feature vector encodes information not just about itself, but about its entire local neighborhood. A node that is 3 message-passing layers deep effectively "sees" all nodes within 3 hops.
Message Passing Illustrated
Consider a simple graph where we want to classify node C:
Round 0 (initial features): A [0.1, 0.9] B [0.8, 0.2] C [0.5, 0.5] D [0.3, 0.7] \ | | / \ | | / ---------- edges -----------+------------ Round 1 (after 1 layer of message passing): C gathers from neighbors: A, B, D C aggregates: mean([0.1, 0.9], [0.8, 0.2], [0.3, 0.7]) = [0.4, 0.6] C updates: combine([0.5, 0.5], [0.4, 0.6]) = [0.45, 0.55] (new features) Round 2 (after 2 layers): C now incorporates 2-hop neighborhood information Its features encode the structure of the local graph
Node Classification
The most common GNN task is node classification: predict a label for each node based on its features and its neighborhood structure. Examples include:
- Fraud detection: Predict whether a user account is fraudulent based on its transaction patterns and the accounts it connects to.
- Document categorization: Classify papers in a citation network by their topic, using both the paper's content and what it cites.
- Social role prediction: Predict a user's role or interests based on their connections.
12.2 Setting Up Training Data
To train a GNN in AstraeaDB, your nodes need two things:
- Feature vectors (embeddings): Numeric representations of each node's attributes. These are the same embeddings used for vector search (Chapter 8).
- Ground truth labels: Known classifications for a subset of nodes. The GNN learns to predict labels for the remaining unlabeled nodes.
Creating Labeled Nodes
from astraeadb import AstraeaClient with AstraeaClient("127.0.0.1", 7687) as client: # Create nodes with embeddings and known labels alice = client.create_node( labels=["User"], properties={"name": "Alice", "label": "legitimate"}, embedding=[0.1, 0.2, 0.05, 0.3, 0.15, 0.1, 0.25, 0.08] ) bob = client.create_node( labels=["User"], properties={"name": "Bob", "label": "legitimate"}, embedding=[0.12, 0.18, 0.07, 0.28, 0.14, 0.11, 0.22, 0.09] ) bad_bot = client.create_node( labels=["User"], properties={"name": "BadBot", "label": "fraudulent"}, embedding=[0.9, 0.8, 0.85, 0.7, 0.95, 0.88, 0.75, 0.92] ) shill = client.create_node( labels=["User"], properties={"name": "Shill42", "label": "fraudulent"}, embedding=[0.88, 0.82, 0.9, 0.72, 0.91, 0.85, 0.78, 0.89] ) # A node with no label -- the GNN will predict this unknown = client.create_node( labels=["User"], properties={"name": "NewUser99"}, # no "label" property embedding=[0.7, 0.6, 0.75, 0.55, 0.8, 0.65, 0.58, 0.72] ) # Create edges (transaction relationships) client.create_edge(alice["node_id"], bob["node_id"], "TRANSACTED", {"amount": 50.0}) client.create_edge(bad_bot["node_id"], shill["node_id"], "TRANSACTED", {"amount": 10000.0}) client.create_edge(shill["node_id"], bad_bot["node_id"], "TRANSACTED", {"amount": 9500.0}) client.create_edge(unknown["node_id"], bad_bot["node_id"], "TRANSACTED", {"amount": 8000.0}) client.create_edge(unknown["node_id"], shill["node_id"], "TRANSACTED", {"amount": 7500.0})
source("r_client.R") client <- AstraeaClient$new("127.0.0.1", 7687) client$connect() # Create labeled nodes with embeddings alice <- client$create_node( labels = list("User"), properties = list(name = "Alice", label = "legitimate"), embedding = c(0.1, 0.2, 0.05, 0.3, 0.15, 0.1, 0.25, 0.08) ) bob <- client$create_node( labels = list("User"), properties = list(name = "Bob", label = "legitimate"), embedding = c(0.12, 0.18, 0.07, 0.28, 0.14, 0.11, 0.22, 0.09) ) bad_bot <- client$create_node( labels = list("User"), properties = list(name = "BadBot", label = "fraudulent"), embedding = c(0.9, 0.8, 0.85, 0.7, 0.95, 0.88, 0.75, 0.92) ) shill <- client$create_node( labels = list("User"), properties = list(name = "Shill42", label = "fraudulent"), embedding = c(0.88, 0.82, 0.9, 0.72, 0.91, 0.85, 0.78, 0.89) ) # Unlabeled node for prediction unknown <- client$create_node( labels = list("User"), properties = list(name = "NewUser99"), embedding = c(0.7, 0.6, 0.75, 0.55, 0.8, 0.65, 0.58, 0.72) ) # Create edges client$create_edge(alice$node_id, bob$node_id, "TRANSACTED", list(amount = 50.0)) client$create_edge(bad_bot$node_id, shill$node_id, "TRANSACTED", list(amount = 10000.0)) client$create_edge(shill$node_id, bad_bot$node_id, "TRANSACTED", list(amount = 9500.0)) client$create_edge(unknown$node_id, bad_bot$node_id, "TRANSACTED", list(amount = 8000.0)) client$create_edge(unknown$node_id, shill$node_id, "TRANSACTED", list(amount = 7500.0)) client$close()
package main import ( "context" "github.com/AstraeaDB/AstraeaDB-Official" ) func main() { client := astraeadb.NewClient(astraeadb.WithAddress("127.0.0.1", 7687)) ctx := context.Background() client.Connect(ctx) defer client.Close() // Create labeled nodes alice, _ := client.CreateNode(ctx, []string{"User"}, map[string]any{"name": "Alice", "label": "legitimate"}, []float32{0.1, 0.2, 0.05, 0.3, 0.15, 0.1, 0.25, 0.08}) bob, _ := client.CreateNode(ctx, []string{"User"}, map[string]any{"name": "Bob", "label": "legitimate"}, []float32{0.12, 0.18, 0.07, 0.28, 0.14, 0.11, 0.22, 0.09}) badBot, _ := client.CreateNode(ctx, []string{"User"}, map[string]any{"name": "BadBot", "label": "fraudulent"}, []float32{0.9, 0.8, 0.85, 0.7, 0.95, 0.88, 0.75, 0.92}) shill, _ := client.CreateNode(ctx, []string{"User"}, map[string]any{"name": "Shill42", "label": "fraudulent"}, []float32{0.88, 0.82, 0.9, 0.72, 0.91, 0.85, 0.78, 0.89}) unknown, _ := client.CreateNode(ctx, []string{"User"}, map[string]any{"name": "NewUser99"}, []float32{0.7, 0.6, 0.75, 0.55, 0.8, 0.65, 0.58, 0.72}) // Create edges client.CreateEdge(ctx, alice.NodeID, bob.NodeID, "TRANSACTED", map[string]any{"amount": 50.0}) client.CreateEdge(ctx, badBot.NodeID, shill.NodeID, "TRANSACTED", map[string]any{"amount": 10000.0}) client.CreateEdge(ctx, shill.NodeID, badBot.NodeID, "TRANSACTED", map[string]any{"amount": 9500.0}) client.CreateEdge(ctx, unknown.NodeID, badBot.NodeID, "TRANSACTED", map[string]any{"amount": 8000.0}) client.CreateEdge(ctx, unknown.NodeID, shill.NodeID, "TRANSACTED", map[string]any{"amount": 7500.0}) }
import com.astraeadb.unified.UnifiedClient; import java.util.List; import java.util.Map; try (var client = UnifiedClient.builder() .host("127.0.0.1").port(7687).build()) { client.connect(); // Create labeled nodes with embeddings var alice = client.createNode( List.of("User"), Map.of("name", "Alice", "label", "legitimate"), new float[]{0.1f, 0.2f, 0.05f, 0.3f, 0.15f, 0.1f, 0.25f, 0.08f}); var bob = client.createNode( List.of("User"), Map.of("name", "Bob", "label", "legitimate"), new float[]{0.12f, 0.18f, 0.07f, 0.28f, 0.14f, 0.11f, 0.22f, 0.09f}); var badBot = client.createNode( List.of("User"), Map.of("name", "BadBot", "label", "fraudulent"), new float[]{0.9f, 0.8f, 0.85f, 0.7f, 0.95f, 0.88f, 0.75f, 0.92f}); var shill = client.createNode( List.of("User"), Map.of("name", "Shill42", "label", "fraudulent"), new float[]{0.88f, 0.82f, 0.9f, 0.72f, 0.91f, 0.85f, 0.78f, 0.89f}); var unknown = client.createNode( List.of("User"), Map.of("name", "NewUser99"), new float[]{0.7f, 0.6f, 0.75f, 0.55f, 0.8f, 0.65f, 0.58f, 0.72f}); // Create edges client.createEdge(alice.getNodeId(), bob.getNodeId(), "TRANSACTED", Map.of("amount", 50.0)); client.createEdge(badBot.getNodeId(), shill.getNodeId(), "TRANSACTED", Map.of("amount", 10000.0)); client.createEdge(shill.getNodeId(), badBot.getNodeId(), "TRANSACTED", Map.of("amount", 9500.0)); client.createEdge(unknown.getNodeId(), badBot.getNodeId(), "TRANSACTED", Map.of("amount", 8000.0)); client.createEdge(unknown.getNodeId(), shill.getNodeId(), "TRANSACTED", Map.of("amount", 7500.0)); }
Understanding the Data
In this fraud detection scenario:
- Alice and Bob are labeled as
"legitimate". They have a single small transaction between them. - BadBot and Shill42 are labeled as
"fraudulent". They have large, circular transactions (a classic fraud pattern). - NewUser99 has no label. It transacts heavily with both known fraudulent accounts. The GNN should predict this as "fraudulent" based on its neighborhood.
Message Passing Configuration
AstraeaDB supports several aggregation functions and activation functions for the message passing layers:
| Setting | Options | Description |
|---|---|---|
| Aggregation | Sum, Mean, Max | How neighbor features are combined. Sum preserves magnitude; Mean normalizes by degree; Max takes the strongest signal. |
| Activation | ReLU, Sigmoid | Non-linearity applied after aggregation. ReLU is the default for hidden layers; Sigmoid is used for the output layer in binary classification. |
| Layers | 1-5 (typical) | Number of message passing rounds. More layers = larger receptive field, but risk of over-smoothing. |
12.3 Training a Model
AstraeaDB's GNN training runs entirely inside the database. There is no need to export your graph to PyTorch Geometric, DGL, or any external framework. This eliminates the data engineering overhead of ETL pipelines and ensures the model always trains on the most current data.
The Training Loop
Each training epoch follows these four steps:
- Forward pass (message passing): For each layer, every node aggregates its neighbors' features (weighted by edge weights), applies the activation function, and produces updated features. After all layers, each node's features encode its neighborhood structure.
- Loss computation: For nodes with known labels, the model's predictions are compared against the ground truth. AstraeaDB uses cross-entropy loss for classification tasks.
- Gradient estimation: AstraeaDB uses numerical gradients (finite differences) rather than automatic differentiation. For each trainable parameter, the loss is evaluated at
param + epsilonandparam - epsilon, and the gradient is approximated as(loss+ - loss-) / (2 * epsilon). - Weight update: All trainable parameters (edge weights, layer weights) are updated using gradient descent:
param = param - learning_rate * gradient.
Conceptual Training Configuration
GNN training is currently configured via the Rust API, with client wrapper support planned for future releases.
use astraeadb::gnn::{GnnConfig, Aggregation, Activation}; let config = GnnConfig { // Network architecture num_layers: 2, hidden_dim: 16, output_classes: 2, // "legitimate" vs "fraudulent" aggregation: Aggregation::Mean, activation: Activation::ReLU, // Training parameters epochs: 50, learning_rate: 0.01, label_key: "label".to_string(), // property containing ground truth // Label mapping label_map: vec![ ("legitimate".to_string(), 0), ("fraudulent".to_string(), 1), ], // Gradient estimation epsilon: 1e-5, }; // Run training let result = graph.train_gnn(config)?; // Inspect training progress for (epoch, loss) in result.epoch_losses.iter().enumerate() { println!("Epoch {}: loss = {:.4}", epoch, loss); } println!("Final accuracy: {:.1}%", result.accuracy * 100.0);
Expected Training Output
During training, you will see the loss decreasing as the model learns:
Epoch 0: loss = 0.6931 (random initialization, ~coin flip) Epoch 5: loss = 0.5412 Epoch 10: loss = 0.3876 Epoch 20: loss = 0.1923 Epoch 30: loss = 0.0847 Epoch 50: loss = 0.0234 Final accuracy: 100.0% (on labeled nodes)
The loss starts near 0.693 (the natural log of 2, which is the maximum entropy for binary classification -- essentially random guessing). As training proceeds, the loss drops as the model learns to distinguish legitimate from fraudulent patterns.
Making Predictions
After training, use the trained model to classify unlabeled nodes:
// Get predictions for all unlabeled nodes let predictions = graph.predict_gnn()?; for (node_id, prediction) in predictions { let label = if prediction.class_id == 0 { "legitimate" } else { "fraudulent" }; println!("Node {}: {} (confidence: {:.1}%)", node_id, label, prediction.confidence * 100.0); } // Output: // Node nd-5 (NewUser99): fraudulent (confidence: 94.2%)
12.4 Differentiable Edge Weights
AstraeaDB's Unique Innovation
In most graph databases, edge weights are static values set at creation time. In AstraeaDB, edge weights are differentiable tensors -- they are trainable parameters that the GNN can update during the learning process.
This is a fundamental architectural decision rooted in AstraeaDB's "Vector-Property Graph" data model (described in the project architecture). Edges do not just carry metadata; they carry learnable parameters that encode the strength and nature of relationships.
How It Works
- Initialization: All edges start with a uniform weight (typically 1.0). At this point, the GNN treats all connections equally.
- Training: During each epoch, the gradient of the loss function with respect to each edge weight is computed. Edges that contribute to correct predictions receive gradient updates that strengthen them; edges that contribute to incorrect predictions are weakened.
- After training: Edge weights encode learned relationship strengths. An edge with a high weight (e.g., 2.3) indicates a strong, predictive connection. An edge with a low weight (e.g., 0.1) indicates a weak or irrelevant connection.
What This Means in Practice
Consider the fraud detection example. Before training, all TRANSACTED edges have a weight of 1.0. After training:
Edge weights after training: Alice --TRANSACTED--> Bob weight: 0.15 (low -- normal transaction, not predictive) BadBot --TRANSACTED--> Shill42 weight: 2.41 (high -- circular fraud pattern) Shill42 --TRANSACTED--> BadBot weight: 2.38 (high -- circular fraud pattern) NewUser99 --TRANSACTED--> BadBot weight: 1.87 (high -- connected to known fraud) NewUser99 --TRANSACTED--> Shill42 weight: 1.92 (high -- connected to known fraud)
The model has learned that:
- Circular transactions between accounts are highly predictive of fraud (BadBot <-> Shill42 edges have the highest weights).
- Normal transactions between legitimate users are not predictive (Alice -> Bob has a low weight).
- Connections to known fraudulent accounts are suspicious (NewUser99's edges have elevated weights).
Querying Learned Weights
After training, the learned edge weights persist in the graph and can be queried like any other edge property:
with AstraeaClient("127.0.0.1", 7687) as client: # Query edges with high learned weights (suspicious connections) edges = client.query(""" MATCH (a:User)-[t:TRANSACTED]->(b:User) WHERE t.weight > 1.5 RETURN a.name, b.name, t.weight, t.amount ORDER BY t.weight DESC """) print("High-weight edges (suspicious connections):") for row in edges: print(f" {row['a.name']} -> {row['b.name']}: " f"weight={row['t.weight']:.2f}, amount={row['t.amount']}")
edges <- client$query(" MATCH (a:User)-[t:TRANSACTED]->(b:User) WHERE t.weight > 1.5 RETURN a.name, b.name, t.weight, t.amount ORDER BY t.weight DESC ") cat("High-weight edges (suspicious connections):\n") for (row in edges) { cat(sprintf(" %s -> %s: weight=%.2f, amount=%s\n", row$"a.name", row$"b.name", row$"t.weight", row$"t.amount")) }
edges, _ := client.Query(ctx, ` MATCH (a:User)-[t:TRANSACTED]->(b:User) WHERE t.weight > 1.5 RETURN a.name, b.name, t.weight, t.amount ORDER BY t.weight DESC `) fmt.Println("High-weight edges (suspicious connections):") for _, row := range edges { fmt.Printf(" %s -> %s: weight=%.2f, amount=%v\n", row["a.name"], row["b.name"], row["t.weight"], row["t.amount"]) }
var edges = client.query(""" MATCH (a:User)-[t:TRANSACTED]->(b:User) WHERE t.weight > 1.5 RETURN a.name, b.name, t.weight, t.amount ORDER BY t.weight DESC """); System.out.println("High-weight edges (suspicious connections):"); for (var row : edges) { System.out.printf(" %s -> %s: weight=%.2f, amount=%s%n", row.get("a.name"), row.get("b.name"), row.get("t.weight"), row.get("t.amount")); }
Use Cases for Differentiable Edges
| Domain | Initial State | After Training |
|---|---|---|
| Fraud Detection | All transaction edges weight = 1.0 | Suspicious circular transactions get high weights; normal transactions get low weights |
| Recommendation | All user-item edges weight = 1.0 | Strong preferences amplified; weak or incidental interactions dampened |
| Knowledge Graphs | All "related_to" edges weight = 1.0 | Causally important relationships get high weights; coincidental co-occurrences get low weights |
| Network Security | All connection edges weight = 1.0 | Lateral movement paths get high weights; routine traffic gets low weights |
The Bigger Picture
Differentiable edge weights transform AstraeaDB from a passive data store into an active learning system. The database does not just store your graph -- it learns from it. By running the training loop inside the database, you eliminate the traditional ML pipeline of "export data, train model, import predictions" and replace it with a single train_gnn call that updates the graph in place.
This is the realization of the "Differentiable Traversal" concept from AstraeaDB's architecture: the query execution plan itself is differentiable, and backpropagation updates edge weights directly inside the database.