Skip to content

🧠 TabularAttention

🧠 TabularAttention

πŸ”₯ Popular βœ… Stable 🟑 Intermediate

🎯 Overview

The TabularAttention layer implements a sophisticated dual attention mechanism specifically designed for tabular data. Unlike traditional attention mechanisms that focus on sequential data, this layer captures both inter-feature relationships (how features interact within each sample) and inter-sample relationships (how samples relate to each other across features).

This layer is particularly powerful for tabular datasets where understanding feature interactions and sample similarities is crucial for making accurate predictions. It's especially useful in scenarios where you have complex feature dependencies that traditional neural networks struggle to capture.

πŸ” How It Works

The TabularAttention layer processes tabular data through a two-stage attention mechanism:

  1. Inter-Feature Attention: Analyzes relationships between different features within each sample
  2. Inter-Sample Attention: Examines relationships between different samples across features
graph TD
    A[Input: batch_size, num_samples, num_features] --> B[Input Projection to d_model]
    B --> C[Inter-Feature Attention]
    C --> D[Feature LayerNorm + Residual]
    D --> E[Feed-Forward Network]
    E --> F[Feature LayerNorm + Residual]
    F --> G[Inter-Sample Attention]
    G --> H[Sample LayerNorm + Residual]
    H --> I[Output Projection]
    I --> J[Output: batch_size, num_samples, d_model]

    style A fill:#e6f3ff,stroke:#4a86e8
    style J fill:#e8f5e9,stroke:#66bb6a
    style C fill:#fff9e6,stroke:#ffb74d
    style G fill:#fff9e6,stroke:#ffb74d

πŸ’‘ Why Use This Layer?

Challenge Traditional Approach TabularAttention's Solution
Feature Interactions Manual feature engineering or simple concatenation 🧠 Automatic discovery of complex feature relationships through attention
Sample Relationships Treating samples independently πŸ”— Cross-sample learning to identify similar patterns and outliers
High-Dimensional Data Dimensionality reduction or feature selection ⚑ Efficient attention that scales to high-dimensional tabular data
Interpretability Black-box models with limited insights πŸ‘οΈ Attention weights provide insights into feature and sample importance

πŸ“Š Use Cases

  • Financial Risk Assessment: Understanding how different financial indicators interact and identifying similar risk profiles
  • Medical Diagnosis: Capturing complex relationships between symptoms and patient characteristics
  • Recommendation Systems: Learning user-item interactions and finding similar users/items
  • Anomaly Detection: Identifying unusual patterns by comparing samples across features
  • Feature Engineering: Automatically discovering meaningful feature combinations

πŸš€ Quick Start

Basic Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import keras
from kerasfactory.layers import TabularAttention

# Create sample tabular data
batch_size, num_samples, num_features = 32, 100, 20
x = keras.random.normal((batch_size, num_samples, num_features))

# Apply tabular attention
attention = TabularAttention(num_heads=8, d_model=64, dropout_rate=0.1)
output = attention(x)

print(f"Input shape: {x.shape}")   # (32, 100, 20)
print(f"Output shape: {output.shape}")  # (32, 100, 64)

In a Sequential Model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import keras
from kerasfactory.layers import TabularAttention

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu'),
    TabularAttention(num_heads=4, d_model=64, dropout_rate=0.1),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In a Functional Model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import keras
from kerasfactory.layers import TabularAttention

# Define inputs
inputs = keras.Input(shape=(100, 20))  # 100 samples, 20 features

# Apply attention
x = TabularAttention(num_heads=8, d_model=128, dropout_rate=0.1)(inputs)

# Add more processing
x = keras.layers.Dense(64, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)

# Create model
model = keras.Model(inputs, outputs)

Advanced Configuration

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# Advanced configuration with custom parameters
attention = TabularAttention(
    num_heads=16,           # More attention heads for complex patterns
    d_model=256,            # Higher dimensionality for rich representations
    dropout_rate=0.2,       # Higher dropout for regularization
    name="advanced_attention"
)

# Use in a complex model
inputs = keras.Input(shape=(50, 30))
x = keras.layers.Dense(256)(inputs)
x = attention(x)
x = keras.layers.LayerNormalization()(x)
x = keras.layers.Dense(128, activation='relu')(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs, outputs)

πŸ“– API Reference

kerasfactory.layers.TabularAttention

This module implements a TabularAttention layer that applies inter-feature and inter-sample attention mechanisms for tabular data. It's particularly useful for capturing complex relationships between features and samples in tabular datasets.

Classes

TabularAttention
1
2
3
4
5
6
7
TabularAttention(
    num_heads: int,
    d_model: int,
    dropout_rate: float = 0.1,
    name: str | None = None,
    **kwargs: Any
)

Custom layer to apply inter-feature and inter-sample attention for tabular data.

This layer implements a dual attention mechanism: 1. Inter-feature attention: Captures dependencies between features for each sample 2. Inter-sample attention: Captures dependencies between samples for each feature

The layer uses MultiHeadAttention for both attention mechanisms and includes layer normalization, dropout, and a feed-forward network.

Parameters:

Name Type Description Default
num_heads int

Number of attention heads

required
d_model int

Dimensionality of the attention model

required
dropout_rate float

Dropout rate for regularization

0.1
name str

Name for the layer

None
Input shape

Tensor with shape: (batch_size, num_samples, num_features)

Output shape

Tensor with shape: (batch_size, num_samples, d_model)

Example
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import keras
from kerasfactory.layers import TabularAttention

# Create sample input data
x = keras.random.normal((32, 100, 20))  # 32 batches, 100 samples, 20 features

# Apply tabular attention
attention = TabularAttention(num_heads=4, d_model=32, dropout_rate=0.1)
y = attention(x)
print("Output shape:", y.shape)  # (32, 100, 32)

Initialize the TabularAttention layer.

Parameters:

Name Type Description Default
num_heads int

Number of attention heads.

required
d_model int

Model dimension.

required
dropout_rate float

Dropout rate.

0.1
name str | None

Name of the layer.

None
**kwargs Any

Additional keyword arguments.

{}
Source code in kerasfactory/layers/TabularAttention.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    num_heads: int,
    d_model: int,
    dropout_rate: float = 0.1,
    name: str | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the TabularAttention layer.

    Args:
        num_heads: Number of attention heads.
        d_model: Model dimension.
        dropout_rate: Dropout rate.
        name: Name of the layer.
        **kwargs: Additional keyword arguments.
    """
    # Set private attributes first
    self._num_heads = num_heads
    self._d_model = d_model
    self._dropout_rate = dropout_rate

    # Validate parameters
    self._validate_params()

    # Set public attributes BEFORE calling parent's __init__
    self.num_heads = self._num_heads
    self.d_model = self._d_model
    self.dropout_rate = self._dropout_rate

    # Initialize layers
    self.input_projection: layers.Dense | None = None
    self.feature_attention: layers.MultiHeadAttention | None = None
    self.feature_layernorm: layers.LayerNormalization | None = None
    self.feature_dropout: layers.Dropout | None = None
    self.feature_layernorm2: layers.LayerNormalization | None = None
    self.feature_dropout2: layers.Dropout | None = None
    self.sample_attention: layers.MultiHeadAttention | None = None
    self.sample_layernorm: layers.LayerNormalization | None = None
    self.sample_dropout: layers.Dropout | None = None
    self.sample_layernorm2: layers.LayerNormalization | None = None
    self.sample_dropout2: layers.Dropout | None = None
    self.ffn_dense1: layers.Dense | None = None
    self.ffn_dense2: layers.Dense | None = None
    self.output_projection: layers.Dense | None = None

    # Call parent's __init__ after setting public attributes
    super().__init__(name=name, **kwargs)
Functions
compute_output_shape
1
2
3
compute_output_shape(
    input_shape: tuple[int, ...]
) -> tuple[int, ...]

Compute the output shape of the layer.

Parameters:

Name Type Description Default
input_shape tuple[int, ...]

Shape of the input tensor.

required

Returns:

Type Description
tuple[int, ...]

Shape of the output tensor.

Source code in kerasfactory/layers/TabularAttention.py
223
224
225
226
227
228
229
230
231
232
def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
    """Compute the output shape of the layer.

    Args:
        input_shape: Shape of the input tensor.

    Returns:
        Shape of the output tensor.
    """
    return (input_shape[0], input_shape[1], self.d_model)

πŸ”§ Parameters Deep Dive

num_heads (int)

  • Purpose: Number of attention heads for parallel processing
  • Range: 1 to 64+ (typically 4, 8, or 16)
  • Impact: More heads = better pattern recognition but higher computational cost
  • Recommendation: Start with 8, increase if you have complex feature interactions

d_model (int)

  • Purpose: Dimensionality of the attention model
  • Range: 32 to 512+ (must be divisible by num_heads)
  • Impact: Higher values = richer representations but more parameters
  • Recommendation: Start with 64-128, scale based on your data complexity

dropout_rate (float)

  • Purpose: Regularization to prevent overfitting
  • Range: 0.0 to 0.9
  • Impact: Higher values = more regularization but potentially less learning
  • Recommendation: Start with 0.1, increase if overfitting occurs

πŸ“ˆ Performance Characteristics

  • Speed: ⚑⚑⚑ Fast for small to medium datasets, scales well with parallel processing
  • Memory: πŸ’ΎπŸ’ΎπŸ’Ύ Moderate memory usage due to attention computations
  • Accuracy: 🎯🎯🎯🎯 Excellent for complex tabular data with feature interactions
  • Best For: Tabular data with complex feature relationships and sample similarities

🎨 Examples

Example 1: Customer Segmentation

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import keras
import numpy as np
from kerasfactory.layers import TabularAttention

# Simulate customer data: age, income, spending, credit_score, etc.
num_customers, num_features = 1000, 15
customer_data = keras.random.normal((32, num_customers, num_features))

# Build segmentation model
inputs = keras.Input(shape=(num_customers, num_features))
x = TabularAttention(num_heads=8, d_model=64)(inputs)
x = keras.layers.GlobalAveragePooling1D()(x)  # Pool across samples
x = keras.layers.Dense(32, activation='relu')(x)
segments = keras.layers.Dense(5, activation='softmax')(x)  # 5 customer segments

model = keras.Model(inputs, segments)
model.compile(optimizer='adam', loss='categorical_crossentropy')

Example 2: Time Series Forecasting

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# For time series data where each sample is a time point
time_steps, features = 30, 10
ts_data = keras.random.normal((32, time_steps, features))

# Build forecasting model
inputs = keras.Input(shape=(time_steps, features))
x = TabularAttention(num_heads=4, d_model=32)(inputs)
x = keras.layers.Dense(16, activation='relu')(x)
forecast = keras.layers.Dense(1)(x)  # Predict next value

model = keras.Model(inputs, forecast)
model.compile(optimizer='adam', loss='mse')

Example 3: Multi-Task Learning

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# Shared attention for multiple related tasks
inputs = keras.Input(shape=(100, 20))

# Shared attention layer
shared_attention = TabularAttention(num_heads=8, d_model=128)
x = shared_attention(inputs)

# Task-specific heads
task1_output = keras.layers.Dense(1, activation='sigmoid', name='classification')(x)
task2_output = keras.layers.Dense(1, name='regression')(x)

model = keras.Model(inputs, [task1_output, task2_output])
model.compile(
    optimizer='adam',
    loss={'classification': 'binary_crossentropy', 'regression': 'mse'},
    loss_weights={'classification': 1.0, 'regression': 0.5}
)

πŸ’‘ Tips & Best Practices

  • Start Simple: Begin with 4-8 attention heads and d_model=64, then scale up
  • Data Preprocessing: Ensure your tabular data is properly normalized before applying attention
  • Batch Size: Use larger batch sizes (32+) for better attention learning
  • Layer Order: Place TabularAttention after initial feature processing but before final predictions
  • Regularization: Use dropout and layer normalization to prevent overfitting
  • Monitoring: Watch attention weights to understand what the model is learning

⚠️ Common Pitfalls

  • Memory Issues: Large d_model values can cause memory problems - start smaller
  • Overfitting: Too many heads or too high d_model can lead to overfitting on small datasets
  • Input Shape: Ensure input is 3D: (batch_size, num_samples, num_features)
  • Divisibility: d_model must be divisible by num_heads
  • Gradient Issues: Use gradient clipping if training becomes unstable

πŸ“š Further Reading