Introduction

Graph Neural Networks (GNNs) are a class of deep learning models designed specifically for graph-structured data. Unlike traditional neural networks that operate on fixed-size inputs (images, sequences), GNNs can learn from the topology of a graph—how nodes connect to each other—and use that structure to make predictions.

AstraeaDB's astraea-gnn crate provides a lightweight, educational GNN framework that integrates directly with your graph database. Train models on your graph data without extracting it to external ML frameworks.

Graph Neural Network Architecture Input Graph Message Passing Output ─────────── ─────────────── ────── (A) Layer 1 Layer 2 │\ │ │ │ \ ▼ ▼ │ \ ┌─────────┐ ┌─────────┐ (B)──(C) ──────► │Aggregate│──►│Aggregate│──► Predictions │ / │Neighbors│ │Neighbors│ per Node │ / └─────────┘ └─────────┘ (D) │ │ ▼ ▼ h_i^(1) h_i^(2) argmax → class Each node learns from its neighbors through iterative message passing. After K layers, each node's representation encodes K-hop neighborhood info.

What Can GNNs Do?

Key Concepts

Node Features

Each node has a feature vector (also called an embedding). In AstraeaDB, these come from the embedding field on nodes:

let node = graph.create_node(
    labels,
    properties,
    Some(vec![0.5, -0.3, 0.8, 0.1])  // Feature vector
)?;

Edge Weights

Edges can have weights that control how much influence neighbors have. These weights are the trainable parameters in AstraeaDB's GNN:

let edge = graph.create_edge(
    source, target, edge_type, properties,
    1.5,   // Weight (trainable)
    None, None
)?;

Message Passing

The core GNN operation. Each node updates its features by aggregating information from its neighbors:

Message: mj→i = wij × hj

Aggregate: Mi = AGG({ mj→i : j ∈ N(i) })

Update: hi(new) = σ(Mi)

Where:

Why GNNs for Graph Data?

Traditional MLGraph Neural Networks
Treats each node independently Learns from node + neighborhood
Requires manual feature engineering Learns structural features automatically
Fixed input size Handles variable-size neighborhoods
No relationship awareness Explicitly models connections
Loses graph topology Preserves and exploits structure
Example: In fraud detection, a legitimate user connected to many fraudsters is suspicious. Traditional ML sees only the user's features; GNNs see the user and their suspicious neighborhood.

Tensor Operations

The Tensor struct is a 1D differentiable vector for node features with gradient tracking.

Creating Tensors

use astraea_gnn::Tensor;

// From data (with gradient tracking)
let features = Tensor::new(vec![1.0, 2.0, 3.0], true);

// Zero-filled
let zeros = Tensor::zeros(4, false);

// Single value
let scalar = Tensor::from_scalar(5.0);

Arithmetic Operations

let a = Tensor::new(vec![1.0, 2.0, 3.0], false);
let b = Tensor::new(vec![4.0, 5.0, 6.0], false);

// Element-wise operations
let sum = a.add(&b);        // [5.0, 7.0, 9.0]
let product = a.mul(&b);    // [4.0, 10.0, 18.0]
let scaled = a.scale(2.5);  // [2.5, 5.0, 7.5]

Reduction Operations

let v = Tensor::new(vec![1.0, 2.0, 3.0], false);

let dot_product = a.dot(&b);  // 1*4 + 2*5 + 3*6 = 32.0
let total = v.sum();          // 6.0
let average = v.mean();       // 2.0
let l2_norm = v.norm();       // sqrt(14) ≈ 3.74

Activation Functions

let v = Tensor::new(vec![-1.0, 0.0, 2.0], false);

let relu = v.relu();      // [0.0, 0.0, 2.0] - max(0, x)
let sig = v.sigmoid();    // [0.27, 0.5, 0.88] - 1/(1+e^-x)

Gradient Management

let t = Tensor::new(vec![1.0, 2.0], true);

// Store computed gradient
t.set_grad(vec![0.1, 0.2]);

// Retrieve gradient
let grad = t.grad();  // Some([0.1, 0.2])

// Clear gradient for next iteration
t.zero_grad();

Complete Tensor API

MethodDescriptionReturns
new(data, requires_grad)Create from vectorTensor
zeros(len, requires_grad)Zero-filled tensorTensor
from_scalar(value)Single-value tensorTensor
add(&other)Element-wise additionTensor
mul(&other)Element-wise multiplicationTensor
scale(s)Scalar multiplicationTensor
dot(&other)Dot productf32
sum()Sum all elementsf32
mean()Mean of elementsf32
norm()L2 (Euclidean) normf32
relu()ReLU activationTensor
sigmoid()Sigmoid activationTensor

Message Passing

The message_passing function is the core GNN operation. It updates all node features based on their neighborhoods.

Basic Usage

use astraea_gnn::{message_passing, MessagePassingConfig, Tensor};
use std::collections::HashMap;

// Initialize node features
let mut features = HashMap::new();
features.insert(node_a, Tensor::new(vec![1.0, 0.0], false));
features.insert(node_b, Tensor::new(vec![0.0, 1.0], false));
features.insert(node_c, Tensor::new(vec![1.0, 1.0], false));

// Set edge weights
let mut weights = HashMap::new();
weights.insert(edge_ab, 1.0);
weights.insert(edge_bc, 1.0);

// Configure message passing
let config = MessagePassingConfig::default();

// Run one layer
let updated = message_passing(&graph, &features, &weights, &config)?;

How It Works

For each node i in the graph: 1. Collect neighbors (both directions) neighbors = graph.neighbors(i, Direction::Both) 2. Compute messages from each neighbor j message_j = edge_weight[i,j] * features[j] 3. Aggregate all messages if Sum: aggregated = Σ message_j if Mean: aggregated = (Σ message_j) / count if Max: aggregated = max(message_j) [element-wise] 4. Apply activation if ReLU: output = max(0, aggregated) if Sigmoid: output = 1 / (1 + exp(-aggregated)) if None: output = aggregated 5. Optional: L2 normalize if normalize: output = output / ||output|| 6. Store as new feature for node i

Aggregation Strategies

Choose how to combine messages from multiple neighbors:

StrategyFormulaBest For
Aggregation::Sum Σ messages Counting patterns, variable-importance weighting
Aggregation::Mean (Σ messages) / n Scale-invariant learning, GCN-style
Aggregation::Max element-wise max Detecting presence of features, GraphSAGE-style

Example Configurations

use astraea_gnn::{MessagePassingConfig, Aggregation, Activation};

// GCN-style: Mean aggregation with normalization
let gcn_config = MessagePassingConfig {
    aggregation: Aggregation::Mean,
    activation: Activation::ReLU,
    normalize: true,
};

// GraphSAGE-style: Mean without normalization
let sage_config = MessagePassingConfig {
    aggregation: Aggregation::Mean,
    activation: Activation::ReLU,
    normalize: false,
};

// Max-pooling style
let max_config = MessagePassingConfig {
    aggregation: Aggregation::Max,
    activation: Activation::ReLU,
    normalize: false,
};

Activation Functions

Non-linear functions applied after aggregation to add expressiveness:

ActivationFormulaOutput RangeUse Case
Activation::ReLU max(0, x) [0, ∞) Default choice, introduces sparsity
Activation::Sigmoid 1 / (1 + e-x) (0, 1) Bounded outputs, probability-like
Activation::None x (identity) (-∞, ∞) Linear layers, final layer before softmax

Training Loop

The train_node_classification function provides an end-to-end training pipeline.

Training Configuration

use astraea_gnn::{TrainingConfig, MessagePassingConfig};

let config = TrainingConfig {
    layers: 2,              // Number of message passing layers
    learning_rate: 0.1,     // Gradient descent step size
    epochs: 100,            // Training iterations
    message_passing: MessagePassingConfig::default(),
};

Training Data

use astraea_gnn::TrainingData;
use std::collections::HashMap;

let mut labels = HashMap::new();
labels.insert(node_a, 0);  // Class 0
labels.insert(node_b, 0);  // Class 0
labels.insert(node_c, 1);  // Class 1
labels.insert(node_d, 1);  // Class 1
// node_e is unlabeled (will be predicted)

let training_data = TrainingData {
    labels,
    num_classes: 2,
};

Running Training

use astraea_gnn::train_node_classification;

let result = train_node_classification(&graph, &training_data, &config)?;

// Inspect results
println!("Accuracy: {:.2}%", result.accuracy * 100.0);
println!("Loss history: {:?}", result.epoch_losses);
println!("Predictions: {:?}", result.final_predictions);

Training Pipeline

Training Pipeline (per epoch): ┌─────────────────────────────────────────────────────────────┐ │ 1. INITIALIZATION │ │ - Load node embeddings as initial features │ │ - Load edge weights (trainable parameters) │ │ - Pad/truncate features to match num_classes │ └─────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 2. FORWARD PASS │ │ for layer in 1..=num_layers: │ │ features = message_passing(graph, features, weights)│ └─────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 3. LOSS COMPUTATION │ │ - Treat features as logits │ │ - Softmax: probabilities = exp(logits) / Σexp(logits) │ │ - Cross-entropy: loss = -log(prob[true_class]) │ └─────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 4. BACKWARD PASS (Numerical Gradients) │ │ for each edge weight w: │ │ loss_original = forward(w) │ │ loss_perturbed = forward(w + ε) │ │ gradient = (loss_perturbed - loss_original) / ε │ └─────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────┐ │ 5. WEIGHT UPDATE │ │ w = w - learning_rate × gradient │ └─────────────────────────────────────────────────────────────┘ │ ▼ Repeat for next epoch

Loss Functions

AstraeaDB uses Cross-Entropy Loss for node classification:

Softmax: P(class = k) = exp(logitk) / Σj exp(logitj)

Cross-Entropy: Loss = -log(P(true_class))

Total Loss: mean over all labeled nodes

The softmax converts raw feature values (logits) into probabilities, and cross-entropy penalizes low probability assigned to the correct class.

Interpretation: A loss of 0.0 means perfect confidence in correct predictions. Higher loss means the model is uncertain or wrong.

Optimization

AstraeaDB uses numerical gradient descent with finite differences:

// For each edge weight:
const EPSILON: f32 = 1e-3;

loss_original = compute_loss(weights);
weights[edge] += EPSILON;
loss_perturbed = compute_loss(weights);
weights[edge] -= EPSILON;  // Restore

gradient = (loss_perturbed - loss_original) / EPSILON;
weights[edge] -= learning_rate * gradient;

Advantages & Limitations

AdvantagesLimitations
Simple to implement O(E) forward passes per epoch
Works with any loss function Approximate gradients only
No symbolic differentiation needed Slower than backpropagation
Numerically stable Best for small-to-medium graphs

Node Classification

The primary task supported by AstraeaDB's GNN framework. Given a partially labeled graph, predict labels for unlabeled nodes.

Complete Example

use astraea_gnn::{
    train_node_classification,
    TrainingConfig, TrainingData,
    MessagePassingConfig, Aggregation, Activation,
};
use astraea_graph::Graph;
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 1. Create graph with nodes and embeddings
    let graph = Graph::new(Box::new(InMemoryStorage::new()));

    let n0 = graph.create_node(
        vec!["User".into()],
        serde_json::json!({"name": "Alice"}),
        Some(vec![0.9, 0.1]),  // Leans toward class 0
    )?;

    let n1 = graph.create_node(
        vec!["User".into()],
        serde_json::json!({"name": "Bob"}),
        Some(vec![0.8, 0.2]),
    )?;

    let n2 = graph.create_node(
        vec!["User".into()],
        serde_json::json!({"name": "Charlie"}),
        Some(vec![0.2, 0.8]),  // Leans toward class 1
    )?;

    let n3 = graph.create_node(
        vec!["User".into()],
        serde_json::json!({"name": "Diana"}),
        Some(vec![0.5, 0.5]),  // Ambiguous - will learn from neighbors
    )?;

    // 2. Add edges (social connections)
    graph.create_edge(n0, n1, "FRIENDS".into(), serde_json::json!({}), 1.0, None, None)?;
    graph.create_edge(n1, n3, "FRIENDS".into(), serde_json::json!({}), 1.0, None, None)?;
    graph.create_edge(n2, n3, "FRIENDS".into(), serde_json::json!({}), 1.0, None, None)?;

    // 3. Prepare training data (partial labels)
    let mut labels = HashMap::new();
    labels.insert(n0, 0);  // Alice is class 0
    labels.insert(n2, 1);  // Charlie is class 1
    // n1 and n3 are unlabeled

    let training_data = TrainingData {
        labels,
        num_classes: 2,
    };

    // 4. Configure training
    let config = TrainingConfig {
        layers: 2,
        learning_rate: 0.1,
        epochs: 50,
        message_passing: MessagePassingConfig {
            aggregation: Aggregation::Mean,
            activation: Activation::ReLU,
            normalize: false,
        },
    };

    // 5. Train
    let result = train_node_classification(&graph, &training_data, &config)?;

    // 6. Results
    println!("Training Accuracy: {:.1}%", result.accuracy * 100.0);

    for (node_id, predicted_class) in &result.final_predictions {
        println!("Node {:?} predicted as class {}", node_id, predicted_class);
    }

    // Diana (n3) should be influenced by both Bob (class 0 neighbor)
    // and Charlie (class 1 neighbor)

    Ok(())
}

GNN Architectures

Configure different GNN variants by changing the message passing settings:

GCN (Graph Convolutional Network)

// Kipf & Welling 2017 style
let gcn = TrainingConfig {
    layers: 2,
    learning_rate: 0.01,
    epochs: 200,
    message_passing: MessagePassingConfig {
        aggregation: Aggregation::Mean,
        activation: Activation::ReLU,
        normalize: true,  // Key: L2 normalization
    },
};

GraphSAGE

// Hamilton et al. 2017 style
let sage = TrainingConfig {
    layers: 2,
    learning_rate: 0.1,
    epochs: 100,
    message_passing: MessagePassingConfig {
        aggregation: Aggregation::Mean,  // or Max
        activation: Activation::ReLU,
        normalize: false,
    },
};

Attention-like (Weighted Sum)

// Edge weights act like attention scores
let attention = TrainingConfig {
    layers: 3,
    learning_rate: 0.05,
    epochs: 150,
    message_passing: MessagePassingConfig {
        aggregation: Aggregation::Sum,  // Weights determine importance
        activation: Activation::Sigmoid,
        normalize: false,
    },
};

Multi-Layer GNNs

The layers parameter controls how many hops of information each node can see:

1 layer: Each node sees immediate neighbors

2 layers: Each node sees 2-hop neighborhood

K layers: Each node sees K-hop neighborhood
Over-smoothing Warning: Too many layers can cause all node representations to become similar ("over-smoothing"). 2-3 layers is usually optimal.

Comparison

// Compare different depths
for num_layers in [1, 2, 3, 4] {
    let config = TrainingConfig {
        layers: num_layers,
        learning_rate: 0.1,
        epochs: 50,
        message_passing: MessagePassingConfig::default(),
    };

    let result = train_node_classification(&graph, &data, &config)?;
    println!("{} layers: {:.1}% accuracy", num_layers, result.accuracy * 100.0);
}

Example: Social Network Analysis

Predict user interests based on their social connections.

Scenario

Users are connected in a social graph. Some users have labeled interests (Tech, Sports). Predict interests for unlabeled users.

use astraea_gnn::*;
use std::collections::HashMap;

// Create users with profile embeddings
let alice = graph.create_node(
    vec!["User".into()],
    serde_json::json!({"name": "Alice", "bio": "Software engineer"}),
    Some(vec![0.9, 0.1]),  // Tech-oriented embedding
)?;

let bob = graph.create_node(
    vec!["User".into()],
    serde_json::json!({"name": "Bob", "bio": "Basketball fan"}),
    Some(vec![0.1, 0.9]),  // Sports-oriented embedding
)?;

let charlie = graph.create_node(
    vec!["User".into()],
    serde_json::json!({"name": "Charlie"}),
    Some(vec![0.5, 0.5]),  // Unknown interest
)?;

// Social connections
graph.create_edge(alice, charlie, "FOLLOWS".into(), serde_json::json!({}), 1.0, None, None)?;
graph.create_edge(bob, charlie, "FOLLOWS".into(), serde_json::json!({}), 0.5, None, None)?;

// Labels: 0 = Tech, 1 = Sports
let mut labels = HashMap::new();
labels.insert(alice, 0);  // Tech
labels.insert(bob, 1);    // Sports

let data = TrainingData { labels, num_classes: 2 };

// Train
let result = train_node_classification(&graph, &data, &TrainingConfig {
    layers: 2,
    learning_rate: 0.1,
    epochs: 100,
    message_passing: MessagePassingConfig::default(),
})?;

// Charlie's prediction influenced by both Alice (Tech) and Bob (Sports)
// Edge weights matter: Alice's edge (1.0) vs Bob's edge (0.5)
let charlie_interest = result.final_predictions[&charlie];
println!("Charlie predicted interest: {}",
    if charlie_interest == 0 { "Tech" } else { "Sports" });

Example: Fraud Detection

Detect fraudulent accounts by learning from their connections to known fraudsters.

use astraea_gnn::*;

// Create accounts
let legit1 = graph.create_node(
    vec!["Account".into()],
    serde_json::json!({"age_days": 365, "verified": true}),
    Some(vec![0.9, 0.1]),  // Legit signal
)?;

let fraud1 = graph.create_node(
    vec!["Account".into()],
    serde_json::json!({"age_days": 7, "verified": false}),
    Some(vec![0.1, 0.9]),  // Fraud signal
)?;

let suspicious = graph.create_node(
    vec!["Account".into()],
    serde_json::json!({"age_days": 30, "verified": false}),
    Some(vec![0.5, 0.5]),  // Ambiguous
)?;

// Suspicious account transacts with known fraudster
graph.create_edge(suspicious, fraud1, "TRANSACTED".into(), serde_json::json!({}), 1.0, None, None)?;
graph.create_edge(suspicious, legit1, "TRANSACTED".into(), serde_json::json!({}), 0.1, None, None)?;

// Labels: 0 = Legitimate, 1 = Fraud
let mut labels = HashMap::new();
labels.insert(legit1, 0);
labels.insert(fraud1, 1);

let data = TrainingData { labels, num_classes: 2 };

let result = train_node_classification(&graph, &data, &TrainingConfig {
    layers: 2,
    learning_rate: 0.1,
    epochs: 100,
    message_passing: MessagePassingConfig {
        aggregation: Aggregation::Sum,  // Sum to count fraud connections
        activation: Activation::ReLU,
        normalize: false,
    },
})?;

// Suspicious account is heavily connected to fraud (weight 1.0)
// vs weakly connected to legit (weight 0.1)
let is_fraud = result.final_predictions[&suspicious] == 1;
println!("Suspicious account flagged as fraud: {}", is_fraud);
Key Insight: The GNN learns that accounts connected to fraudsters (especially with high-weight edges) are likely fraudulent themselves. This captures the "guilt by association" pattern that traditional ML misses.

Best Practices

1. Feature Engineering

2. Graph Structure

3. Training

ParameterRecommendation
layersStart with 2; increase only if needed
learning_rate0.01 - 0.1; lower for stability
epochs50-200; monitor loss curve
aggregationMean for scale-invariance; Sum for counting

4. Debugging

// Monitor training progress
let result = train_node_classification(&graph, &data, &config)?;

// Check if loss is decreasing
for (epoch, loss) in result.epoch_losses.iter().enumerate() {
    println!("Epoch {}: loss = {:.4}", epoch, loss);
}

// If loss not decreasing:
// - Try lower learning rate
// - Check feature initialization
// - Verify graph connectivity

5. When to Use GNNs

Good FitPoor Fit
Node labels correlate with neighbors Labels are independent of structure
Rich graph topology Sparse or disconnected graphs
Semi-supervised (few labels) All nodes labeled (use simpler methods)
Homophily (similar nodes connect) Random connections

AstraeaDB GNN Tutorial — Back to Wiki

See also: Vignette: Bitcoin AML with GNNs | GraphRAG with Claude Tutorial