Skip to content

πŸ”€ MultiHeadGraphFeaturePreprocessor

πŸ”€ MultiHeadGraphFeaturePreprocessor

🟑 Intermediate βœ… Stable πŸ”₯ Popular

🎯 Overview

The MultiHeadGraphFeaturePreprocessor treats each feature as a node in a graph and applies multi-head self-attention to capture and aggregate complex interactions among features. It learns multiple relational views among features, which can significantly boost performance on tabular data.

This layer is particularly powerful for tabular data where complex feature relationships need to be captured, providing a sophisticated preprocessing step that can learn multiple aspects of feature interactions.

πŸ” How It Works

The MultiHeadGraphFeaturePreprocessor processes data through multi-head graph-based transformation:

  1. Feature Embedding: Projects each scalar input into an embedding
  2. Multi-Head Split: Splits the embedding into multiple heads
  3. Query-Key-Value: Computes queries, keys, and values for each head
  4. Scaled Dot-Product Attention: Calculates attention across feature dimension
  5. Head Concatenation: Concatenates head outputs
  6. Output Projection: Projects back to original dimension with residual connection
graph TD
    A[Input Features] --> B[Feature Embedding]
    B --> C[Multi-Head Split]
    C --> D[Query-Key-Value]
    D --> E[Scaled Dot-Product Attention]
    E --> F[Head Concatenation]
    F --> G[Output Projection]
    A --> H[Residual Connection]
    G --> H
    H --> I[Transformed Features]

    style A fill:#e6f3ff,stroke:#4a86e8
    style I fill:#e8f5e9,stroke:#66bb6a
    style B fill:#fff9e6,stroke:#ffb74d
    style C fill:#f3e5f5,stroke:#9c27b0
    style D fill:#e1f5fe,stroke:#03a9f4
    style E fill:#fff3e0,stroke:#ff9800

πŸ’‘ Why Use This Layer?

Challenge Traditional Approach MultiHeadGraphFeaturePreprocessor's Solution
Feature Interactions Manual feature engineering 🎯 Automatic learning of complex feature interactions
Multiple Views Single perspective ⚑ Multi-head attention for multiple relational views
Graph Structure No graph structure 🧠 Graph-based feature preprocessing
Complex Relationships Limited relationship modeling πŸ”— Sophisticated relationship learning

πŸ“Š Use Cases

  • Tabular Data: Complex feature relationship preprocessing
  • Graph Neural Networks: Graph-based preprocessing for tabular data
  • Feature Engineering: Automatic feature interaction learning
  • Multi-Head Attention: Multiple relational views of features
  • Complex Patterns: Capturing complex feature relationships

πŸš€ Quick Start

Basic Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import keras
from kerasfactory.layers import MultiHeadGraphFeaturePreprocessor

# Create sample input data
batch_size, num_features = 32, 10
x = keras.random.normal((batch_size, num_features))

# Apply multi-head graph feature preprocessor
graph_preproc = MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4)
output = graph_preproc(x, training=True)

print(f"Input shape: {x.shape}")           # (32, 10)
print(f"Output shape: {output.shape}")     # (32, 10)

In a Sequential Model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import keras
from kerasfactory.layers import MultiHeadGraphFeaturePreprocessor

model = keras.Sequential([
    keras.layers.Dense(32, activation='relu'),
    MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4),
    keras.layers.Dense(16, activation='relu'),
    MultiHeadGraphFeaturePreprocessor(embed_dim=8, num_heads=2),
    keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In a Functional Model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import keras
from kerasfactory.layers import MultiHeadGraphFeaturePreprocessor

# Define inputs
inputs = keras.Input(shape=(20,))  # 20 features

# Apply multi-head graph feature preprocessor
x = MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4)(inputs)

# Continue processing
x = keras.layers.Dense(32, activation='relu')(x)
x = MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4)(x)
x = keras.layers.Dense(16, activation='relu')(x)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)

model = keras.Model(inputs, outputs)

Advanced Configuration

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Advanced configuration with multiple graph preprocessors
def create_multi_head_graph_network():
    inputs = keras.Input(shape=(25,))  # 25 features

    # Multiple graph preprocessors with different configurations
    x = MultiHeadGraphFeaturePreprocessor(
        embed_dim=24,
        num_heads=6,
        dropout_rate=0.1
    )(inputs)

    x = keras.layers.Dense(48, activation='relu')(x)
    x = keras.layers.BatchNormalization()(x)

    x = MultiHeadGraphFeaturePreprocessor(
        embed_dim=20,
        num_heads=5,
        dropout_rate=0.1
    )(x)

    x = keras.layers.Dense(32, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)

    # Multi-task output
    classification = keras.layers.Dense(3, activation='softmax', name='classification')(x)
    regression = keras.layers.Dense(1, name='regression')(x)

    return keras.Model(inputs, [classification, regression])

model = create_multi_head_graph_network()
model.compile(
    optimizer='adam',
    loss={'classification': 'categorical_crossentropy', 'regression': 'mse'},
    loss_weights={'classification': 1.0, 'regression': 0.5}
)

πŸ“– API Reference

kerasfactory.layers.MultiHeadGraphFeaturePreprocessor

This module implements a MultiHeadGraphFeaturePreprocessor layer that treats features as nodes in a graph and learns multiple "views" (heads) of the feature interactions via self-attention. This approach is useful for tabular data where complex feature relationships need to be captured.

Classes

MultiHeadGraphFeaturePreprocessor
1
2
3
4
5
6
7
MultiHeadGraphFeaturePreprocessor(
    embed_dim: int = 16,
    num_heads: int = 4,
    dropout_rate: float = 0.0,
    name: str | None = None,
    **kwargs: Any
)

Multi-head graph-based feature preprocessor for tabular data.

This layer treats each feature as a node and applies multi-head self-attention to capture and aggregate complex interactions among features. The process is:

  1. Project each scalar input into an embedding of dimension embed_dim.
  2. Split the embedding into num_heads heads.
  3. For each head, compute queries, keys, and values and calculate scaled dot-product attention across the feature dimension.
  4. Concatenate the head outputs, project back to the original feature dimension, and add a residual connection.

This mechanism allows the network to learn multiple relational views among features, which can significantly boost performance on tabular data.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the feature embeddings. Default is 16.

16
num_heads int

Number of attention heads. Default is 4.

4
dropout_rate float

Dropout rate applied to attention weights. Default is 0.0.

0.0
name str | None

Optional name for the layer.

None
Input shape

2D tensor with shape: (batch_size, num_features)

Output shape

2D tensor with shape: (batch_size, num_features) (same as input)

Example
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import keras
from kerasfactory.layers import MultiHeadGraphFeaturePreprocessor

# Tabular data with 10 features
x = keras.random.normal((32, 10))

# Create the layer with 16-dim embeddings and 4 attention heads
graph_preproc = MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4)
y = graph_preproc(x, training=True)
print("Output shape:", y.shape)  # Expected: (32, 10)

Initialize the MultiHeadGraphFeaturePreprocessor.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension.

16
num_heads int

Number of attention heads.

4
dropout_rate float

Dropout rate.

0.0
name str | None

Name of the layer.

None
**kwargs Any

Additional keyword arguments.

{}
Source code in kerasfactory/layers/MultiHeadGraphFeaturePreprocessor.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    embed_dim: int = 16,
    num_heads: int = 4,
    dropout_rate: float = 0.0,
    name: str | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the MultiHeadGraphFeaturePreprocessor.

    Args:
        embed_dim: Embedding dimension.
        num_heads: Number of attention heads.
        dropout_rate: Dropout rate.
        name: Name of the layer.
        **kwargs: Additional keyword arguments.
    """
    # Set public attributes
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.dropout_rate = dropout_rate

    # Initialize instance variables
    self.projection: layers.Dense | None = None
    self.q_dense: layers.Dense | None = None
    self.k_dense: layers.Dense | None = None
    self.v_dense: layers.Dense | None = None
    self.out_proj: layers.Dense | None = None
    self.final_dense: layers.Dense | None = None
    self.dropout_layer: layers.Dropout | None = None
    self.num_features: int | None = None
    self.depth: int | None = None

    # Validate parameters
    self._validate_params()

    # Call parent's __init__
    super().__init__(name=name, **kwargs)
Functions
split_heads
1
2
3
split_heads(
    x: KerasTensor, batch_size: KerasTensor
) -> KerasTensor

Split the last dimension into (num_heads, depth) and transpose.

Parameters:

Name Type Description Default
x KerasTensor

Input tensor with shape (batch_size, num_features, embed_dim).

required
batch_size KerasTensor

Batch size tensor.

required

Returns:

Type Description
KerasTensor

Tensor with shape (batch_size, num_heads, num_features, depth).

Source code in kerasfactory/layers/MultiHeadGraphFeaturePreprocessor.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def split_heads(self, x: KerasTensor, batch_size: KerasTensor) -> KerasTensor:
    """Split the last dimension into (num_heads, depth) and transpose.

    Args:
        x: Input tensor with shape (batch_size, num_features, embed_dim).
        batch_size: Batch size tensor.

    Returns:
        Tensor with shape (batch_size, num_heads, num_features, depth).
    """
    # Get the actual number of features from the input tensor
    actual_num_features = ops.shape(x)[1]

    x = ops.reshape(
        x,
        (batch_size, actual_num_features, self.num_heads, self.depth),
    )
    return ops.transpose(x, (0, 2, 1, 3))

πŸ”§ Parameters Deep Dive

embed_dim (int)

  • Purpose: Dimension of the feature embeddings
  • Range: 8 to 128+ (typically 16-64)
  • Impact: Larger values = more expressive embeddings but more parameters
  • Recommendation: Start with 16-32, scale based on data complexity

num_heads (int)

  • Purpose: Number of attention heads
  • Range: 1 to 16+ (typically 4-8)
  • Impact: More heads = more diverse attention patterns
  • Recommendation: Use 4-8 heads for most applications

dropout_rate (float)

  • Purpose: Dropout rate applied to attention weights
  • Range: 0.0 to 0.5 (typically 0.1-0.2)
  • Impact: Higher values = more regularization
  • Recommendation: Use 0.1-0.2 for regularization

πŸ“ˆ Performance Characteristics

  • Speed: ⚑⚑⚑ Fast for small to medium models, scales with heads and features
  • Memory: πŸ’ΎπŸ’ΎπŸ’Ύ Moderate memory usage due to multi-head attention
  • Accuracy: 🎯🎯🎯🎯 Excellent for complex feature relationship learning
  • Best For: Tabular data with complex feature relationships

🎨 Examples

Example 1: Complex Feature Relationships

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import keras
import numpy as np
from kerasfactory.layers import MultiHeadGraphFeaturePreprocessor

# Create a model for complex feature relationships
def create_complex_relationship_model():
    inputs = keras.Input(shape=(20,))  # 20 features

    # Multiple graph preprocessors for different relationship levels
    x = MultiHeadGraphFeaturePreprocessor(
        embed_dim=24,
        num_heads=6,
        dropout_rate=0.1
    )(inputs)

    x = keras.layers.Dense(48, activation='relu')(x)
    x = keras.layers.BatchNormalization()(x)

    x = MultiHeadGraphFeaturePreprocessor(
        embed_dim=20,
        num_heads=5,
        dropout_rate=0.1
    )(x)

    x = keras.layers.Dense(32, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)

    # Output
    outputs = keras.layers.Dense(1, activation='sigmoid')(x)

    return keras.Model(inputs, outputs)

model = create_complex_relationship_model()
model.compile(optimizer='adam', loss='binary_crossentropy')

# Test with sample data
sample_data = keras.random.normal((100, 20))
predictions = model(sample_data)
print(f"Complex relationship predictions shape: {predictions.shape}")

Example 2: Multi-Head Analysis

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Analyze multi-head behavior
def analyze_multi_head_behavior():
    # Create model with multi-head graph preprocessor
    inputs = keras.Input(shape=(15,))
    x = MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4)(inputs)
    outputs = keras.layers.Dense(1, activation='sigmoid')(x)

    model = keras.Model(inputs, outputs)

    # Test with different input patterns
    test_inputs = [
        keras.random.normal((10, 15)),  # Random data
        keras.random.normal((10, 15)) * 2,  # Scaled data
        keras.random.normal((10, 15)) + 1,  # Shifted data
    ]

    print("Multi-Head Behavior Analysis:")
    print("=" * 40)

    for i, test_input in enumerate(test_inputs):
        prediction = model(test_input)
        print(f"Test {i+1}: Prediction mean = {keras.ops.mean(prediction):.4f}")

    return model

# Analyze multi-head behavior
# model = analyze_multi_head_behavior()

Example 3: Attention Head Analysis

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Analyze attention head patterns
def analyze_attention_heads():
    # Create model with multi-head graph preprocessor
    inputs = keras.Input(shape=(12,))
    x = MultiHeadGraphFeaturePreprocessor(embed_dim=16, num_heads=4)(inputs)
    outputs = keras.layers.Dense(1, activation='sigmoid')(x)

    model = keras.Model(inputs, outputs)

    # Test with sample data
    sample_data = keras.random.normal((50, 12))
    predictions = model(sample_data)

    print("Attention Head Analysis:")
    print("=" * 40)
    print(f"Input shape: {sample_data.shape}")
    print(f"Output shape: {predictions.shape}")
    print(f"Model parameters: {model.count_params()}")

    return model

# Analyze attention heads
# model = analyze_attention_heads()

πŸ’‘ Tips & Best Practices

  • Embedding Dimension: Start with 16-32, scale based on data complexity
  • Number of Heads: Use 4-8 heads for most applications
  • Dropout Rate: Use 0.1-0.2 for regularization
  • Feature Relationships: Works best when features have complex relationships
  • Residual Connections: Built-in residual connections for gradient flow
  • Attention Patterns: Monitor attention patterns for interpretability

⚠️ Common Pitfalls

  • Embedding Dimension: Must be divisible by num_heads
  • Number of Heads: Must be positive integer
  • Dropout Rate: Must be between 0 and 1
  • Memory Usage: Scales with number of heads and features
  • Overfitting: Monitor for overfitting with complex configurations

πŸ“š Further Reading