Understand the fundamental concepts behind KerasFactory and how to effectively use its layers for modeling.
🎯 What is KerasFactory?
KerasFactory (KerasFactory) is a comprehensive collection of specialized layers designed exclusively for tabular data (but not only !!!). Unlike traditional neural network layers that were designed for images or sequences, KerasFactory layers understand the unique characteristics of tabular data.
Key Principles
Tabular-First Design: Every layer is optimized for tabular data characteristics
Production Ready: Battle-tested layers used in real-world applications
Keras 3 Native: Built specifically for Keras 3 with modern best practices
No TensorFlow Dependencies: Pure Keras implementation for maximum compatibility
📊 Understanding Tabular Data
Characteristics of Tabular Data
1 2 3 4 5 6 7 8 91011121314
# Example tabular datasetimportpandasaspdimportnumpyasnp# Sample tabular datadata={'age':[25,30,35,40,45],'income':[50000,75000,90000,110000,130000],'education':['Bachelor','Master','PhD','Bachelor','Master'],'city':['NYC','SF','LA','Chicago','Boston']}df=pd.DataFrame(data)print(df)
Key Characteristics:
- Mixed Data Types: Numerical and categorical features
- No Spatial Structure: Unlike images, features don't have spatial relationships
- Variable Importance: Some features are more important than others
- Missing Values: Common in real-world datasets
- Feature Interactions: Complex relationships between features
🏗️ Layer Architecture
Layer Categories
1. 🧠 Attention Layers
Focus on important features and relationships:
1 2 3 4 5 6 7 8 910
fromkerasfactory.layersimportTabularAttention,ColumnAttention,RowAttention# Tabular attention for feature relationshipsattention=TabularAttention(num_heads=8,key_dim=64)# Column attention for feature importancecol_attention=ColumnAttention(hidden_dim=64)# Row attention for sample relationshipsrow_attention=RowAttention(hidden_dim=64)
2. ⚙️ Preprocessing Layers
Handle data preparation and missing values:
1 2 3 4 5 6 7 8 9101112131415
fromkerasfactory.layersimport(DifferentiableTabularPreprocessor,DateParsingLayer,DateEncodingLayer)# End-to-end preprocessingpreprocessor=DifferentiableTabularPreprocessor(imputation_strategy='learnable',normalization='learnable')# Date handlingdate_parser=DateParsingLayer()date_encoder=DateEncodingLayer()
fromkerasfactory.layersimport(GatedResidualNetwork,TransformerBlock,TabularMoELayer)# Gated residual networkgrn=GatedResidualNetwork(units=64,dropout_rate=0.2)# Transformer blocktransformer=TransformerBlock(dim_model=64,num_heads=4)# Mixture of expertsmoe=TabularMoELayer(num_experts=4,expert_units=16)
5. 🛠️ Utility Layers
Essential tools for data processing:
1 2 3 4 5 6 7 8 91011121314
fromkerasfactory.layersimport(CastToFloat32Layer,NumericalAnomalyDetection,FeatureCutout)# Type castingcast_layer=CastToFloat32Layer()# Anomaly detectionanomaly_detector=NumericalAnomalyDetection()# Data augmentationcutout=FeatureCutout(cutout_prob=0.1)
# Fast configurationlayer=VariableSelection(hidden_dim=32,# Smaller hidden dimensiondropout=0.1# Light dropout)
🔍 Best Practices
1. Start Simple
Begin with basic layers and gradually add complexity:
12345678
# Start with preprocessingx=DifferentiableTabularPreprocessor()(inputs)# Add feature selectionx=VariableSelection(hidden_dim=64)(x)# Add attentionx=TabularAttention(num_heads=8,key_dim=64)(x)
2. Monitor Performance
Track training metrics and adjust accordingly:
1 2 3 4 5 6 7 8 9101112
# Monitor during trainingmodel.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])# Use callbacks for monitoringcallbacks=[keras.callbacks.EarlyStopping(patience=10),keras.callbacks.ReduceLROnPlateau(factor=0.5)]