import keras
import numpy as np
from loguru import logger
from typing import Tuple, Dict, Any, Optional
from kerasfactory.layers import TabularAttention
def create_tabular_attention_model(input_dim: int, num_classes: int) -> keras.Model:
"""Create a model using TabularAttention for dual attention mechanisms.
TabularAttention implements both inter-feature and inter-sample attention,
making it ideal for capturing complex relationships in tabular data.
Args:
input_dim: Number of input features.
num_classes: Number of output classes.
Returns:
keras.Model: Compiled model ready for training.
Example:
```python
import keras
model = create_tabular_attention_model(input_dim=20, num_classes=3)
```
"""
# Input layer
inputs = keras.Input(shape=(input_dim,), name='tabular_input')
# TabularAttention layer with comprehensive configuration
attention_layer = TabularAttention(
num_heads=8, # 8 attention heads for rich representation
key_dim=64, # 64-dimensional key vectors
dropout=0.1, # 10% dropout for regularization
use_attention_weights=True, # Return attention weights for interpretation
attention_activation='softmax', # Softmax activation for attention weights
name='tabular_attention'
)
# Apply attention
x = attention_layer(inputs)
# Output layer
outputs = keras.layers.Dense(
num_classes,
activation='softmax',
name='predictions'
)(x)
# Create and compile model
model = keras.Model(inputs, outputs, name='tabular_attention_model')
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
# Usage example
def demonstrate_tabular_attention() -> Tuple[keras.Model, keras.callbacks.History]:
"""Demonstrate TabularAttention with sample data.
Creates and trains a TabularAttention model on random sample data,
evaluating its performance and returning the trained model and history.
Returns:
Tuple[keras.Model, keras.callbacks.History]: Trained model and training history.
Example:
```python
model, history = demonstrate_tabular_attention()
```
"""
# Create sample data
X_train = np.random.random((1000, 20))
y_train = np.random.randint(0, 3, (1000,))
y_train = keras.utils.to_categorical(y_train, 3)
# Create model
model = create_tabular_attention_model(input_dim=20, num_classes=3)
# Train model
history = model.fit(
X_train, y_train,
validation_split=0.2,
epochs=10,
batch_size=32,
verbose=1
)
# Evaluate model
test_loss, test_accuracy = model.evaluate(X_train, y_train, verbose=0)
logger.info(f"Model accuracy: {test_accuracy:.4f}")
return model, history
# Run demonstration
# model, history = demonstrate_tabular_attention()