Model Implementation Guide

 1 # 🤖 Model Implementation Guide for KerasFactory  2  3 This guide outlines the complete process and best practices for implementing new models in the KerasFactory project. Follow the checklists to ensure your implementation meets all KerasFactory standards.  4  5 ## 📋 Model Implementation Checklist  6  7 Use this checklist when implementing a new model. Check off each item as you complete it.  8  9 ### Phase 1: Planning & Design  10 - [ ] Define Purpose: Clearly document what the model does and when to use it  11 - [ ] Review Architecture: Design the model architecture (layers, connections, data flow)  12 - [ ] Plan Layers: Identify which layers the model needs  13  - [ ] Check if all required layers exist in kerasfactory/layers/  14  - [ ] Plan to implement missing layers separately first  15  - [ ] Prioritize reusability (create standalone layers, not embedded logic)  16 - [ ] Define Inputs/Outputs: Plan input and output specifications  17 - [ ] Document Algorithm: Write mathematical description or pseudo-code  18  19 ### Phase 2: Layer Implementation (if needed)  20 IMPORTANT: Implement any missing layers FIRST as standalone, reusable components.  21  22 For each missing layer:  23 - [ ] Follow the Layer Implementation Checklist from layers_implementation_guide.md  24 - [ ] Implement layer code  25 - [ ] Write comprehensive tests  26 - [ ] Create documentation  27 - [ ] Update API references  28 - [ ] Run all tests and linting  29  30 ### Phase 3: Implementation - Core Model Code  31 - [ ] Create File: Create kerasfactory/models/YourModelName.py following naming conventions  32 - [ ] Add Module Docstring: Document the module's purpose  33 - [ ] Implement Pure Keras 3: Use only Keras operations (no TensorFlow)  34 - [ ] Apply @register_keras_serializable: Decorate class with @register_keras_serializable(package="kerasfactory.models")  35 - [ ] Inherit from BaseModel: Extend kerasfactory.models._base.BaseModel  36 - [ ] Implement init:   37  - [ ] Set private attributes first (self._param = param)  38  - [ ] Validate parameters (in init or _validate_params)  39  - [ ] Set public attributes (self.param = self._param)  40  - [ ] Call super().__init__(name=name, **kwargs) AFTER setting public attributes  41 - [ ] Implement _validate_params: Add parameter validation logic  42 - [ ] Implement build(): Initialize all layers and sublayers  43 - [ ] Implement call(): Implement forward pass with Keras operations only  44 - [ ] Implement get_config(): Return all constructor parameters  45 - [ ] Add Type Hints: All methods and parameters have proper type annotations  46 - [ ] Add Logging: Use loguru for debug messages  47 - [ ] Add Comprehensive Docstring: Google-style docstring with:  48  - [ ] Description  49  - [ ] Parameters  50  - [ ] Input/output shapes  51  - [ ] Usage examples  52  - [ ] References (if applicable)  53  54 ### Phase 4: Unit Tests  55 - [ ] Create Test File: Create tests/models/test__YourModelName.py  56 - [ ] Test Initialization:   57  - [ ] Default parameters  58  - [ ] Custom parameters  59  - [ ] Invalid parameters (should raise errors)  60 - [ ] Test Model Building: Build with different input shapes  61 - [ ] Test Output Shape: Verify output shapes match expected values  62 - [ ] Test Output Type: Verify output is correct dtype  63 - [ ] Test Different Batch Sizes: Test with various batch dimensions  64 - [ ] Test Forward Pass: Model produces valid outputs  65 - [ ] Test Training Loop:  66  - [ ] Can compile the model  67  - [ ] Can train for multiple epochs  68  - [ ] Loss decreases over training  69 - [ ] Test Serialization:  70  - [ ] get_config() returns correct dict  71  - [ ] from_config() recreates model correctly  72  - [ ] keras.saving.serialize_keras_object() works  73  - [ ] keras.saving.deserialize_keras_object() works  74  - [ ] Model can be saved/loaded (.keras format)  75  - [ ] Weights can be saved/loaded (.h5 format)  76  - [ ] Predictions consistent after loading  77 - [ ] Test Deterministic Output: Same input produces same output (with same seed)  78 - [ ] Test Layer Integration: All constituent layers work correctly together  79 - [ ] Test Prediction: Model can make predictions on new data  80 - [ ] All Tests Pass: Run pytest tests/models/test__YourModelName.py -v  81  82 ### Phase 5: Documentation  83 - [ ] Create Documentation File: Create docs/models/your-model-name.md  84 - [ ] Follow Template: Use structure from similar model in docs/models/  85 - [ ] Include Comprehensive Sections:  86  - [ ] Overview and problem it solves  87  - [ ] Architecture overview with diagram  88  - [ ] Key features and innovations  89  - [ ] Input/output specifications  90  - [ ] Parameters and their impact  91  - [ ] Quick start example  92  - [ ] Advanced usage (custom training loop, transfer learning, etc.)  93  - [ ] Performance characteristics and benchmarks  94  - [ ] Comparison with related architectures  95  - [ ] Training best practices  96  - [ ] Common issues & troubleshooting  97  - [ ] Integration with other KerasFactory components  98  - [ ] References and citations  99 - [ ] Add Code Examples: Real, working examples (training, evaluation, prediction)  100 - [ ] Include Mathematical Details: Equations, loss functions, optimization details  101 - [ ] Add Visual Aids: Architecture diagrams, Mermaid diagrams, flowcharts  102 - [ ] Include Reproducibility Info: Random seeds, hardware requirements, etc.  103  104 ### Phase 6: Jupyter Notebook Example  105 - [ ] Create Notebook: Create notebooks/your_model_name_demo.ipynb or your_model_name_end_to_end_demo.ipynb  106 - [ ] Include Sections:  107  - [ ] Title and description  108  - [ ] Setup and imports  109  - [ ] Data generation/loading  110  - [ ] Data exploration/visualization  111  - [ ] Model creation and architecture overview  112  - [ ] Model training with visualization  113  - [ ] Model evaluation  114  - [ ] Predictions and visualization  115  - [ ] Performance comparison (if applicable)  116  - [ ] Model serialization and loading  117  - [ ] Best practices and tips  118  - [ ] Summary and conclusions  119 - [ ] Add Visualizations:   120  - [ ] Training curves (loss, metrics)  121  - [ ] Predictions vs actual  122  - [ ] Performance metrics  123  - [ ] Model comparisons (if applicable)  124 - [ ] Include Output: Run all cells to verify they work  125 - [ ] Use Interactive Plots: Plotly for better interactivity  126  127 ### Phase 7: Integration & Updates  128 - [ ] Update Imports: Add to kerasfactory/models/__init__.py  129  - [ ] Add import statement  130  - [ ] Add model name to __all__ list  131 - [ ] Update API Documentation: Add entry to docs/api/models.md  132  - [ ] Add model name and description  133  - [ ] Include autodoc reference (kerasfactory.models.YourModelName)  134  - [ ] List key features  135  - [ ] Add use case recommendations  136 - [ ] Update Models Overview: If exists, add to docs/models_overview.md or similar  137 - [ ] Update Main README: If it's a significant model  138  - [ ] Add to feature list  139  - [ ] Link to documentation  140 - [ ] Update Tutorials: If introducing new concepts  141 - [ ] Update Data Analyzer: If applicable, add to kerasfactory/utils/data_analyzer.py  142  143 ### Phase 8: Quality Assurance  144 - [ ] Run All Tests:   145  - [ ] Model tests pass: pytest tests/models/test__YourModelName.py -v  146  - [ ] All layer tests pass: pytest tests/layers/ -v  147  - [ ] No regressions: pytest tests/ -v  148 - [ ] Pre-commit Hooks: Run pre-commit run --all-files  149  - [ ] Black formatting passes  150  - [ ] Ruff linting passes  151  - [ ] No unused imports or variables  152  - [ ] Proper type hints  153  - [ ] Docstring formatting  154 - [ ] Documentation Build: mkdocs serve builds without errors  155  - [ ] No broken links  156  - [ ] All images load correctly  157  - [ ] Code examples render properly  158 - [ ] Notebook Execution: Run full notebook end-to-end  159  - [ ] All cells execute without errors  160  - [ ] Visualizations render correctly  161  - [ ] No performance issues (reasonable execution time)  162 - [ ] Code Review: Request code review from team  163 - [ ] Integration Test: Test model in real-world scenario  164 - [ ] Performance Test: Verify model meets performance requirements  165  166 ---  167  168 ## Key Requirements  169  170 ### ✅ Keras 3 Only  171 All model implementations MUST use only Keras 3 operations and layers. NO TensorFlow dependencies are allowed in model implementations.  172 - Allowed: keras.layers, keras.ops, kerasfactory.layers, kerasfactory.models  173 - NOT Allowed: tensorflow.python.*, tf.nn.* (use keras.ops.* instead)  174 - Exception: TensorFlow can ONLY be used in test files and notebooks for validation  175  176 ### ✅ Reusable Components  177 Avoid embedding layer logic directly in models. Create standalone, reusable layers first:  178 - Good: Implement TemporalMixing as a layer, use it in TSMixer model  179 - Bad: Implement temporal mixing logic directly in model  180  181 ### ✅ Proper Inheritance  182 - Models must inherit from kerasfactory.models._base.BaseModel  183 - Layers must inherit from kerasfactory.layers._base_layer.BaseLayer  184  185 ### ✅ Type Annotations (Python 3.12+)  186 Use modern type hints with the union operator:  187 python  188 param: int | float = 0.1 # Instead of Union[int, float]  189   190  191 ### ✅ Comprehensive Documentation  192 Every model needs extensive documentation covering usage, architecture, and best practices.  193  194 ---  195  196 ## Implementation Pattern  197  198 Follow this pattern for implementing models:  199  200 python  201 """  202 Module docstring describing the model's purpose and functionality.  203 """  204  205 from typing import Any  206 from loguru import logger  207 from keras import layers, ops  208 from keras import KerasTensor  209 from keras.saving import register_keras_serializable  210 from kerasfactory.models._base import BaseModel  211 from kerasfactory.layers import YourCustomLayer # Use existing layers  212  213 @register_keras_serializable(package="kerasfactory.models")  214 class YourCustomModel(BaseModel):  215  """Comprehensive model description.  216    217  This model implements [algorithm/architecture] for [task].  218  It combines multiple layers to [describe what it does].  219    220  Args:  221  param1: Description with type and default.  222  param2: Description with type and default.  223  name: Optional name for the model.  224    225  Input shape:  226  `(batch_size, ...)` - Description of input.  227    228  Output shape:  229  `(batch_size, ...)` - Description of output.  230    231  Example:  232 python  233  import keras  234  from kerasfactory.models import YourCustomModel  235    236  # Create model  237  model = YourCustomModel(param1=value1, param2=value2)  238  model.compile(optimizer='adam', loss='mse')  239    240  # Train  241  model.fit(X_train, y_train, epochs=10)  242    243  # Predict  244  predictions = model.predict(X_test)  245    246    247  References:  248  - Author et al. (Year). "Paper Title". Journal.  249  """  250  251  def __init__(  252  self,  253  param1: int = 32,  254  param2: float = 0.1,  255  name: str | None = None,  256  **kwargs: Any  257  ) -> None:  258  # Set private attributes  259  self._param1 = param1  260  self._param2 = param2  261  262  # Validate parameters  263  self._validate_params()  264  265  # Set public attributes BEFORE super().__init__()  266  self.param1 = self._param1  267  self.param2 = self._param2  268  269  # Call parent's __init__  270  super().__init__(name=name, **kwargs)  271  272  def _validate_params(self) -> None:  273  """Validate model parameters."""  274  if self._param1 < 1:  275  raise ValueError(f"param1 must be >= 1, got {self._param1}")  276  if not (0 <= self._param2 <= 1):  277  raise ValueError(f"param2 must be in [0, 1], got {self._param2}")  278  279  def build(self, input_shape: tuple[int, ...] | list[tuple[int, ...]]) -> None:  280  """Build model with given input shape(s).  281  282  Args:  283  input_shape: Tuple(s) of integers defining input shape(s).  284  """  285  # Initialize all layers  286  self.layer1 = YourCustomLayer(self._param1)  287  self.layer2 = layers.Dense(self._param1)  288  self.output_layer = layers.Dense(10) # or task-specific output  289    290  logger.debug(f"Building {self.__class__.__name__} with params: "  291  f"param1={self.param1}, param2={self.param2}")  292  super().build(input_shape)  293  294  def call(self, inputs: KerasTensor, training: bool | None = None) -> KerasTensor:  295  """Forward pass.  296  297  Args:  298  inputs: Input tensor(s).  299  training: Whether in training mode.  300  301  Returns:  302  Model output tensor.  303  """  304  # Forward pass through layers  305  x = self.layer1(inputs, training=training)  306  x = self.layer2(x)  307  output = self.output_layer(x)  308  return output  309  310  def get_config(self) -> dict[str, Any]:  311  """Returns model configuration.  312  313  Returns:  314  Dictionary with model configuration.  315  """  316  config = super().get_config()  317  config.update({  318  "param1": self.param1,  319  "param2": self.param2,  320  })  321  return config  322   323  324 ---  325  326 ## Model Serialization & Loading  327  328 Ensure your model can be saved and loaded correctly:  329  330 python  331 import keras  332 import tempfile  333  334 # Create and train model  335 model = YourCustomModel(param1=32, param2=0.1)  336 model.compile(optimizer='adam', loss='mse', metrics=['mae'])  337 model.fit(X_train, y_train, epochs=10, verbose=0)  338  339 # Save full model  340 with tempfile.TemporaryDirectory() as tmpdir:  341  # Save with architecture  342  model.save(f'{tmpdir}/model.keras')  343    344  # Load full model  345  loaded_model = keras.models.load_model(f'{tmpdir}/model.keras')  346    347  # Verify predictions are identical  348  pred1 = model.predict(X_test)  349  pred2 = loaded_model.predict(X_test)  350    351  # Save only weights  352  model.save_weights(f'{tmpdir}/weights.h5')  353    354  # Load weights into new model  355  new_model = YourCustomModel(param1=32, param2=0.1)  356  new_model.load_weights(f'{tmpdir}/weights.h5')  357   358  359 ---  360  361 ## Testing Template  362  363 Create comprehensive tests following this template:  364  365 python  366 import unittest  367 import numpy as np  368 import tensorflow as tf  369 import keras  370  371 from kerasfactory.models import YourCustomModel  372  373 class TestYourCustomModel(unittest.TestCase):  374  """Test suite for YourCustomModel."""  375  376  def setUp(self) -> None:  377  """Set up test fixtures."""  378  self.model = YourCustomModel(param1=32, param2=0.1)  379  self.model.compile(optimizer='adam', loss='mse', metrics=['mae'])  380    381  # Create sample data  382  self.X_train = np.random.randn(100, 20).astype(np.float32)  383  self.y_train = np.random.randn(100, 10).astype(np.float32)  384  self.X_test = np.random.randn(20, 20).astype(np.float32)  385  self.y_test = np.random.randn(20, 10).astype(np.float32)  386  387  def test_initialization(self) -> None:  388  """Test model initialization."""  389  self.assertEqual(self.model.param1, 32)  390  self.assertEqual(self.model.param2, 0.1)  391  392  def test_invalid_parameters(self) -> None:  393  """Test invalid parameter handling."""  394  with self.assertRaises(ValueError):  395  YourCustomModel(param1=-1)  396  397  def test_forward_pass(self) -> None:  398  """Test forward pass."""  399  output = self.model(self.X_test)  400  self.assertEqual(output.shape, (20, 10))  401  402  def test_training(self) -> None:  403  """Test model training."""  404  history = self.model.fit(  405  self.X_train, self.y_train,  406  epochs=2, batch_size=32, verbose=0  407  )  408    409  # Verify training occurred (loss changed)  410  self.assertIsNotNone(history.history['loss'])  411  412  def test_serialization(self) -> None:  413  """Test model serialization."""  414  config = self.model.get_config()  415  new_model = YourCustomModel.from_config(config)  416  new_model.compile(optimizer='adam', loss='mse')  417    418  output1 = self.model(self.X_test)  419  output2 = new_model(self.X_test)  420    421  np.testing.assert_allclose(output1, output2, rtol=1e-5)  422  423  def test_save_load(self) -> None:  424  """Test model save and load."""  425  import tempfile  426    427  with tempfile.TemporaryDirectory() as tmpdir:  428  model_path = f'{tmpdir}/model.keras'  429  self.model.save(model_path)  430    431  loaded_model = keras.models.load_model(model_path)  432  loaded_model.compile(optimizer='adam', loss='mse')  433    434  pred1 = self.model.predict(self.X_test, verbose=0)  435  pred2 = loaded_model.predict(self.X_test, verbose=0)  436    437  np.testing.assert_allclose(pred1, pred2, rtol=1e-5)  438  439 if __name__ == "__main__":  440  unittest.main()  441   442  443 ---  444  445 ## Common Pitfalls & Solutions  446  447 | Pitfall | Problem | Solution |  448 |---------|---------|----------|  449 | Embedded layer logic | Code not reusable | Create standalone layers first |  450 | TensorFlow dependencies | Using tf.* operations | Use keras.ops.* and kerasfactory.layers |  451 | Wrong inheritance | Type errors | Inherit from BaseModel |  452 | Incomplete serialization | Cannot save/load | Include all parameters in get_config() |  453 | Missing layer instantiation | Runtime errors | Initialize all layers in build() |  454 | Wrong attribute order | AttributeError | Set public attributes BEFORE super().__init__() |  455 | Insufficient tests | Bugs in production | Write comprehensive tests |  456 | Inadequate documentation | Users confused | Write detailed guide with examples |  457 | No notebook example | Hard to get started | Create end-to-end demo notebook |  458  459 ---  460  461 ## Next Steps  462  463 After implementing and testing your model:  464  465 1. Submit for Review: Create a pull request with your implementation  466 2. Address Feedback: Update based on review comments  467 3. Final Testing: Run full test suite one more time  468 4. Merge: Once approved, merge to main branch  469 5. Announce: Notify team about new model availability  470 6. Update README: Add to main README and features list  471  472 ---  473  474 ## Related Resources  475  476 - Layer Implementation Guide - Detailed layer implementation guide  477 - API Reference - Models - Model API documentation  478 - Contributing Guidelines - Project contribution guidelines  479 - Keras 3 Documentation - Keras 3 API reference  480