Chapter 12: Graph Neural Networks

Train neural networks directly on graph structure inside AstraeaDB. Learn how message passing, differentiable edge weights, and in-database training eliminate the need to export data to external ML frameworks.

Who Is This Chapter For? GNN training is an advanced feature. For most users, the combination of vector search + graph traversals + GraphRAG (Chapters 8, 6, and 11) will be sufficient. GNN training is for users who want to learn relationship patterns directly from labeled data -- for example, predicting fraud, classifying documents, or discovering hidden communities based on structural features.

12.1 What Are GNNs?

Neural Networks on Graphs

Traditional neural networks operate on fixed-size inputs: images have a grid of pixels, text has a sequence of tokens. But graphs have no fixed structure -- each node can have a different number of neighbors, and the topology varies throughout the graph. Graph Neural Networks (GNNs) are a family of neural network architectures designed specifically for this irregular, relational data.

The Key Idea: Message Passing

GNNs work through a process called message passing. In each layer of the network, every node:

  1. Gathers feature vectors from its neighbors (and their edge weights).
  2. Aggregates those features into a single summary (e.g., sum, mean, or max).
  3. Updates its own feature vector based on the aggregated neighborhood information.

After multiple rounds of message passing, each node's feature vector encodes information not just about itself, but about its entire local neighborhood. A node that is 3 message-passing layers deep effectively "sees" all nodes within 3 hops.

Message Passing Illustrated

Consider a simple graph where we want to classify node C:

Round 0 (initial features):
  A [0.1, 0.9]    B [0.8, 0.2]    C [0.5, 0.5]    D [0.3, 0.7]
       \              |              |              /
        \             |              |             /
         ---------- edges -----------+------------

Round 1 (after 1 layer of message passing):
  C gathers from neighbors: A, B, D
  C aggregates: mean([0.1, 0.9], [0.8, 0.2], [0.3, 0.7]) = [0.4, 0.6]
  C updates:    combine([0.5, 0.5], [0.4, 0.6]) = [0.45, 0.55] (new features)

Round 2 (after 2 layers):
  C now incorporates 2-hop neighborhood information
  Its features encode the structure of the local graph

Node Classification

The most common GNN task is node classification: predict a label for each node based on its features and its neighborhood structure. Examples include:

12.2 Setting Up Training Data

To train a GNN in AstraeaDB, your nodes need two things:

Creating Labeled Nodes

from astraeadb import AstraeaClient

with AstraeaClient("127.0.0.1", 7687) as client:
    # Create nodes with embeddings and known labels
    alice = client.create_node(
        labels=["User"],
        properties={"name": "Alice", "label": "legitimate"},
        embedding=[0.1, 0.2, 0.05, 0.3, 0.15, 0.1, 0.25, 0.08]
    )

    bob = client.create_node(
        labels=["User"],
        properties={"name": "Bob", "label": "legitimate"},
        embedding=[0.12, 0.18, 0.07, 0.28, 0.14, 0.11, 0.22, 0.09]
    )

    bad_bot = client.create_node(
        labels=["User"],
        properties={"name": "BadBot", "label": "fraudulent"},
        embedding=[0.9, 0.8, 0.85, 0.7, 0.95, 0.88, 0.75, 0.92]
    )

    shill = client.create_node(
        labels=["User"],
        properties={"name": "Shill42", "label": "fraudulent"},
        embedding=[0.88, 0.82, 0.9, 0.72, 0.91, 0.85, 0.78, 0.89]
    )

    # A node with no label -- the GNN will predict this
    unknown = client.create_node(
        labels=["User"],
        properties={"name": "NewUser99"},  # no "label" property
        embedding=[0.7, 0.6, 0.75, 0.55, 0.8, 0.65, 0.58, 0.72]
    )

    # Create edges (transaction relationships)
    client.create_edge(alice["node_id"], bob["node_id"], "TRANSACTED",
                       {"amount": 50.0})
    client.create_edge(bad_bot["node_id"], shill["node_id"], "TRANSACTED",
                       {"amount": 10000.0})
    client.create_edge(shill["node_id"], bad_bot["node_id"], "TRANSACTED",
                       {"amount": 9500.0})
    client.create_edge(unknown["node_id"], bad_bot["node_id"], "TRANSACTED",
                       {"amount": 8000.0})
    client.create_edge(unknown["node_id"], shill["node_id"], "TRANSACTED",
                       {"amount": 7500.0})
source("r_client.R")

client <- AstraeaClient$new("127.0.0.1", 7687)
client$connect()

# Create labeled nodes with embeddings
alice <- client$create_node(
  labels = list("User"),
  properties = list(name = "Alice", label = "legitimate"),
  embedding = c(0.1, 0.2, 0.05, 0.3, 0.15, 0.1, 0.25, 0.08)
)

bob <- client$create_node(
  labels = list("User"),
  properties = list(name = "Bob", label = "legitimate"),
  embedding = c(0.12, 0.18, 0.07, 0.28, 0.14, 0.11, 0.22, 0.09)
)

bad_bot <- client$create_node(
  labels = list("User"),
  properties = list(name = "BadBot", label = "fraudulent"),
  embedding = c(0.9, 0.8, 0.85, 0.7, 0.95, 0.88, 0.75, 0.92)
)

shill <- client$create_node(
  labels = list("User"),
  properties = list(name = "Shill42", label = "fraudulent"),
  embedding = c(0.88, 0.82, 0.9, 0.72, 0.91, 0.85, 0.78, 0.89)
)

# Unlabeled node for prediction
unknown <- client$create_node(
  labels = list("User"),
  properties = list(name = "NewUser99"),
  embedding = c(0.7, 0.6, 0.75, 0.55, 0.8, 0.65, 0.58, 0.72)
)

# Create edges
client$create_edge(alice$node_id, bob$node_id, "TRANSACTED",
  list(amount = 50.0))
client$create_edge(bad_bot$node_id, shill$node_id, "TRANSACTED",
  list(amount = 10000.0))
client$create_edge(shill$node_id, bad_bot$node_id, "TRANSACTED",
  list(amount = 9500.0))
client$create_edge(unknown$node_id, bad_bot$node_id, "TRANSACTED",
  list(amount = 8000.0))
client$create_edge(unknown$node_id, shill$node_id, "TRANSACTED",
  list(amount = 7500.0))

client$close()
package main

import (
    "context"
    "github.com/AstraeaDB/AstraeaDB-Official"
)

func main() {
    client := astraeadb.NewClient(astraeadb.WithAddress("127.0.0.1", 7687))
    ctx := context.Background()
    client.Connect(ctx)
    defer client.Close()

    // Create labeled nodes
    alice, _ := client.CreateNode(ctx, []string{"User"},
        map[string]any{"name": "Alice", "label": "legitimate"},
        []float32{0.1, 0.2, 0.05, 0.3, 0.15, 0.1, 0.25, 0.08})

    bob, _ := client.CreateNode(ctx, []string{"User"},
        map[string]any{"name": "Bob", "label": "legitimate"},
        []float32{0.12, 0.18, 0.07, 0.28, 0.14, 0.11, 0.22, 0.09})

    badBot, _ := client.CreateNode(ctx, []string{"User"},
        map[string]any{"name": "BadBot", "label": "fraudulent"},
        []float32{0.9, 0.8, 0.85, 0.7, 0.95, 0.88, 0.75, 0.92})

    shill, _ := client.CreateNode(ctx, []string{"User"},
        map[string]any{"name": "Shill42", "label": "fraudulent"},
        []float32{0.88, 0.82, 0.9, 0.72, 0.91, 0.85, 0.78, 0.89})

    unknown, _ := client.CreateNode(ctx, []string{"User"},
        map[string]any{"name": "NewUser99"},
        []float32{0.7, 0.6, 0.75, 0.55, 0.8, 0.65, 0.58, 0.72})

    // Create edges
    client.CreateEdge(ctx, alice.NodeID, bob.NodeID, "TRANSACTED",
        map[string]any{"amount": 50.0})
    client.CreateEdge(ctx, badBot.NodeID, shill.NodeID, "TRANSACTED",
        map[string]any{"amount": 10000.0})
    client.CreateEdge(ctx, shill.NodeID, badBot.NodeID, "TRANSACTED",
        map[string]any{"amount": 9500.0})
    client.CreateEdge(ctx, unknown.NodeID, badBot.NodeID, "TRANSACTED",
        map[string]any{"amount": 8000.0})
    client.CreateEdge(ctx, unknown.NodeID, shill.NodeID, "TRANSACTED",
        map[string]any{"amount": 7500.0})
}
import com.astraeadb.unified.UnifiedClient;
import java.util.List;
import java.util.Map;

try (var client = UnifiedClient.builder()
        .host("127.0.0.1").port(7687).build()) {
    client.connect();

    // Create labeled nodes with embeddings
    var alice = client.createNode(
        List.of("User"),
        Map.of("name", "Alice", "label", "legitimate"),
        new float[]{0.1f, 0.2f, 0.05f, 0.3f, 0.15f, 0.1f, 0.25f, 0.08f});

    var bob = client.createNode(
        List.of("User"),
        Map.of("name", "Bob", "label", "legitimate"),
        new float[]{0.12f, 0.18f, 0.07f, 0.28f, 0.14f, 0.11f, 0.22f, 0.09f});

    var badBot = client.createNode(
        List.of("User"),
        Map.of("name", "BadBot", "label", "fraudulent"),
        new float[]{0.9f, 0.8f, 0.85f, 0.7f, 0.95f, 0.88f, 0.75f, 0.92f});

    var shill = client.createNode(
        List.of("User"),
        Map.of("name", "Shill42", "label", "fraudulent"),
        new float[]{0.88f, 0.82f, 0.9f, 0.72f, 0.91f, 0.85f, 0.78f, 0.89f});

    var unknown = client.createNode(
        List.of("User"),
        Map.of("name", "NewUser99"),
        new float[]{0.7f, 0.6f, 0.75f, 0.55f, 0.8f, 0.65f, 0.58f, 0.72f});

    // Create edges
    client.createEdge(alice.getNodeId(), bob.getNodeId(),
        "TRANSACTED", Map.of("amount", 50.0));
    client.createEdge(badBot.getNodeId(), shill.getNodeId(),
        "TRANSACTED", Map.of("amount", 10000.0));
    client.createEdge(shill.getNodeId(), badBot.getNodeId(),
        "TRANSACTED", Map.of("amount", 9500.0));
    client.createEdge(unknown.getNodeId(), badBot.getNodeId(),
        "TRANSACTED", Map.of("amount", 8000.0));
    client.createEdge(unknown.getNodeId(), shill.getNodeId(),
        "TRANSACTED", Map.of("amount", 7500.0));
}

Understanding the Data

In this fraud detection scenario:

Message Passing Configuration

AstraeaDB supports several aggregation functions and activation functions for the message passing layers:

SettingOptionsDescription
AggregationSum, Mean, MaxHow neighbor features are combined. Sum preserves magnitude; Mean normalizes by degree; Max takes the strongest signal.
ActivationReLU, SigmoidNon-linearity applied after aggregation. ReLU is the default for hidden layers; Sigmoid is used for the output layer in binary classification.
Layers1-5 (typical)Number of message passing rounds. More layers = larger receptive field, but risk of over-smoothing.
Note: Over-Smoothing With too many message passing layers, all node features converge to similar values -- a phenomenon called over-smoothing. For most graphs, 2-3 layers is optimal. Start with 2 and increase only if accuracy improves.

12.3 Training a Model

AstraeaDB's GNN training runs entirely inside the database. There is no need to export your graph to PyTorch Geometric, DGL, or any external framework. This eliminates the data engineering overhead of ETL pipelines and ensures the model always trains on the most current data.

The Training Loop

Each training epoch follows these four steps:

  1. Forward pass (message passing): For each layer, every node aggregates its neighbors' features (weighted by edge weights), applies the activation function, and produces updated features. After all layers, each node's features encode its neighborhood structure.
  2. Loss computation: For nodes with known labels, the model's predictions are compared against the ground truth. AstraeaDB uses cross-entropy loss for classification tasks.
  3. Gradient estimation: AstraeaDB uses numerical gradients (finite differences) rather than automatic differentiation. For each trainable parameter, the loss is evaluated at param + epsilon and param - epsilon, and the gradient is approximated as (loss+ - loss-) / (2 * epsilon).
  4. Weight update: All trainable parameters (edge weights, layer weights) are updated using gradient descent: param = param - learning_rate * gradient.

Conceptual Training Configuration

GNN training is currently configured via the Rust API, with client wrapper support planned for future releases.

use astraeadb::gnn::{GnnConfig, Aggregation, Activation};

let config = GnnConfig {
    // Network architecture
    num_layers:    2,
    hidden_dim:    16,
    output_classes: 2,          // "legitimate" vs "fraudulent"
    aggregation:   Aggregation::Mean,
    activation:    Activation::ReLU,

    // Training parameters
    epochs:        50,
    learning_rate: 0.01,
    label_key:     "label".to_string(),    // property containing ground truth

    // Label mapping
    label_map: vec![
        ("legitimate".to_string(), 0),
        ("fraudulent".to_string(), 1),
    ],

    // Gradient estimation
    epsilon: 1e-5,
};

// Run training
let result = graph.train_gnn(config)?;

// Inspect training progress
for (epoch, loss) in result.epoch_losses.iter().enumerate() {
    println!("Epoch {}: loss = {:.4}", epoch, loss);
}
println!("Final accuracy: {:.1}%", result.accuracy * 100.0);

Expected Training Output

During training, you will see the loss decreasing as the model learns:

Epoch  0: loss = 0.6931  (random initialization, ~coin flip)
Epoch  5: loss = 0.5412
Epoch 10: loss = 0.3876
Epoch 20: loss = 0.1923
Epoch 30: loss = 0.0847
Epoch 50: loss = 0.0234
Final accuracy: 100.0% (on labeled nodes)

The loss starts near 0.693 (the natural log of 2, which is the maximum entropy for binary classification -- essentially random guessing). As training proceeds, the loss drops as the model learns to distinguish legitimate from fraudulent patterns.

Making Predictions

After training, use the trained model to classify unlabeled nodes:

// Get predictions for all unlabeled nodes
let predictions = graph.predict_gnn()?;

for (node_id, prediction) in predictions {
    let label = if prediction.class_id == 0 { "legitimate" } else { "fraudulent" };
    println!("Node {}: {} (confidence: {:.1}%)",
        node_id, label, prediction.confidence * 100.0);
}

// Output:
// Node nd-5 (NewUser99): fraudulent (confidence: 94.2%)
Warning: Training Data Quality GNN accuracy depends heavily on the quality and quantity of labeled data. A few mislabeled nodes can significantly degrade performance. Aim for at least 20-30 labeled examples per class for meaningful results. For production fraud detection, hundreds or thousands of labeled examples are recommended.

12.4 Differentiable Edge Weights

AstraeaDB's Unique Innovation

In most graph databases, edge weights are static values set at creation time. In AstraeaDB, edge weights are differentiable tensors -- they are trainable parameters that the GNN can update during the learning process.

This is a fundamental architectural decision rooted in AstraeaDB's "Vector-Property Graph" data model (described in the project architecture). Edges do not just carry metadata; they carry learnable parameters that encode the strength and nature of relationships.

How It Works

  1. Initialization: All edges start with a uniform weight (typically 1.0). At this point, the GNN treats all connections equally.
  2. Training: During each epoch, the gradient of the loss function with respect to each edge weight is computed. Edges that contribute to correct predictions receive gradient updates that strengthen them; edges that contribute to incorrect predictions are weakened.
  3. After training: Edge weights encode learned relationship strengths. An edge with a high weight (e.g., 2.3) indicates a strong, predictive connection. An edge with a low weight (e.g., 0.1) indicates a weak or irrelevant connection.

What This Means in Practice

Consider the fraud detection example. Before training, all TRANSACTED edges have a weight of 1.0. After training:

Edge weights after training:
Alice  --TRANSACTED--> Bob       weight: 0.15  (low -- normal transaction, not predictive)
BadBot --TRANSACTED--> Shill42   weight: 2.41  (high -- circular fraud pattern)
Shill42 --TRANSACTED--> BadBot   weight: 2.38  (high -- circular fraud pattern)
NewUser99 --TRANSACTED--> BadBot weight: 1.87  (high -- connected to known fraud)
NewUser99 --TRANSACTED--> Shill42 weight: 1.92  (high -- connected to known fraud)

The model has learned that:

Querying Learned Weights

After training, the learned edge weights persist in the graph and can be queried like any other edge property:

with AstraeaClient("127.0.0.1", 7687) as client:
    # Query edges with high learned weights (suspicious connections)
    edges = client.query("""
        MATCH (a:User)-[t:TRANSACTED]->(b:User)
        WHERE t.weight > 1.5
        RETURN a.name, b.name, t.weight, t.amount
        ORDER BY t.weight DESC
    """)

    print("High-weight edges (suspicious connections):")
    for row in edges:
        print(f"  {row['a.name']} -> {row['b.name']}: "
              f"weight={row['t.weight']:.2f}, amount={row['t.amount']}")
edges <- client$query("
  MATCH (a:User)-[t:TRANSACTED]->(b:User)
  WHERE t.weight > 1.5
  RETURN a.name, b.name, t.weight, t.amount
  ORDER BY t.weight DESC
")

cat("High-weight edges (suspicious connections):\n")
for (row in edges) {
  cat(sprintf("  %s -> %s: weight=%.2f, amount=%s\n",
    row$"a.name", row$"b.name", row$"t.weight", row$"t.amount"))
}
edges, _ := client.Query(ctx, `
    MATCH (a:User)-[t:TRANSACTED]->(b:User)
    WHERE t.weight > 1.5
    RETURN a.name, b.name, t.weight, t.amount
    ORDER BY t.weight DESC
`)

fmt.Println("High-weight edges (suspicious connections):")
for _, row := range edges {
    fmt.Printf("  %s -> %s: weight=%.2f, amount=%v\n",
        row["a.name"], row["b.name"], row["t.weight"], row["t.amount"])
}
var edges = client.query("""
    MATCH (a:User)-[t:TRANSACTED]->(b:User)
    WHERE t.weight > 1.5
    RETURN a.name, b.name, t.weight, t.amount
    ORDER BY t.weight DESC
""");

System.out.println("High-weight edges (suspicious connections):");
for (var row : edges) {
    System.out.printf("  %s -> %s: weight=%.2f, amount=%s%n",
        row.get("a.name"), row.get("b.name"),
        row.get("t.weight"), row.get("t.amount"));
}

Use Cases for Differentiable Edges

DomainInitial StateAfter Training
Fraud Detection All transaction edges weight = 1.0 Suspicious circular transactions get high weights; normal transactions get low weights
Recommendation All user-item edges weight = 1.0 Strong preferences amplified; weak or incidental interactions dampened
Knowledge Graphs All "related_to" edges weight = 1.0 Causally important relationships get high weights; coincidental co-occurrences get low weights
Network Security All connection edges weight = 1.0 Lateral movement paths get high weights; routine traffic gets low weights
Note: Persistence Learned edge weights are persisted to disk as part of the normal graph storage. They survive server restarts and can be used in subsequent queries, GraphRAG pipelines, or additional training rounds without re-training from scratch.

The Bigger Picture

Differentiable edge weights transform AstraeaDB from a passive data store into an active learning system. The database does not just store your graph -- it learns from it. By running the training loop inside the database, you eliminate the traditional ML pipeline of "export data, train model, import predictions" and replace it with a single train_gnn call that updates the graph in place.

This is the realization of the "Differentiable Traversal" concept from AstraeaDB's architecture: the query execution plan itself is differentiable, and backpropagation updates edge weights directly inside the database.

← Chapter 11: GraphRAG Chapter 13: Security →