Skip to content

TensorFlow API Reference 📖

Complete API reference for MLPotion's TensorFlow components.

Auto-Generated Documentation

This page is automatically populated with API documentation from the source code.

Extensibility

These components are built using protocol-based design, making MLPotion easy to extend. Want to add new data sources, training methods, or integrations? See Contributing Guide.

Data Loading

mlpotion.frameworks.tensorflow.data.loaders

TensorFlow data loaders.

Classes

CSVDataLoader

CSVDataLoader(
    file_pattern: str,
    batch_size: int = 32,
    column_names: list[str] | None = None,
    label_name: str | None = None,
    map_fn: Callable[[dict[str, Any]], dict[str, Any]]
    | None = None,
    config: dict[str, Any] | None = None,
) -> None

Bases: DataLoader[tf.data.Dataset]

Load CSV files into TensorFlow datasets.

This class provides a convenient wrapper around tf.data.experimental.make_csv_dataset, adding validation, logging, and configuration management. It handles file pattern matching, column selection, and label separation.

Attributes:

Name Type Description
file_pattern str

Glob pattern matching the CSV files to load.

batch_size int

Number of samples per batch.

column_names list[str] | None

Specific columns to load. If None, all columns are loaded.

label_name str | None

Name of the column to use as the label. If None, no labels are returned.

map_fn Callable | None

Optional function to map over the dataset (e.g., for preprocessing).

config dict | None

Additional configuration passed to make_csv_dataset.

Example
from mlpotion.frameworks.tensorflow import CSVDataLoader

# Simple usage
loader = CSVDataLoader(
    file_pattern="data/train_*.csv",
    label_name="target_class",
    batch_size=64,
    config={"num_epochs": 5, "shuffle": True}
)

dataset = loader.load()

# Iterate
for features, labels in dataset:
    print(features['some_column'].shape)
    break
Source code in mlpotion/frameworks/tensorflow/data/loaders.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    file_pattern: str,
    batch_size: int = 32,
    column_names: list[str] | None = None,
    label_name: str | None = None,
    map_fn: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
    config: dict[str, Any] | None = None,
) -> None:
    self.file_pattern = file_pattern
    self.column_names = column_names
    self.label_name = label_name
    self.batch_size = batch_size
    self.map_fn = map_fn

    # set default config
    _default_config = {"ignore_errors": True, "num_epochs": 1}
    self.config: dict[str, Any] = dict(config or _default_config)

    # Extract and validate num_epochs *once* so we don't risk duplicating kwargs
    self.num_epochs = self._extract_and_validate_num_epochs()

    self._validate_files_exist()
    self._validate_finite_dataset()

    logger.info(
        "{class_name} initialized with attrs: {attrs}",
        class_name=self.__class__.__name__,
        attrs=vars(self),
    )
Functions
load
load() -> tf.data.Dataset

Load CSV files into a TensorFlow dataset.

Returns:

Type Description
tf.data.Dataset

tf.data.Dataset: A tf.data.Dataset yielding tuples of (features, labels) if label_name

tf.data.Dataset

is provided, or just features (dict) if not.

Raises:

Type Description
DataLoadingError

If no files match the pattern, or if num_epochs is invalid.

Source code in mlpotion/frameworks/tensorflow/data/loaders.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@trycatch(
    error=DataLoadingError,
    success_msg="✅ Successfully loaded dataset",
)
def load(self) -> tf.data.Dataset:
    """Load CSV files into a TensorFlow dataset.

    Returns:
        tf.data.Dataset: A `tf.data.Dataset` yielding tuples of `(features, labels)` if `label_name`
        is provided, or just `features` (dict) if not.

    Raises:
        DataLoadingError: If no files match the pattern, or if `num_epochs` is invalid.
    """
    dataset = tf.data.experimental.make_csv_dataset(
        file_pattern=self.file_pattern,
        batch_size=self.batch_size,
        label_name=self.label_name,
        column_names=self.column_names,
        num_epochs=self.num_epochs,  # extracted and validated
        **self.config,
    )

    if self.map_fn:
        logger.info("Applying mapping function to dataset")
        dataset = dataset.map(self.map_fn)

    # Attach metadata for CSV materializer
    # This allows ZenML to efficiently serialize/deserialize the dataset
    # by storing just the configuration instead of the actual data
    dataset._csv_config = {
        "file_pattern": self.file_pattern,
        "batch_size": self.batch_size,
        "label_name": self.label_name,
        "column_names": self.column_names,
        "num_epochs": self.num_epochs,
        "extra_params": self.config,
        "transformations": [],  # Will be populated by optimizer if used
    }

    return dataset

RecordDataLoader

RecordDataLoader(
    file_pattern: str,
    batch_size: int = 32,
    column_names: list[str] | None = None,
    label_name: str | None = None,
    map_fn: Callable[[tf.Tensor], Any] | None = None,
    element_spec_json: str | dict[str, Any] | None = None,
    config: dict[str, Any] | None = None,
) -> None

Bases: DataLoader[tf.data.Dataset]

Loader for TFRecord files into tf.data.Dataset.

This class facilitates loading data from TFRecord files, which is the recommended format for high-performance TensorFlow pipelines. It supports parsing examples, handling nested structures via element_spec, and applying common dataset optimizations.

Attributes:

Name Type Description
file_pattern str

Glob pattern matching the TFRecord files.

batch_size int

Number of samples per batch.

column_names list[str] | None

Specific feature keys to extract.

label_name str | None

Key of the label feature.

map_fn Callable | None

Optional function to map over the dataset.

element_spec_json str | dict | None

JSON or dict describing the data structure (optional).

config dict | None

Configuration for reading (e.g., num_parallel_reads, compression_type).

Example
from mlpotion.frameworks.tensorflow import RecordDataLoader

loader = RecordDataLoader(
    file_pattern="data/records/*.tfrecord",
    batch_size=128,
    label_name="label",
    config={
        "compression_type": "GZIP",
        "num_parallel_reads": tf.data.AUTOTUNE
    }
)

dataset = loader.load()
Source code in mlpotion/frameworks/tensorflow/data/loaders.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def __init__(
    self,
    file_pattern: str,
    batch_size: int = 32,
    column_names: list[str] | None = None,
    label_name: str | None = None,
    map_fn: Callable[[tf.Tensor], Any] | None = None,
    element_spec_json: str | dict[str, Any] | None = None,
    config: dict[str, Any] | None = None,
) -> None:
    self.file_pattern = file_pattern
    self.batch_size = batch_size
    self.map_fn = map_fn
    self.element_spec_json = element_spec_json
    self.column_names = column_names
    self.label_name = label_name

    # set config
    self.config = config or {}

    # validate files exist
    self._validate_files_exist()

    logger.info(
        "{class_name} initialized with attrs: {attrs}",
        class_name=self.__class__.__name__,
        attrs=vars(self),
    )
Functions
load
load() -> tf.data.Dataset

Load TFRecord files into a tf.data.Dataset.

Returns:

Type Description
tf.data.Dataset

tf.data.Dataset: Parsed and optionally mapped dataset of (features, label) or features only.

Raises:

Type Description
DataLoadingError

on failure.

Source code in mlpotion/frameworks/tensorflow/data/loaders.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
@trycatch(
    error=DataLoadingError,
    success_msg="✅ Successfully loaded TFRecord dataset",
)
def load(self) -> tf.data.Dataset:
    """Load TFRecord files into a tf.data.Dataset.

    Returns:
        tf.data.Dataset: Parsed and optionally mapped dataset of (features, label) or features only.
    Raises:
        DataLoadingError: on failure.
    """
    filenames = self._get_files_matching_pattern()

    ds = tf.data.TFRecordDataset(
        filenames=filenames,
        compression_type=self.config.get("compression_type", ""),
        buffer_size=self.config.get("buffer_size", None),
        num_parallel_reads=self.config.get("num_parallel_reads", tf.data.AUTOTUNE),
    )

    # Optionally repeat
    if "repeat_count" in self.config:
        ds = ds.repeat(self.config["repeat_count"])

    # Apply column/label selection
    ds = ds.map(
        self._apply_column_label_selection,
        num_parallel_calls=self.config.get("num_parallel_reads", tf.data.AUTOTUNE),
    )

    # Shuffle if requested
    if "shuffle_buffer_size" in self.config:
        ds = ds.shuffle(self.config["shuffle_buffer_size"])

    # Batch
    ds = ds.batch(
        self.batch_size, drop_remainder=self.config.get("drop_remainder", False)
    )

    # Prefetch
    ds = ds.prefetch(self.config.get("prefetch_buffer_size", tf.data.AUTOTUNE))

    # Apply mapping function
    if self.map_fn:
        logger.info("Applying mapping function to dataset")
        ds = ds.map(self.map_fn)

    return ds

mlpotion.frameworks.tensorflow.data.optimizers

TensorFlow dataset optimization.

Classes

DatasetOptimizer

DatasetOptimizer(
    batch_size: int = 32,
    shuffle_buffer_size: int | None = None,
    prefetch: bool = True,
    cache: bool = False,
) -> None

Bases: DatasetOptimizerProtocol[tf.data.Dataset]

Optimize TensorFlow datasets for training performance.

This class applies a standard set of performance optimizations to a tf.data.Dataset: caching, shuffling, batching, and prefetching. These are critical for preventing data loading bottlenecks during training.

Attributes:

Name Type Description
batch_size int

The number of samples per batch.

shuffle_buffer_size int | None

Size of the shuffle buffer. If None, shuffling is disabled.

prefetch bool

Whether to prefetch data (uses tf.data.AUTOTUNE).

cache bool

Whether to cache the dataset in memory.

Example
from mlpotion.frameworks.tensorflow import DatasetOptimizer

# Create optimizer
optimizer = DatasetOptimizer(
    batch_size=32,
    shuffle_buffer_size=1000,
    cache=True,
    prefetch=True
)

# Apply to a raw dataset
optimized_dataset = optimizer.optimize(raw_dataset)

Initialize dataset optimizer.

Parameters:

Name Type Description Default
batch_size int

Batch size

32
shuffle_buffer_size int | None

Buffer size for shuffling (None = no shuffle)

None
prefetch bool

Whether to prefetch batches

True
cache bool

Whether to cache dataset in memory

False
Source code in mlpotion/frameworks/tensorflow/data/optimizers.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(
    self,
    batch_size: int = 32,
    shuffle_buffer_size: int | None = None,
    prefetch: bool = True,
    cache: bool = False,
) -> None:
    """Initialize dataset optimizer.

    Args:
        batch_size: Batch size
        shuffle_buffer_size: Buffer size for shuffling (None = no shuffle)
        prefetch: Whether to prefetch batches
        cache: Whether to cache dataset in memory
    """
    self.batch_size = batch_size
    self.shuffle_buffer_size = shuffle_buffer_size
    self.prefetch = prefetch
    self.cache = cache
Functions
from_config classmethod
from_config(
    config: DataOptimizationConfig,
) -> DatasetOptimizer

Create optimizer from configuration.

Parameters:

Name Type Description Default
config DataOptimizationConfig

Optimization configuration

required

Returns:

Type Description
DatasetOptimizer

Configured optimizer instance

Source code in mlpotion/frameworks/tensorflow/data/optimizers.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@classmethod
def from_config(cls, config: DataOptimizationConfig) -> "DatasetOptimizer":
    """Create optimizer from configuration.

    Args:
        config: Optimization configuration

    Returns:
        Configured optimizer instance
    """
    return cls(
        batch_size=config.batch_size,
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch=config.prefetch,
        cache=config.cache,
    )
optimize
optimize(dataset: tf.data.Dataset) -> tf.data.Dataset

Optimize dataset for training.

Applies optimizations in the following order: 1. Cache: Caches data in memory (if enabled). 2. Shuffle: Randomizes data order (if shuffle_buffer_size is set). 3. Batch: Groups data into batches. 4. Prefetch: Prepares the next batch while the current one is being processed.

Parameters:

Name Type Description Default
dataset tf.data.Dataset

The input tf.data.Dataset.

required

Returns:

Type Description
tf.data.Dataset

tf.data.Dataset: The optimized dataset pipeline.

Source code in mlpotion/frameworks/tensorflow/data/optimizers.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def optimize(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Optimize dataset for training.

    Applies optimizations in the following order:
    1. **Cache**: Caches data in memory (if enabled).
    2. **Shuffle**: Randomizes data order (if `shuffle_buffer_size` is set).
    3. **Batch**: Groups data into batches.
    4. **Prefetch**: Prepares the next batch while the current one is being processed.

    Args:
        dataset: The input `tf.data.Dataset`.

    Returns:
        tf.data.Dataset: The optimized dataset pipeline.
    """
    logger.info("Applying dataset optimizations...")

    # Track transformations for CSV materializer
    transformations = []
    if hasattr(dataset, "_csv_config"):
        transformations = dataset._csv_config.get("transformations", [])

    # Cache first (before shuffling/batching)
    if self.cache:
        logger.info("Caching dataset in memory")
        dataset = dataset.cache()
        # Note: cache() doesn't need to be recorded as it's a performance optimization

    # Shuffle before batching
    if self.shuffle_buffer_size:
        logger.info(f"Shuffling with buffer size {self.shuffle_buffer_size}")
        dataset = dataset.shuffle(
            buffer_size=self.shuffle_buffer_size,
            reshuffle_each_iteration=True,
        )
        transformations.append(
            {
                "type": "shuffle",
                "params": {"buffer_size": self.shuffle_buffer_size},
            }
        )

    # Batch
    logger.info(f"Batching with size {self.batch_size}")
    dataset = dataset.batch(self.batch_size)
    transformations.append(
        {
            "type": "batch",
            "params": {"batch_size": self.batch_size},
        }
    )

    # Prefetch last for best performance
    if self.prefetch:
        logger.info("Prefetching with AUTOTUNE")
        dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
        transformations.append(
            {
                "type": "prefetch",
                "params": {"buffer_size": "AUTOTUNE"},
            }
        )

    # Preserve CSV config if it exists
    if hasattr(dataset, "_csv_config"):
        dataset._csv_config["transformations"] = transformations

    return dataset

Training

mlpotion.frameworks.tensorflow.training.trainers

TensorFlow model trainers.

This module re-exports the Keras ModelTrainer implementation, as TensorFlow 2.x uses Keras as its high-level API.

Classes

ModelTrainer dataclass

Bases: ModelTrainerProtocol[Model, Sequence]

Generic trainer for Keras 3 models.

This class implements the ModelTrainerProtocol for Keras models, providing a standardized interface for training. It wraps the standard model.fit() method but adds flexibility and consistency checks.

It supports: - Automatic model compilation if compile_params are provided. - Handling of various data formats (tuples, dicts, generators). - Standardized return format (dictionary of history metrics).

Example
import keras
import numpy as np
from mlpotion.frameworks.keras import ModelTrainer

# Prepare data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)

# Define model
model = keras.Sequential([
    keras.layers.Dense(1, activation='sigmoid')
])

# Initialize trainer
trainer = ModelTrainer()

# Train
history = trainer.train(
    model=model,
    data=(X_train, y_train),
    compile_params={
        "optimizer": "adam",
        "loss": "binary_crossentropy",
        "metrics": ["accuracy"]
    },
    fit_params={
        "epochs": 5,
        "batch_size": 32,
        "verbose": 1
    }
)

print(history['loss'])
Functions
train
train(
    model: Model,
    dataset: Any,
    config: ModelTrainingConfig,
    validation_dataset: Any | None = None,
) -> TrainingResult[Model]

Train a Keras model using the provided dataset and configuration.

Parameters:

Name Type Description Default
model Model

The Keras model to train.

required
dataset Any

The training data. Can be a tuple (x, y), a dictionary, a Sequence, or a generator.

required
config ModelTrainingConfig

Configuration object containing training parameters.

required
validation_dataset Any | None

Optional validation data.

None

Returns:

Type Description
TrainingResult[Model]

TrainingResult[Model]: An object containing the trained model, training history, and metrics.

Source code in mlpotion/frameworks/keras/training/trainers.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@trycatch(
    error=ModelTrainerError,
    success_msg="✅ Successfully trained Keras model",
)
def train(
    self,
    model: Model,
    dataset: Any,
    config: ModelTrainingConfig,
    validation_dataset: Any | None = None,
) -> TrainingResult[Model]:
    """Train a Keras model using the provided dataset and configuration.

    Args:
        model: The Keras model to train.
        dataset: The training data. Can be a tuple `(x, y)`, a dictionary, a `Sequence`, or a generator.
        config: Configuration object containing training parameters.
        validation_dataset: Optional validation data.

    Returns:
        TrainingResult[Model]: An object containing the trained model, training history, and metrics.
    """
    self._validate_model(model)

    # Prepare compile parameters from config
    compile_params = {
        "optimizer": self._get_optimizer(config),
        "loss": config.loss,
        "metrics": config.metrics,
    }

    # Compile if needed or if forced by config (though we usually respect existing compilation)
    # Here we'll ensure it's compiled. If the user wants to use their own compilation,
    # they should probably compile it before passing it, but our config implies we control it.
    # However, to be safe and flexible:
    if not self._is_compiled(model):
        if not config.optimizer or not config.loss:
            raise RuntimeError(
                "Model is not compiled and config does not provide optimizer and loss. "
                "Either compile the model beforehand or provide optimizer and loss in config."
            )
        logger.info("Compiling model with config parameters.")
        model.compile(**compile_params)
    else:
        logger.info("Model already compiled. Using existing compilation settings.")

    # Prepare fit parameters
    fit_kwargs = {
        "epochs": config.epochs,
        "batch_size": config.batch_size,
        "verbose": config.verbose,
        "shuffle": config.shuffle,
        "validation_split": config.validation_split,
        "callbacks": self._prepare_callbacks(config),
    }

    if validation_dataset is not None:
        fit_kwargs["validation_data"] = validation_dataset

    # Add any framework-specific options
    fit_kwargs.update(config.framework_options)

    logger.info("Starting Keras model training...")
    logger.debug(f"Training data type: {type(dataset)!r}")
    logger.debug(f"Fit parameters: {fit_kwargs}")

    import time

    start_time = time.time()

    history_obj = self._call_fit(model=model, data=dataset, fit_kwargs=fit_kwargs)

    training_time = time.time() - start_time

    # Convert History object to dict[str, list[float]]
    history_dict = self._history_to_dict(history_obj)

    # Extract final metrics
    final_metrics = {}
    for k, v in history_dict.items():
        if v:
            final_metrics[k] = v[-1]

    logger.info("Training completed.")
    logger.debug(f"Training history: {history_dict}")

    return TrainingResult(
        model=model,
        history=history_dict,
        metrics=final_metrics,
        config=config,
        training_time=training_time,
        best_epoch=None,  # Keras history doesn't explicitly track "best" unless using callbacks
    )

Evaluation

mlpotion.frameworks.tensorflow.evaluation.evaluators

TensorFlow model evaluators.

This module re-exports the Keras ModelEvaluator implementation, as TensorFlow 2.x uses Keras as its high-level API.

Classes

ModelEvaluator dataclass

Bases: ModelEvaluatorProtocol[Model, Sequence]

Generic evaluator for Keras 3 models.

This class implements the ModelEvaluatorProtocol for Keras models. It wraps the model.evaluate() method to provide a consistent evaluation interface.

It ensures that the evaluation result is always returned as a dictionary of metric names to values, regardless of how the model was compiled or what arguments were passed.

Example
import keras
import numpy as np
from mlpotion.frameworks.keras import ModelEvaluator

# Prepare data
X_test = np.random.rand(20, 10)
y_test = np.random.randint(0, 2, 20)

# Define model
model = keras.Sequential([
    keras.layers.Dense(1, activation='sigmoid')
])

# Initialize evaluator
evaluator = ModelEvaluator()

# Evaluate
metrics = evaluator.evaluate(
    model=model,
    data=(X_test, y_test),
    compile_params={
        "optimizer": "adam",
        "loss": "binary_crossentropy",
        "metrics": ["accuracy"]
    },
    eval_params={"batch_size": 32}
)

print(metrics)  # {'loss': 0.693..., 'accuracy': 0.5...}
Functions
evaluate
evaluate(
    model: Model,
    dataset: Any,
    config: ModelEvaluationConfig,
) -> EvaluationResult

Evaluate a Keras model on the given data.

Parameters:

Name Type Description Default
model Model

The Keras model to evaluate.

required
dataset Any

The evaluation data. Can be a tuple (x, y), a dictionary, or a Sequence.

required
config ModelEvaluationConfig

Configuration object containing evaluation parameters.

required

Returns:

Name Type Description
EvaluationResult EvaluationResult

An object containing the evaluation metrics.

Source code in mlpotion/frameworks/keras/evaluation/evaluators.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@trycatch(
    error=ModelEvaluatorError,
    success_msg="✅ Successfully evaluated Keras model",
)
def evaluate(
    self,
    model: Model,
    dataset: Any,
    config: ModelEvaluationConfig,
) -> EvaluationResult:
    """Evaluate a Keras model on the given data.

    Args:
        model: The Keras model to evaluate.
        dataset: The evaluation data. Can be a tuple `(x, y)`, a dictionary, or a `Sequence`.
        config: Configuration object containing evaluation parameters.

    Returns:
        EvaluationResult: An object containing the evaluation metrics.
    """
    self._validate_model(model)

    # Prepare eval parameters
    eval_kwargs = {
        "batch_size": config.batch_size,
        "verbose": config.verbose,
        "return_dict": True,
    }

    # Add any framework-specific options
    eval_kwargs.update(config.framework_options)

    # We assume the model is already compiled. If not, Keras will raise an error
    # unless we provide compile params, but EvaluationConfig doesn't typically carry them.
    # The user should ensure the model is compiled (e.g. after loading or training).
    if not self._is_compiled(model):
        logger.warning(
            "Model is not compiled. Evaluation might fail if loss/metrics are not defined."
        )

    logger.info("Evaluating Keras model...")
    logger.debug(f"Evaluation data type: {type(dataset)!r}")
    logger.debug(f"Evaluation parameters: {eval_kwargs}")

    import time

    start_time = time.time()

    result = self._call_evaluate(model=model, data=dataset, eval_kwargs=eval_kwargs)

    evaluation_time = time.time() - start_time

    # At this point, result should be a dict[str, float]
    if not isinstance(result, dict):
        # Defensive fallback if user or Keras changed behavior
        logger.warning(
            f"`model.evaluate` did not return a dict (got {type(result)!r}). "
            "Wrapping into a dict under key 'metric_0'."
        )
        result = {"metric_0": float(result)}

    metrics = {str(k): float(v) for k, v in result.items()}
    logger.info(f"Evaluation result: {metrics}")

    return EvaluationResult(
        metrics=metrics,
        config=config,
        evaluation_time=evaluation_time,
    )

Persistence

mlpotion.frameworks.tensorflow.deployment.persistence

TensorFlow model persistence.

This module re-exports the Keras ModelPersistence implementation, as TensorFlow 2.x uses Keras as its high-level API.

Classes

ModelPersistence

ModelPersistence(
    path: str | Path, model: Model | None = None
) -> None

Bases: ModelPersistenceProtocol[Model]

Persistence helper for Keras models.

This class manages saving and loading of Keras models. It supports standard Keras formats (.keras, .h5) and SavedModel directories. It also integrates with ModelInspector to provide model metadata upon loading.

Attributes:

Name Type Description
path Path

The file path for the model artifact.

model Model | None

The Keras model instance (optional).

Example
import keras
from mlpotion.frameworks.keras import ModelPersistence

# Define model
model = keras.Sequential([keras.layers.Dense(1)])

# Save
saver = ModelPersistence(path="models/my_model.keras", model=model)
saver.save()

# Load
loader = ModelPersistence(path="models/my_model.keras")
loaded_model, metadata = loader.load(inspect=True)
print(metadata['parameters'])
Source code in mlpotion/frameworks/keras/deployment/persistence.py
44
45
46
def __init__(self, path: str | Path, model: Model | None = None) -> None:
    self._path = Path(path)
    self._model = model
Attributes
model property writable
model: Model | None

Currently attached Keras model (may be None before loading).

path property writable
path: Path

Filesystem path where the model is saved/loaded.

Functions
load
load(
    *, inspect: bool = True, **kwargs: Any
) -> tuple[Model, dict[str, Any] | None]

Load a Keras model from disk.

Parameters:

Name Type Description Default
inspect bool

Whether to inspect the loaded model and return metadata.

True
**kwargs Any

Additional arguments passed to keras.models.load_model().

{}

Returns:

Type Description
Model

tuple[Model, dict[str, Any] | None]: A tuple containing the loaded model and

dict[str, Any] | None

optional inspection metadata.

Raises:

Type Description
ModelPersistenceError

If the model file cannot be found or loaded.

Source code in mlpotion/frameworks/keras/deployment/persistence.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@trycatch(
    error=ModelPersistenceError,
    success_msg="✅ Successfully loaded Keras model",
)
def load(
    self,
    *,
    inspect: bool = True,
    **kwargs: Any,
) -> tuple[Model, dict[str, Any] | None]:
    """Load a Keras model from disk.

    Args:
        inspect: Whether to inspect the loaded model and return metadata.
        **kwargs: Additional arguments passed to `keras.models.load_model()`.

    Returns:
        tuple[Model, dict[str, Any] | None]: A tuple containing the loaded model and
        optional inspection metadata.

    Raises:
        ModelPersistenceError: If the model file cannot be found or loaded.
    """
    path = self._ensure_path_exists()

    logger.info(f"Loading Keras model from: {path!s}")
    model = keras.models.load_model(path.as_posix(), **kwargs)

    self._model = model  # keep instance in sync

    inspection_result: dict[str, Any] | None = None
    if inspect:
        logger.info("Inspecting loaded Keras model with ModelInspector.")
        inspector = ModelInspector()
        inspection_result = inspector.inspect(model)

    return model, inspection_result
save
save(overwrite: bool = True, **kwargs: Any) -> None

Save the attached model to disk.

Parameters:

Name Type Description Default
overwrite bool

Whether to overwrite the file if it already exists.

True
**kwargs Any

Additional arguments passed to model.save().

{}

Raises:

Type Description
ModelPersistenceError

If no model is attached or if the file exists and overwrite is False.

Source code in mlpotion/frameworks/keras/deployment/persistence.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@trycatch(
    error=ModelPersistenceError,
    success_msg="✅ Successfully saved Keras model",
)
def save(
    self,
    overwrite: bool = True,
    **kwargs: Any,
) -> None:
    """Save the attached model to disk.

    Args:
        overwrite: Whether to overwrite the file if it already exists.
        **kwargs: Additional arguments passed to `model.save()`.

    Raises:
        ModelPersistenceError: If no model is attached or if the file exists and `overwrite` is False.
    """
    model = self._ensure_model()
    target = self._path

    if target.exists() and not overwrite:
        raise ModelPersistenceError(
            f"Target path already exists and overwrite=False: {target!s}"
        )

    logger.info(f"Saving Keras model to: {target!s}")
    target.parent.mkdir(parents=True, exist_ok=True)

    # Keras 3 generally infers format from the path; `save_format` is
    # deprecated / discouraged in newer APIs, so we do NOT pass it.
    model.save(target.as_posix(), **kwargs)
    logger.info("Keras model saved successfully.")

Export

mlpotion.frameworks.tensorflow.deployment.exporters

TensorFlow model exporters.

This module re-exports the Keras ModelExporter implementation, as TensorFlow 2.x uses Keras as its high-level API.

Classes

ModelExporter

Bases: ModelExporterProtocol[Model]

Generic exporter for Keras 3 models.

This class implements ModelExporterProtocol and supports exporting Keras models to various formats, including native Keras formats (.keras, .h5) and inference formats like TensorFlow SavedModel or ONNX (via model.export).

It also supports creating export archives with custom endpoints using keras.export.ExportArchive.

Example
import keras
from mlpotion.frameworks.keras import ModelExporter

model = keras.Sequential([keras.layers.Dense(1)])
exporter = ModelExporter()

# Export as standard Keras file
exporter.export(model, "models/model.keras")

# Export for serving (TF SavedModel)
exporter.export(model, "models/serving", export_format="tf_saved_model")
Functions
export
export(model: Model, path: str, **kwargs: Any) -> None

Export a Keras model to disk.

Parameters:

Name Type Description Default
model Model

The Keras model to export.

required
path str

The destination path or directory.

required
**kwargs Any

Additional export options: - export_format (str): "keras", "h5", "tf_saved_model", "onnx", etc. - dataset (Iterable): Optional data for model warmup. - endpoint_name (str): Name for custom endpoint (uses ExportArchive). - input_specs (list[InputSpec]): Input signatures for custom endpoint. - config (dict): Extra arguments for the underlying save/export method.

{}

Raises:

Type Description
ModelExporterError

If export fails.

Source code in mlpotion/frameworks/keras/deployment/exporters.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@trycatch(
    error=ModelExporterError,
    success_msg="✅ Successfully Exported model",
)
def export(self, model: Model, path: str, **kwargs: Any) -> None:
    """Export a Keras model to disk.

    Args:
        model: The Keras model to export.
        path: The destination path or directory.
        **kwargs: Additional export options:
            - `export_format` (str): "keras", "h5", "tf_saved_model", "onnx", etc.
            - `dataset` (Iterable): Optional data for model warmup.
            - `endpoint_name` (str): Name for custom endpoint (uses ExportArchive).
            - `input_specs` (list[InputSpec]): Input signatures for custom endpoint.
            - `config` (dict): Extra arguments for the underlying save/export method.

    Raises:
        ModelExporterError: If export fails.
    """
    export_path = Path(path)

    export_format: str | None = kwargs.pop("export_format", None)
    dataset: Iterable[Any] | None = kwargs.pop("dataset", None)
    endpoint_name: str | None = kwargs.pop("endpoint_name", None)
    input_specs: Sequence[InputSpec] | None = kwargs.pop("input_specs", None)
    config: Mapping[str, Any] | None = kwargs.pop("config", None)

    if kwargs:
        logger.warning(
            "Unused export kwargs passed to ModelExporter: "
            f"{list(kwargs.keys())}"
        )

    self._validate_model(model)
    self._validate_config(config)

    # Determine mode if export_format isn't explicitly set
    if export_format is None:
        export_format = self._infer_export_format_from_path(export_path)

    logger.info(
        f"Exporting Keras model '{model.name}' to {export_path!s} "
        f"with format '{export_format}'"
    )

    # Optional warm-up pass
    self._warmup_if_needed(model=model, dataset=dataset)

    # Choose strategy
    try:
        if self._is_native_keras_format(export_format):
            self._save_native_keras(model=model, path=export_path, config=config)
        elif endpoint_name is not None or input_specs is not None:
            self._export_with_export_archive(
                model=model,
                path=export_path,
                endpoint_name=endpoint_name or self.default_endpoint_name,
                input_specs=input_specs,
                export_format=export_format,
            )
        else:
            self._export_with_model_export(
                model=model,
                path=export_path,
                export_format=export_format,
                config=config,
            )
    except ValueError as err:
        logger.warning(
            f"Export error: {err} "
            "(you may need to build the model by calling it on example data "
            "before exporting)"
        )

    logger.info(f"Model export completed: {export_path!s}")

Model Inspection

mlpotion.frameworks.tensorflow.models.inspection

TensorFlow model inspection.

This module re-exports the Keras ModelInspector implementation, as TensorFlow 2.x uses Keras as its high-level API.

Classes

ModelInspector dataclass

Bases: ModelInspectorProtocol[ModelLike]

Inspector for Keras models.

This class analyzes Keras models to extract metadata such as input/output shapes, parameter counts, layer details, and signatures. It is useful for validating models before training or deployment, and for generating model reports.

Attributes:

Name Type Description
include_layers bool

Whether to include detailed information about each layer.

include_signatures bool

Whether to include model signatures (if available).

Example
import keras
from mlpotion.frameworks.keras import ModelInspector

model = keras.Sequential([keras.layers.Dense(1, input_shape=(10,))])
inspector = ModelInspector()

info = inspector.inspect(model)
print(f"Total params: {info['parameters']['total']}")
print(f"Inputs: {info['inputs']}")
Functions
inspect
inspect(model: ModelLike) -> dict[str, Any]

Inspect a Keras model and return structured metadata.

Parameters:

Name Type Description Default
model ModelLike

The Keras model to inspect.

required

Returns:

Type Description
dict[str, Any]

dict[str, Any]: A dictionary containing model metadata: - name: Model name. - backend: Keras backend used. - trainable: Whether the model is trainable. - inputs: List of input specifications. - outputs: List of output specifications. - parameters: Dictionary of parameter counts. - layers: List of layer details (if include_layers=True). - signatures: Model signatures (if include_signatures=True).

Source code in mlpotion/frameworks/keras/models/inspection.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@trycatch(
    error=ModelInspectorError,
    success_msg="✅ Successfully inspected Keras model",
)
def inspect(self, model: ModelLike) -> dict[str, Any]:
    """Inspect a Keras model and return structured metadata.

    Args:
        model: The Keras model to inspect.

    Returns:
        dict[str, Any]: A dictionary containing model metadata:
            - `name`: Model name.
            - `backend`: Keras backend used.
            - `trainable`: Whether the model is trainable.
            - `inputs`: List of input specifications.
            - `outputs`: List of output specifications.
            - `parameters`: Dictionary of parameter counts.
            - `layers`: List of layer details (if `include_layers=True`).
            - `signatures`: Model signatures (if `include_signatures=True`).
    """
    if not isinstance(model, keras.Model):
        raise TypeError(
            f"ModelInspector expects a keras.Model, got {type(model)!r}"
        )

    logger.info("Inspecting Keras model...")

    backend_name = self._get_backend_name()

    info: dict[str, Any] = {
        "name": model.name,
        "backend": backend_name,
        "trainable": model.trainable,
    }

    info["inputs"] = self._get_inputs(model)
    info["input_names"] = [input["name"] for input in info["inputs"]]
    info["outputs"] = self._get_outputs(model)
    info["output_names"] = [output["name"] for output in info["outputs"]]
    info["parameters"] = self._get_param_counts(model)

    if self.include_signatures:
        info["signatures"] = self._get_signatures(model)

    if self.include_layers:
        info["layers"] = self._get_layers_summary(model)

    logger.debug(f"Keras model inspection result: {info}")
    return info

See the TensorFlow Guide for usage examples