ποΈ InterpretableMultiHeadAttention
ποΈ InterpretableMultiHeadAttention
π― Overview
The InterpretableMultiHeadAttention layer is a specialized multi-head attention mechanism designed for interpretability and explainability. Unlike standard attention layers that hide their internal workings, this layer exposes attention scores, allowing you to understand exactly what the model is focusing on during its decision-making process.
This layer is particularly valuable for applications where model interpretability is crucial, such as healthcare, finance, and other high-stakes domains where understanding model decisions is as important as accuracy.
π How It Works
The InterpretableMultiHeadAttention layer extends the standard multi-head attention mechanism with interpretability features:
- Multi-Head Processing: Processes input through multiple attention heads in parallel
- Attention Score Storage: Captures and stores attention weights for each head
- Score Accessibility: Provides easy access to attention scores for analysis
- Interpretable Output: Returns both the attention output and accessible attention weights
graph TD
A[Query, Key, Value] --> B[Multi-Head Attention]
B --> C[Head 1]
B --> D[Head 2]
B --> E[Head N]
C --> F[Attention Scores 1]
D --> G[Attention Scores 2]
E --> H[Attention Scores N]
F --> I[Concatenate Heads]
G --> I
H --> I
I --> J[Output + Stored Scores]
style A fill:#e6f3ff,stroke:#4a86e8
style J fill:#e8f5e9,stroke:#66bb6a
style F fill:#fff9e6,stroke:#ffb74d
style G fill:#fff9e6,stroke:#ffb74d
style H fill:#fff9e6,stroke:#ffb74d
π‘ Why Use This Layer?
| Challenge | Traditional Approach | InterpretableMultiHeadAttention's Solution |
|---|---|---|
| Model Interpretability | Black-box attention with no visibility | ποΈ Transparent attention with accessible attention scores |
| Debugging Models | Difficult to understand what model focuses on | π Clear visibility into attention patterns and focus areas |
| Regulatory Compliance | Limited explainability for high-stakes decisions | π Full traceability of attention decisions for compliance |
| Model Validation | Hard to validate attention behavior | β Easy validation through attention score analysis |
π Use Cases
- Healthcare AI: Understanding which medical features drive diagnoses
- Financial Risk: Explaining which factors influence risk assessments
- Regulatory Compliance: Providing interpretable decisions for auditors
- Model Debugging: Identifying attention patterns and potential issues
- Research: Analyzing attention mechanisms in academic studies
- Customer-Facing AI: Explaining decisions to end users
π Quick Start
Basic Usage
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | |
In a Sequential Model
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
In a Functional Model
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | |
Advanced Configuration
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | |
π API Reference
kerasfactory.layers.InterpretableMultiHeadAttention
Interpretable Multi-Head Attention layer implementation.
Classes
InterpretableMultiHeadAttention
1 2 3 4 5 6 | |
Interpretable Multi-Head Attention layer.
This layer wraps Keras MultiHeadAttention and stores the attention scores
for interpretability purposes. The attention scores can be accessed via
the attention_scores attribute after calling the layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_model |
int
|
Size of each attention head for query, key, value. |
required |
n_head |
int
|
Number of attention heads. |
required |
dropout_rate |
float
|
Dropout probability. Default: 0.1. |
0.1
|
**kwargs |
dict[str, Any]
|
Additional arguments passed to MultiHeadAttention. Supported arguments: - value_dim: Size of each attention head for value. - use_bias: Whether to use bias. Default: True. - output_shape: Expected output shape. Default: None. - attention_axes: Axes for attention. Default: None. - kernel_initializer: Initializer for kernels. Default: 'glorot_uniform'. - bias_initializer: Initializer for biases. Default: 'zeros'. - kernel_regularizer: Regularizer for kernels. Default: None. - bias_regularizer: Regularizer for biases. Default: None. - activity_regularizer: Regularizer for activity. Default: None. - kernel_constraint: Constraint for kernels. Default: None. - bias_constraint: Constraint for biases. Default: None. - seed: Random seed for dropout. Default: None. |
{}
|
Call Args
query: Query tensor of shape (B, S, E) where B is batch size,
S is sequence length, and E is the feature dimension.
key: Key tensor of shape (B, S, E).
value: Value tensor of shape (B, S, E).
training: Python boolean indicating whether the layer should behave in
training mode (applying dropout) or in inference mode (no dropout).
Returns:
| Name | Type | Description |
|---|---|---|
output |
Attention output of shape |
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
Initialize the layer.
Source code in kerasfactory/layers/InterpretableMultiHeadAttention.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | |
Functions
classmethod
1 2 3 | |
Create layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config |
dict[str, Any]
|
Layer configuration dictionary |
required |
Returns:
| Type | Description |
|---|---|
InterpretableMultiHeadAttention
|
Layer instance |
Source code in kerasfactory/layers/InterpretableMultiHeadAttention.py
152 153 154 155 156 157 158 159 160 161 162 | |
π§ Parameters Deep Dive
d_model (int)
- Purpose: Size of each attention head for query, key, and value
- Range: 16 to 512+ (typically 64-256)
- Impact: Higher values = richer representations but more parameters
- Recommendation: Start with 64, scale based on data complexity
n_head (int)
- Purpose: Number of attention heads for parallel processing
- Range: 1 to 32+ (typically 4, 8, or 16)
- Impact: More heads = better pattern recognition but higher computational cost
- Recommendation: Start with 8, increase for complex patterns
dropout_rate (float)
- Purpose: Regularization to prevent overfitting
- Range: 0.0 to 0.9
- Impact: Higher values = more regularization but potentially less learning
- Recommendation: Start with 0.1, adjust based on overfitting
π Performance Characteristics
- Speed: β‘β‘β‘ Fast for small to medium datasets, scales with head count
- Memory: πΎπΎπΎ Moderate memory usage due to attention score storage
- Accuracy: π―π―π―π― Excellent for complex patterns with interpretability
- Best For: Applications requiring both high performance and interpretability
π¨ Examples
Example 1: Medical Diagnosis with Interpretability
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | |
Example 2: Financial Risk Assessment
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
Example 3: Attention Visualization
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | |
π‘ Tips & Best Practices
- Interpretability: Access
attention_scoresattribute after each forward pass - Head Analysis: Analyze individual heads to understand different attention patterns
- Visualization: Use attention scores for heatmap visualizations
- Regularization: Use appropriate dropout to prevent overfitting
- Head Count: Start with 8 heads, adjust based on complexity
- Memory: Be aware that attention scores increase memory usage
β οΈ Common Pitfalls
- Memory Usage: Storing attention scores increases memory consumption
- Score Access: Must access scores immediately after forward pass
- Head Interpretation: Different heads may focus on different patterns
- Overfitting: Complex attention can overfit on small datasets
- Performance: More heads = higher computational cost
π Related Layers
- TabularAttention - General tabular attention mechanism
- MultiResolutionTabularAttention - Multi-resolution attention
- ColumnAttention - Column-wise attention
- RowAttention - Row-wise attention
π Further Reading
- Attention Is All You Need - Original Transformer paper
- The Annotated Transformer - Detailed attention explanation
- Attention Visualization in Deep Learning - Attention visualization techniques
- KerasFactory Layer Explorer - Browse all available layers
- Model Interpretability Tutorial - Complete guide to model interpretability