π RowAttention
π RowAttention
π― Overview
The RowAttention layer implements a row-wise attention mechanism that dynamically weights samples based on their importance and relevance. Unlike traditional attention mechanisms that focus on feature relationships, this layer learns to assign attention weights to each sample (row) in the batch, allowing the model to focus on the most informative samples for each prediction.
This layer is particularly useful for sample weighting, outlier handling, and improving model performance by learning which samples are most important for the current context.
π How It Works
The RowAttention layer processes tabular data through a sample-wise attention mechanism:
- Sample Analysis: Analyzes each sample to understand its importance
- Attention Weight Generation: Uses a neural network to compute attention weights for each sample
- Softmax Normalization: Normalizes weights across the batch using softmax
- Dynamic Weighting: Applies learned weights to scale sample importance
graph TD
A[Input: batch_size, num_features] --> B[Sample Analysis]
B --> C[Attention Network]
C --> D[Sigmoid Activation]
D --> E[Softmax Normalization]
E --> F[Attention Weights]
A --> G[Element-wise Multiplication]
F --> G
G --> H[Weighted Samples Output]
style A fill:#e6f3ff,stroke:#4a86e8
style H fill:#e8f5e9,stroke:#66bb6a
style C fill:#fff9e6,stroke:#ffb74d
style E fill:#f3e5f5,stroke:#9c27b0
π‘ Why Use This Layer?
| Challenge | Traditional Approach | RowAttention's Solution |
|---|---|---|
| Sample Importance | Treat all samples equally | π― Automatic learning of sample importance per batch |
| Outlier Handling | Outliers can skew predictions | β‘ Dynamic weighting to down-weight outliers |
| Data Quality | No distinction between good/bad samples | ποΈ Quality-aware processing based on sample characteristics |
| Batch Effects | Ignore sample relationships within batch | π Context-aware weighting based on batch composition |
π Use Cases
- Sample Weighting: Automatically identifying and emphasizing important samples
- Outlier Detection: Down-weighting outliers and noisy samples
- Data Quality: Handling datasets with varying sample quality
- Batch Processing: Learning sample importance within each batch
- Imbalanced Data: Balancing the influence of different sample types
π Quick Start
Basic Usage
1 2 3 4 5 6 7 8 9 10 11 12 13 | |
In a Sequential Model
1 2 3 4 5 6 7 8 9 10 11 | |
In a Functional Model
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
Advanced Configuration
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | |
π API Reference
kerasfactory.layers.RowAttention
Row attention mechanism for weighting samples in a batch.
Classes
RowAttention
1 2 3 4 5 | |
Row attention mechanism to weight samples dynamically.
This layer applies attention weights to each sample (row) in the input tensor. The attention weights are computed using a two-layer neural network that takes each sample as input and outputs a scalar attention weight.
Example
1 2 3 4 5 6 7 8 9 10 11 | |
Initialize row attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim |
int
|
Number of input features |
required |
hidden_dim |
int | None
|
Hidden layer dimension. If None, uses feature_dim // 2 |
None
|
**kwargs |
dict[str, Any]
|
Additional layer arguments |
{}
|
Source code in kerasfactory/layers/RowAttention.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
Functions
classmethod
1 | |
Create layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config |
dict[str, Any]
|
Layer configuration dictionary |
required |
Returns:
| Type | Description |
|---|---|
RowAttention
|
RowAttention instance |
Source code in kerasfactory/layers/RowAttention.py
114 115 116 117 118 119 120 121 122 123 124 | |
π§ Parameters Deep Dive
feature_dim (int)
- Purpose: Number of input features for each sample
- Range: 1 to 1000+ (typically 10-100)
- Impact: Must match the number of features in your input
- Recommendation: Set to the output dimension of your previous layer
hidden_dim (int, optional)
- Purpose: Size of the hidden layer in the attention network
- Range: 1 to feature_dim (default: feature_dim // 2)
- Impact: Larger values = more complex attention patterns but more parameters
- Recommendation: Start with default, increase for complex sample relationships
π Performance Characteristics
- Speed: β‘β‘β‘β‘ Very fast - simple neural network computation
- Memory: πΎπΎ Low memory usage - minimal additional parameters
- Accuracy: π―π―π― Good for sample importance and outlier handling
- Best For: Tabular data where sample importance varies by context
π¨ Examples
Example 1: Outlier Detection and Handling
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | |
Example 2: Imbalanced Data Handling
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | |
Example 3: Quality-Aware Processing
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
π‘ Tips & Best Practices
- Placement: Use after initial feature processing but before final predictions
- Hidden Dimension: Start with feature_dim // 2, adjust based on complexity
- Batch Size: Works best with larger batch sizes for better softmax normalization
- Regularization: Combine with dropout and batch normalization for better generalization
- Interpretability: Access attention weights to understand sample importance
- Data Quality: Particularly effective with noisy or imbalanced data
β οΈ Common Pitfalls
- Input Shape: Must be 2D tensor (batch_size, feature_dim)
- Dimension Mismatch: feature_dim must match the number of features
- Small Batches: Softmax normalization works better with larger batches
- Overfitting: Can overfit on small datasets - use regularization
- Memory: Hidden dimension affects memory usage - keep reasonable
π Related Layers
- ColumnAttention - Column-wise attention for feature relationships
- TabularAttention - General tabular attention mechanism
- SparseAttentionWeighting - Sparse attention weights
- VariableSelection - Feature selection layer
π Further Reading
- Attention Mechanisms in Deep Learning - Understanding attention mechanisms
- Sample Weighting in Machine Learning - Sample weighting concepts
- KerasFactory Layer Explorer - Browse all available layers
- Data Quality Tutorial - Complete guide to data quality handling