Introduction
Graph Neural Networks (GNNs) are a class of deep learning models designed specifically for graph-structured data. Unlike traditional neural networks that operate on fixed-size inputs (images, sequences), GNNs can learn from the topology of a graph—how nodes connect to each other—and use that structure to make predictions.
AstraeaDB's astraea-gnn crate provides a lightweight, educational GNN framework that integrates directly with your graph database. Train models on your graph data without extracting it to external ML frameworks.
What Can GNNs Do?
- Node Classification — Predict labels for nodes (e.g., user categories, fraud detection)
- Link Prediction — Predict missing edges (e.g., friend recommendations)
- Graph Classification — Classify entire graphs (e.g., molecule properties)
- Node Embeddings — Learn dense representations that capture graph structure
Key Concepts
Node Features
Each node has a feature vector (also called an embedding). In AstraeaDB, these come from the embedding field on nodes:
let node = graph.create_node( labels, properties, Some(vec![0.5, -0.3, 0.8, 0.1]) // Feature vector )?;
Edge Weights
Edges can have weights that control how much influence neighbors have. These weights are the trainable parameters in AstraeaDB's GNN:
let edge = graph.create_edge( source, target, edge_type, properties, 1.5, // Weight (trainable) None, None )?;
Message Passing
The core GNN operation. Each node updates its features by aggregating information from its neighbors:
Aggregate: Mi = AGG({ mj→i : j ∈ N(i) })
Update: hi(new) = σ(Mi)
Where:
hj= feature vector of neighbor jwij= edge weight from j to iN(i)= neighbors of node iAGG= aggregation function (sum, mean, or max)σ= activation function (ReLU, sigmoid)
Why GNNs for Graph Data?
| Traditional ML | Graph Neural Networks |
|---|---|
| Treats each node independently | Learns from node + neighborhood |
| Requires manual feature engineering | Learns structural features automatically |
| Fixed input size | Handles variable-size neighborhoods |
| No relationship awareness | Explicitly models connections |
| Loses graph topology | Preserves and exploits structure |
Tensor Operations
The Tensor struct is a 1D differentiable vector for node features with gradient tracking.
Creating Tensors
use astraea_gnn::Tensor; // From data (with gradient tracking) let features = Tensor::new(vec![1.0, 2.0, 3.0], true); // Zero-filled let zeros = Tensor::zeros(4, false); // Single value let scalar = Tensor::from_scalar(5.0);
Arithmetic Operations
let a = Tensor::new(vec![1.0, 2.0, 3.0], false); let b = Tensor::new(vec![4.0, 5.0, 6.0], false); // Element-wise operations let sum = a.add(&b); // [5.0, 7.0, 9.0] let product = a.mul(&b); // [4.0, 10.0, 18.0] let scaled = a.scale(2.5); // [2.5, 5.0, 7.5]
Reduction Operations
let v = Tensor::new(vec![1.0, 2.0, 3.0], false); let dot_product = a.dot(&b); // 1*4 + 2*5 + 3*6 = 32.0 let total = v.sum(); // 6.0 let average = v.mean(); // 2.0 let l2_norm = v.norm(); // sqrt(14) ≈ 3.74
Activation Functions
let v = Tensor::new(vec![-1.0, 0.0, 2.0], false); let relu = v.relu(); // [0.0, 0.0, 2.0] - max(0, x) let sig = v.sigmoid(); // [0.27, 0.5, 0.88] - 1/(1+e^-x)
Gradient Management
let t = Tensor::new(vec![1.0, 2.0], true); // Store computed gradient t.set_grad(vec![0.1, 0.2]); // Retrieve gradient let grad = t.grad(); // Some([0.1, 0.2]) // Clear gradient for next iteration t.zero_grad();
Complete Tensor API
| Method | Description | Returns |
|---|---|---|
new(data, requires_grad) | Create from vector | Tensor |
zeros(len, requires_grad) | Zero-filled tensor | Tensor |
from_scalar(value) | Single-value tensor | Tensor |
add(&other) | Element-wise addition | Tensor |
mul(&other) | Element-wise multiplication | Tensor |
scale(s) | Scalar multiplication | Tensor |
dot(&other) | Dot product | f32 |
sum() | Sum all elements | f32 |
mean() | Mean of elements | f32 |
norm() | L2 (Euclidean) norm | f32 |
relu() | ReLU activation | Tensor |
sigmoid() | Sigmoid activation | Tensor |
Message Passing
The message_passing function is the core GNN operation. It updates all node features based on their neighborhoods.
Basic Usage
use astraea_gnn::{message_passing, MessagePassingConfig, Tensor}; use std::collections::HashMap; // Initialize node features let mut features = HashMap::new(); features.insert(node_a, Tensor::new(vec![1.0, 0.0], false)); features.insert(node_b, Tensor::new(vec![0.0, 1.0], false)); features.insert(node_c, Tensor::new(vec![1.0, 1.0], false)); // Set edge weights let mut weights = HashMap::new(); weights.insert(edge_ab, 1.0); weights.insert(edge_bc, 1.0); // Configure message passing let config = MessagePassingConfig::default(); // Run one layer let updated = message_passing(&graph, &features, &weights, &config)?;
How It Works
Aggregation Strategies
Choose how to combine messages from multiple neighbors:
| Strategy | Formula | Best For |
|---|---|---|
Aggregation::Sum |
Σ messages | Counting patterns, variable-importance weighting |
Aggregation::Mean |
(Σ messages) / n | Scale-invariant learning, GCN-style |
Aggregation::Max |
element-wise max | Detecting presence of features, GraphSAGE-style |
Example Configurations
use astraea_gnn::{MessagePassingConfig, Aggregation, Activation}; // GCN-style: Mean aggregation with normalization let gcn_config = MessagePassingConfig { aggregation: Aggregation::Mean, activation: Activation::ReLU, normalize: true, }; // GraphSAGE-style: Mean without normalization let sage_config = MessagePassingConfig { aggregation: Aggregation::Mean, activation: Activation::ReLU, normalize: false, }; // Max-pooling style let max_config = MessagePassingConfig { aggregation: Aggregation::Max, activation: Activation::ReLU, normalize: false, };
Activation Functions
Non-linear functions applied after aggregation to add expressiveness:
| Activation | Formula | Output Range | Use Case |
|---|---|---|---|
Activation::ReLU |
max(0, x) | [0, ∞) | Default choice, introduces sparsity |
Activation::Sigmoid |
1 / (1 + e-x) | (0, 1) | Bounded outputs, probability-like |
Activation::None |
x (identity) | (-∞, ∞) | Linear layers, final layer before softmax |
Training Loop
The train_node_classification function provides an end-to-end training pipeline.
Training Configuration
use astraea_gnn::{TrainingConfig, MessagePassingConfig}; let config = TrainingConfig { layers: 2, // Number of message passing layers learning_rate: 0.1, // Gradient descent step size epochs: 100, // Training iterations message_passing: MessagePassingConfig::default(), };
Training Data
use astraea_gnn::TrainingData; use std::collections::HashMap; let mut labels = HashMap::new(); labels.insert(node_a, 0); // Class 0 labels.insert(node_b, 0); // Class 0 labels.insert(node_c, 1); // Class 1 labels.insert(node_d, 1); // Class 1 // node_e is unlabeled (will be predicted) let training_data = TrainingData { labels, num_classes: 2, };
Running Training
use astraea_gnn::train_node_classification; let result = train_node_classification(&graph, &training_data, &config)?; // Inspect results println!("Accuracy: {:.2}%", result.accuracy * 100.0); println!("Loss history: {:?}", result.epoch_losses); println!("Predictions: {:?}", result.final_predictions);
Training Pipeline
Loss Functions
AstraeaDB uses Cross-Entropy Loss for node classification:
Cross-Entropy: Loss = -log(P(true_class))
Total Loss: mean over all labeled nodes
The softmax converts raw feature values (logits) into probabilities, and cross-entropy penalizes low probability assigned to the correct class.
Optimization
AstraeaDB uses numerical gradient descent with finite differences:
// For each edge weight: const EPSILON: f32 = 1e-3; loss_original = compute_loss(weights); weights[edge] += EPSILON; loss_perturbed = compute_loss(weights); weights[edge] -= EPSILON; // Restore gradient = (loss_perturbed - loss_original) / EPSILON; weights[edge] -= learning_rate * gradient;
Advantages & Limitations
| Advantages | Limitations |
|---|---|
| Simple to implement | O(E) forward passes per epoch |
| Works with any loss function | Approximate gradients only |
| No symbolic differentiation needed | Slower than backpropagation |
| Numerically stable | Best for small-to-medium graphs |
Node Classification
The primary task supported by AstraeaDB's GNN framework. Given a partially labeled graph, predict labels for unlabeled nodes.
Complete Example
use astraea_gnn::{ train_node_classification, TrainingConfig, TrainingData, MessagePassingConfig, Aggregation, Activation, }; use astraea_graph::Graph; use std::collections::HashMap; fn main() -> Result<(), Box<dyn std::error::Error>> { // 1. Create graph with nodes and embeddings let graph = Graph::new(Box::new(InMemoryStorage::new())); let n0 = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Alice"}), Some(vec![0.9, 0.1]), // Leans toward class 0 )?; let n1 = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Bob"}), Some(vec![0.8, 0.2]), )?; let n2 = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Charlie"}), Some(vec![0.2, 0.8]), // Leans toward class 1 )?; let n3 = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Diana"}), Some(vec![0.5, 0.5]), // Ambiguous - will learn from neighbors )?; // 2. Add edges (social connections) graph.create_edge(n0, n1, "FRIENDS".into(), serde_json::json!({}), 1.0, None, None)?; graph.create_edge(n1, n3, "FRIENDS".into(), serde_json::json!({}), 1.0, None, None)?; graph.create_edge(n2, n3, "FRIENDS".into(), serde_json::json!({}), 1.0, None, None)?; // 3. Prepare training data (partial labels) let mut labels = HashMap::new(); labels.insert(n0, 0); // Alice is class 0 labels.insert(n2, 1); // Charlie is class 1 // n1 and n3 are unlabeled let training_data = TrainingData { labels, num_classes: 2, }; // 4. Configure training let config = TrainingConfig { layers: 2, learning_rate: 0.1, epochs: 50, message_passing: MessagePassingConfig { aggregation: Aggregation::Mean, activation: Activation::ReLU, normalize: false, }, }; // 5. Train let result = train_node_classification(&graph, &training_data, &config)?; // 6. Results println!("Training Accuracy: {:.1}%", result.accuracy * 100.0); for (node_id, predicted_class) in &result.final_predictions { println!("Node {:?} predicted as class {}", node_id, predicted_class); } // Diana (n3) should be influenced by both Bob (class 0 neighbor) // and Charlie (class 1 neighbor) Ok(()) }
GNN Architectures
Configure different GNN variants by changing the message passing settings:
GCN (Graph Convolutional Network)
// Kipf & Welling 2017 style let gcn = TrainingConfig { layers: 2, learning_rate: 0.01, epochs: 200, message_passing: MessagePassingConfig { aggregation: Aggregation::Mean, activation: Activation::ReLU, normalize: true, // Key: L2 normalization }, };
GraphSAGE
// Hamilton et al. 2017 style let sage = TrainingConfig { layers: 2, learning_rate: 0.1, epochs: 100, message_passing: MessagePassingConfig { aggregation: Aggregation::Mean, // or Max activation: Activation::ReLU, normalize: false, }, };
Attention-like (Weighted Sum)
// Edge weights act like attention scores let attention = TrainingConfig { layers: 3, learning_rate: 0.05, epochs: 150, message_passing: MessagePassingConfig { aggregation: Aggregation::Sum, // Weights determine importance activation: Activation::Sigmoid, normalize: false, }, };
Multi-Layer GNNs
The layers parameter controls how many hops of information each node can see:
2 layers: Each node sees 2-hop neighborhood
K layers: Each node sees K-hop neighborhood
Comparison
// Compare different depths for num_layers in [1, 2, 3, 4] { let config = TrainingConfig { layers: num_layers, learning_rate: 0.1, epochs: 50, message_passing: MessagePassingConfig::default(), }; let result = train_node_classification(&graph, &data, &config)?; println!("{} layers: {:.1}% accuracy", num_layers, result.accuracy * 100.0); }
Example: Social Network Analysis
Predict user interests based on their social connections.
Scenario
Users are connected in a social graph. Some users have labeled interests (Tech, Sports). Predict interests for unlabeled users.
use astraea_gnn::*; use std::collections::HashMap; // Create users with profile embeddings let alice = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Alice", "bio": "Software engineer"}), Some(vec![0.9, 0.1]), // Tech-oriented embedding )?; let bob = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Bob", "bio": "Basketball fan"}), Some(vec![0.1, 0.9]), // Sports-oriented embedding )?; let charlie = graph.create_node( vec!["User".into()], serde_json::json!({"name": "Charlie"}), Some(vec![0.5, 0.5]), // Unknown interest )?; // Social connections graph.create_edge(alice, charlie, "FOLLOWS".into(), serde_json::json!({}), 1.0, None, None)?; graph.create_edge(bob, charlie, "FOLLOWS".into(), serde_json::json!({}), 0.5, None, None)?; // Labels: 0 = Tech, 1 = Sports let mut labels = HashMap::new(); labels.insert(alice, 0); // Tech labels.insert(bob, 1); // Sports let data = TrainingData { labels, num_classes: 2 }; // Train let result = train_node_classification(&graph, &data, &TrainingConfig { layers: 2, learning_rate: 0.1, epochs: 100, message_passing: MessagePassingConfig::default(), })?; // Charlie's prediction influenced by both Alice (Tech) and Bob (Sports) // Edge weights matter: Alice's edge (1.0) vs Bob's edge (0.5) let charlie_interest = result.final_predictions[&charlie]; println!("Charlie predicted interest: {}", if charlie_interest == 0 { "Tech" } else { "Sports" });
Example: Fraud Detection
Detect fraudulent accounts by learning from their connections to known fraudsters.
use astraea_gnn::*; // Create accounts let legit1 = graph.create_node( vec!["Account".into()], serde_json::json!({"age_days": 365, "verified": true}), Some(vec![0.9, 0.1]), // Legit signal )?; let fraud1 = graph.create_node( vec!["Account".into()], serde_json::json!({"age_days": 7, "verified": false}), Some(vec![0.1, 0.9]), // Fraud signal )?; let suspicious = graph.create_node( vec!["Account".into()], serde_json::json!({"age_days": 30, "verified": false}), Some(vec![0.5, 0.5]), // Ambiguous )?; // Suspicious account transacts with known fraudster graph.create_edge(suspicious, fraud1, "TRANSACTED".into(), serde_json::json!({}), 1.0, None, None)?; graph.create_edge(suspicious, legit1, "TRANSACTED".into(), serde_json::json!({}), 0.1, None, None)?; // Labels: 0 = Legitimate, 1 = Fraud let mut labels = HashMap::new(); labels.insert(legit1, 0); labels.insert(fraud1, 1); let data = TrainingData { labels, num_classes: 2 }; let result = train_node_classification(&graph, &data, &TrainingConfig { layers: 2, learning_rate: 0.1, epochs: 100, message_passing: MessagePassingConfig { aggregation: Aggregation::Sum, // Sum to count fraud connections activation: Activation::ReLU, normalize: false, }, })?; // Suspicious account is heavily connected to fraud (weight 1.0) // vs weakly connected to legit (weight 0.1) let is_fraud = result.final_predictions[&suspicious] == 1; println!("Suspicious account flagged as fraud: {}", is_fraud);
Best Practices
1. Feature Engineering
- Use meaningful embeddings for initial node features
- Consider pre-trained embeddings (e.g., from text, images)
- Ensure feature dimensions match
num_classesor pad appropriately
2. Graph Structure
- Ensure edges exist in both directions if relationships are symmetric
- Use edge weights to encode relationship strength
- Handle isolated nodes (no neighbors) gracefully
3. Training
| Parameter | Recommendation |
|---|---|
layers | Start with 2; increase only if needed |
learning_rate | 0.01 - 0.1; lower for stability |
epochs | 50-200; monitor loss curve |
aggregation | Mean for scale-invariance; Sum for counting |
4. Debugging
// Monitor training progress let result = train_node_classification(&graph, &data, &config)?; // Check if loss is decreasing for (epoch, loss) in result.epoch_losses.iter().enumerate() { println!("Epoch {}: loss = {:.4}", epoch, loss); } // If loss not decreasing: // - Try lower learning rate // - Check feature initialization // - Verify graph connectivity
5. When to Use GNNs
| Good Fit | Poor Fit |
|---|---|
| Node labels correlate with neighbors | Labels are independent of structure |
| Rich graph topology | Sparse or disconnected graphs |
| Semi-supervised (few labels) | All nodes labeled (use simpler methods) |
| Homophily (similar nodes connect) | Random connections |
AstraeaDB GNN Tutorial — Back to Wiki
See also: Vignette: Bitcoin AML with GNNs | GraphRAG with Claude Tutorial