Skip to content

ZenML Integration API Reference 📖

Complete API reference for MLPotion's ZenML integration.

Auto-Generated Documentation

This page is automatically populated with API documentation from the source code.

TensorFlow Steps

mlpotion.integrations.zenml.tensorflow.steps

Classes

Functions

evaluate_model

evaluate_model(
    model: keras.Model,
    dataset: tf.data.Dataset,
    verbose: int = 1,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, float], EvaluationMetrics]

Evaluate a TensorFlow/Keras model using ModelEvaluator.

This step computes metrics on a given dataset using the provided model.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to evaluate.

required
dataset tf.data.Dataset

The evaluation tf.data.Dataset.

required
verbose int

Verbosity mode (0 or 1).

1
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[dict[str, float], EvaluationMetrics]

dict[str, float]: A dictionary of computed metrics.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@step
def evaluate_model(
    model: keras.Model,
    dataset: tf.data.Dataset,
    verbose: int = 1,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, float], "EvaluationMetrics"]:
    """Evaluate a TensorFlow/Keras model using `ModelEvaluator`.

    This step computes metrics on a given dataset using the provided model.

    Args:
        model: The Keras model to evaluate.
        dataset: The evaluation `tf.data.Dataset`.
        verbose: Verbosity mode (0 or 1).
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        dict[str, float]: A dictionary of computed metrics.
    """
    logger.info("Evaluating model")

    evaluator = ModelEvaluator()

    config = ModelEvaluationConfig(
        verbose=verbose,
    )

    result = evaluator.evaluate(
        model=model,
        dataset=dataset,
        config=config,
    )

    metrics = result.metrics

    if metadata:
        log_step_metadata(metadata={**metadata, "metrics": metrics})

    return metrics

export_model

export_model(
    model: keras.Model,
    export_path: str,
    export_format: str = "keras",
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, ExportPath]

Export a TensorFlow/Keras model to disk using ModelExporter.

This step exports the model to a specified format (e.g., Keras format, SavedModel).

Parameters:

Name Type Description Default
model keras.Model

The Keras model to export.

required
export_path str

The destination path for the exported model.

required
export_format str

The format to export to (default: "keras").

'keras'
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, ExportPath]

The path to the exported model artifact.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
@step
def export_model(
    model: keras.Model,
    export_path: str,
    export_format: str = "keras",
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "ExportPath"]:
    """Export a TensorFlow/Keras model to disk using `ModelExporter`.

    This step exports the model to a specified format (e.g., Keras format, SavedModel).

    Args:
        model: The Keras model to export.
        export_path: The destination path for the exported model.
        export_format: The format to export to (default: "keras").
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the exported model artifact.
    """
    logger.info(f"Exporting model to: {export_path}")

    exporter = ModelExporter()

    exporter.export(
        model=model,
        path=export_path,
        export_format=export_format,
    )

    if metadata:
        log_step_metadata(metadata={**metadata, "export_path": export_path})

    return export_path

inspect_model

inspect_model(
    model: keras.Model,
    include_layers: bool = True,
    include_signatures: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, Any], ModelInspection]

Inspect a TensorFlow/Keras model using ModelInspector.

This step extracts metadata about the model, such as layer configuration, input/output shapes, and parameter counts.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to inspect.

required
include_layers bool

Whether to include detailed layer information.

True
include_signatures bool

Whether to include signature information.

True
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[dict[str, Any], ModelInspection]

dict[str, Any]: A dictionary containing the inspection results.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
@step
def inspect_model(
    model: keras.Model,
    include_layers: bool = True,
    include_signatures: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, Any], "ModelInspection"]:
    """Inspect a TensorFlow/Keras model using `ModelInspector`.

    This step extracts metadata about the model, such as layer configuration,
    input/output shapes, and parameter counts.

    Args:
        model: The Keras model to inspect.
        include_layers: Whether to include detailed layer information.
        include_signatures: Whether to include signature information.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        dict[str, Any]: A dictionary containing the inspection results.
    """
    logger.info("Inspecting model")

    inspector = ModelInspector(
        include_layers=include_layers,
        include_signatures=include_signatures,
    )
    inspection = inspector.inspect(model)

    if metadata:
        log_step_metadata(metadata={**metadata, "inspection": inspection})

    return inspection

load_data

load_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str = "target",
    column_names: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[tf.data.Dataset, TFDataset]

Load data from local CSV files using TensorFlow's efficient loading.

This step uses CSVDataLoader to create a tf.data.Dataset from CSV files matching the specified pattern.

Parameters:

Name Type Description Default
file_path str

Glob pattern for CSV files (e.g., "data/*.csv").

required
batch_size int

Number of samples per batch.

32
label_name str

Name of the column to use as the label.

'target'
column_names list[str] | None

List of specific columns to load.

None
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[tf.data.Dataset, TFDataset]

tf.data.Dataset: The loaded TensorFlow dataset.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@step
def load_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str = "target",
    column_names: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[tf.data.Dataset, "TFDataset"]:
    """Load data from local CSV files using TensorFlow's efficient loading.

    This step uses `CSVDataLoader` to create a `tf.data.Dataset` from CSV files matching
    the specified pattern.

    Args:
        file_path: Glob pattern for CSV files (e.g., "data/*.csv").
        batch_size: Number of samples per batch.
        label_name: Name of the column to use as the label.
        column_names: List of specific columns to load.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        tf.data.Dataset: The loaded TensorFlow dataset.
    """
    logger.info(f"Loading data from: {file_path}")

    # defining configuration
    config = DataLoadingConfig(
        file_pattern=file_path,
        batch_size=batch_size,
        label_name=label_name,
        column_names=column_names,
    )

    # initializing data loader
    loader = CSVDataLoader(**config.dict())
    # loading data
    dataset = loader.load()

    # adding metadata
    if metadata:
        log_step_metadata(metadata=metadata)

    return dataset

load_model

load_model(
    model_path: str,
    inspect: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[keras.Model, LoadedModel]

Load a TensorFlow/Keras model from disk using ModelPersistence.

This step loads a previously saved model. It can optionally inspect the loaded model to log metadata about its structure.

Parameters:

Name Type Description Default
model_path str

The path to the saved model.

required
inspect bool

Whether to inspect the model after loading.

True
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[keras.Model, LoadedModel]

keras.Model: The loaded Keras model.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
@step
def load_model(
    model_path: str,
    inspect: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[keras.Model, "LoadedModel"]:
    """Load a TensorFlow/Keras model from disk using `ModelPersistence`.

    This step loads a previously saved model. It can optionally inspect the loaded model
    to log metadata about its structure.

    Args:
        model_path: The path to the saved model.
        inspect: Whether to inspect the model after loading.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        keras.Model: The loaded Keras model.
    """
    logger.info(f"Loading model from: {model_path}")

    persistence = ModelPersistence(path=model_path)
    model, inspection = persistence.load(inspect=inspect)

    if metadata:
        meta = {**metadata}
        if inspection:
            meta["inspection"] = inspection
        log_step_metadata(metadata=meta)

    return model

optimize_data

optimize_data(
    dataset: tf.data.Dataset,
    batch_size: int = 32,
    shuffle_buffer_size: int | None = None,
    prefetch: bool = True,
    cache: bool = False,
    metadata: dict[str, Any] | None = None,
) -> Annotated[tf.data.Dataset, TFDataset]

Optimize a TensorFlow dataset for training performance.

This step applies optimizations like caching, shuffling, and prefetching to the dataset using DatasetOptimizer.

Parameters:

Name Type Description Default
dataset tf.data.Dataset

The input tf.data.Dataset.

required
batch_size int

Batch size (if re-batching is needed).

32
shuffle_buffer_size int | None

Size of the shuffle buffer.

None
prefetch bool

Whether to prefetch data.

True
cache bool

Whether to cache data in memory.

False
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[tf.data.Dataset, TFDataset]

tf.data.Dataset: The optimized TensorFlow dataset.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@step
def optimize_data(
    dataset: tf.data.Dataset,
    batch_size: int = 32,
    shuffle_buffer_size: int | None = None,
    prefetch: bool = True,
    cache: bool = False,
    metadata: dict[str, Any] | None = None,
) -> Annotated[tf.data.Dataset, "TFDataset"]:
    """Optimize a TensorFlow dataset for training performance.

    This step applies optimizations like caching, shuffling, and prefetching to the dataset
    using `DatasetOptimizer`.

    Args:
        dataset: The input `tf.data.Dataset`.
        batch_size: Batch size (if re-batching is needed).
        shuffle_buffer_size: Size of the shuffle buffer.
        prefetch: Whether to prefetch data.
        cache: Whether to cache data in memory.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        tf.data.Dataset: The optimized TensorFlow dataset.
    """
    logger.info("Optimizing dataset for training performance")

    config = DataOptimizationConfig(
        batch_size=batch_size,
        shuffle_buffer_size=shuffle_buffer_size,
        prefetch=prefetch,
        cache=cache,
    )

    optimizer = DatasetOptimizer(**config.dict())
    dataset = optimizer.optimize(dataset)

    # adding metadata
    if metadata:
        log_step_metadata(metadata=metadata)

    return dataset

save_model

save_model(
    model: keras.Model,
    save_path: str,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, SavePath]

Save a TensorFlow/Keras model to disk using ModelPersistence.

This step saves the model for later reloading.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to save.

required
save_path str

The destination path.

required
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, SavePath]

The path to the saved model.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
@step
def save_model(
    model: keras.Model,
    save_path: str,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "SavePath"]:
    """Save a TensorFlow/Keras model to disk using `ModelPersistence`.

    This step saves the model for later reloading.

    Args:
        model: The Keras model to save.
        save_path: The destination path.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the saved model.
    """
    logger.info(f"Saving model to: {save_path}")

    persistence = ModelPersistence(path=save_path, model=model)
    persistence.save()

    if metadata:
        log_step_metadata(metadata={**metadata, "save_path": save_path})

    return save_path

train_model

train_model(
    model: keras.Model,
    dataset: tf.data.Dataset,
    epochs: int = 10,
    validation_dataset: tf.data.Dataset | None = None,
    learning_rate: float = 0.001,
    verbose: int = 1,
    metadata: dict[str, Any] | None = None,
) -> Tuple[
    Annotated[keras.Model, TrainedModel],
    Annotated[dict[str, list[float]], TrainingHistory],
]

Train a TensorFlow/Keras model using ModelTrainer.

This step configures and runs a training session. It supports validation data and logging of training metrics.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to train.

required
dataset tf.data.Dataset

The training tf.data.Dataset.

required
epochs int

Number of epochs to train.

10
validation_dataset tf.data.Dataset | None

Optional validation tf.data.Dataset.

None
learning_rate float

Learning rate for the Adam optimizer.

0.001
verbose int

Verbosity mode (0, 1, or 2).

1
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Tuple[Annotated[keras.Model, TrainedModel], Annotated[dict[str, list[float]], TrainingHistory]]

Tuple[keras.Model, dict[str, list[float]]]: The trained model and a dictionary of history metrics.

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@step
def train_model(
    model: keras.Model,
    dataset: tf.data.Dataset,
    epochs: int = 10,
    validation_dataset: tf.data.Dataset | None = None,
    learning_rate: float = 0.001,
    verbose: int = 1,
    metadata: dict[str, Any] | None = None,
) -> Tuple[
    Annotated[keras.Model, "TrainedModel"],
    Annotated[dict[str, list[float]], "TrainingHistory"],
]:
    """Train a TensorFlow/Keras model using `ModelTrainer`.

    This step configures and runs a training session. It supports validation data
    and logging of training metrics.

    Args:
        model: The Keras model to train.
        dataset: The training `tf.data.Dataset`.
        epochs: Number of epochs to train.
        validation_dataset: Optional validation `tf.data.Dataset`.
        learning_rate: Learning rate for the Adam optimizer.
        verbose: Verbosity mode (0, 1, or 2).
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        Tuple[keras.Model, dict[str, list[float]]]: The trained model and a dictionary of history metrics.
    """
    logger.info(f"Training model for {epochs} epochs")

    trainer = ModelTrainer()

    config = ModelTrainingConfig(
        epochs=epochs,
        learning_rate=learning_rate,
        verbose=verbose,
        optimizer="adam",
        loss="mse",
        metrics=["mae"],
    )

    result = trainer.train(
        model=model,
        dataset=dataset,
        config=config,
        validation_dataset=validation_dataset,
    )

    # Result is TrainingResult object
    training_metrics = result.metrics

    if metadata:
        log_step_metadata(metadata={**metadata, "history": result.history})

    return model, training_metrics

transform_data

transform_data(
    dataset: tf.data.Dataset,
    model: keras.Model,
    data_output_path: str,
    data_output_per_batch: bool = False,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, OutputPath]

Transform data using a TensorFlow model and save predictions to CSV.

This step uses DataToCSVTransformer to run inference on a dataset using a provided model and saves the results to the specified output path.

Parameters:

Name Type Description Default
dataset tf.data.Dataset

The input tf.data.Dataset.

required
model keras.Model

The Keras model to use for transformation.

required
data_output_path str

Path to save the transformed data (CSV).

required
data_output_per_batch bool

Whether to save a separate file per batch.

False
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, OutputPath]

The path to the saved output file(s).

Source code in mlpotion/integrations/zenml/tensorflow/steps.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
@step
def transform_data(
    dataset: tf.data.Dataset,
    model: keras.Model,
    data_output_path: str,
    data_output_per_batch: bool = False,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "OutputPath"]:
    """Transform data using a TensorFlow model and save predictions to CSV.

    This step uses `DataToCSVTransformer` to run inference on a dataset using a provided model
    and saves the results to the specified output path.

    Args:
        dataset: The input `tf.data.Dataset`.
        model: The Keras model to use for transformation.
        data_output_path: Path to save the transformed data (CSV).
        data_output_per_batch: Whether to save a separate file per batch.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the saved output file(s).
    """
    logger.info(f"Transforming data and saving to: {data_output_path}")

    transformer = DataToCSVTransformer(
        dataset=dataset,
        model=model,
        data_output_path=data_output_path,
        data_output_per_batch=data_output_per_batch,
    )

    # Create minimal config for transform method
    config = DataTransformationConfig(
        file_pattern="",  # Not used since dataset is provided
        model_path="",  # Not used since model is provided
        model_input_signature={},  # Empty dict as model is provided directly
        data_output_path=data_output_path,
        data_output_per_batch=data_output_per_batch,
    )

    transformer.transform(dataset=None, model=None, config=config)

    if metadata:
        log_step_metadata(metadata=metadata)

    return data_output_path

PyTorch Steps

mlpotion.integrations.zenml.pytorch.steps

ZenML steps for PyTorch framework.

Classes

Functions

evaluate_model

evaluate_model(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: str = "mse",
    device: str = "cpu",
    verbose: int = 1,
    max_batches: int | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, float], EvaluationMetrics]

Evaluate a PyTorch model using ModelEvaluator.

This step computes metrics on a given dataset using the provided model.

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model to evaluate.

required
dataloader DataLoader

The evaluation DataLoader.

required
loss_fn str

Name of the loss function (e.g., "mse", "cross_entropy").

'mse'
device str

Device to evaluate on ("cpu" or "cuda").

'cpu'
verbose int

Verbosity mode (0 or 1).

1
max_batches int | None

Limit number of batches to evaluate (useful for debugging).

None
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[dict[str, float], EvaluationMetrics]

dict[str, float]: A dictionary of computed metrics.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
@step
def evaluate_model(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: str = "mse",
    device: str = "cpu",
    verbose: int = 1,
    max_batches: int | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, float], "EvaluationMetrics"]:
    """Evaluate a PyTorch model using `ModelEvaluator`.

    This step computes metrics on a given dataset using the provided model.

    Args:
        model: The PyTorch model to evaluate.
        dataloader: The evaluation `DataLoader`.
        loss_fn: Name of the loss function (e.g., "mse", "cross_entropy").
        device: Device to evaluate on ("cpu" or "cuda").
        verbose: Verbosity mode (0 or 1).
        max_batches: Limit number of batches to evaluate (useful for debugging).
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        dict[str, float]: A dictionary of computed metrics.
    """
    logger.info(f"Evaluating model on {device}")

    config = ModelEvaluationConfig(
        batch_size=dataloader.batch_size or 32,
        verbose=verbose,
        device=device,
        framework_options={"loss_fn": loss_fn, "max_batches": max_batches},
    )

    evaluator = ModelEvaluator()
    result = evaluator.evaluate(
        model=model,
        dataloader=dataloader,
        config=config,
    )

    # Extract metrics and evaluation time from result
    metrics = {**result.metrics, "evaluation_time": result.evaluation_time}

    if metadata:
        log_step_metadata(metadata={**metadata, "metrics": metrics})

    return metrics

export_model

export_model(
    model: nn.Module,
    export_path: str,
    export_format: str = "state_dict",
    device: str = "cpu",
    example_input: torch.Tensor | None = None,
    jit_mode: str = "script",
    input_names: list[str] | None = None,
    output_names: list[str] | None = None,
    dynamic_axes: dict[str, dict[int, str]] | None = None,
    opset_version: int = 14,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, ExportPath]

Export a PyTorch model to disk using ModelExporter.

This step exports the model to a specified format (TorchScript, ONNX, or state_dict).

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model to export.

required
export_path str

The destination path for the exported model.

required
export_format str

The format to export to ("torchscript", "onnx", "state_dict").

'state_dict'
device str

Device to use for export (important for tracing).

'cpu'
example_input torch.Tensor | None

Example input tensor (required for ONNX and TorchScript trace).

None
jit_mode str

TorchScript mode ("script" or "trace").

'script'
input_names list[str] | None

List of input names for ONNX export.

None
output_names list[str] | None

List of output names for ONNX export.

None
dynamic_axes dict[str, dict[int, str]] | None

Dictionary of dynamic axes for ONNX export.

None
opset_version int

ONNX opset version.

14
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, ExportPath]

The path to the exported model artifact.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
@step
def export_model(
    model: nn.Module,
    export_path: str,
    export_format: str = "state_dict",
    device: str = "cpu",
    example_input: torch.Tensor | None = None,
    jit_mode: str = "script",
    input_names: list[str] | None = None,
    output_names: list[str] | None = None,
    dynamic_axes: dict[str, dict[int, str]] | None = None,
    opset_version: int = 14,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "ExportPath"]:
    """Export a PyTorch model to disk using `ModelExporter`.

    This step exports the model to a specified format (TorchScript, ONNX, or state_dict).

    Args:
        model: The PyTorch model to export.
        export_path: The destination path for the exported model.
        export_format: The format to export to ("torchscript", "onnx", "state_dict").
        device: Device to use for export (important for tracing).
        example_input: Example input tensor (required for ONNX and TorchScript trace).
        jit_mode: TorchScript mode ("script" or "trace").
        input_names: List of input names for ONNX export.
        output_names: List of output names for ONNX export.
        dynamic_axes: Dictionary of dynamic axes for ONNX export.
        opset_version: ONNX opset version.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the exported model artifact.
    """
    logger.info(f"Exporting model to: {export_path} (format: {export_format})")

    config = ModelExportConfig(
        export_path=export_path,
        format=export_format,
        device=device,
        jit_mode=jit_mode,
        example_input=example_input,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=opset_version,
    )

    exporter = ModelExporter()
    result = exporter.export(model=model, config=config)

    if metadata:
        log_step_metadata(
            metadata={
                **metadata,
                "export_path": result.export_path,
                "format": result.format,
                "metadata": result.metadata,
            }
        )

    return str(result.export_path)

load_csv_data

load_csv_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str | None = None,
    column_names: list[str] | None = None,
    shuffle: bool = True,
    num_workers: int = 0,
    pin_memory: bool = False,
    drop_last: bool = False,
    dtype: str = "float32",
    metadata: dict[str, Any] | None = None,
) -> Annotated[DataLoader, PyTorchDataLoader]

Load data from CSV files into a PyTorch DataLoader.

This step uses CSVDataset and CSVDataLoader to load data matching the specified file pattern. It returns a configured DataLoader ready for training or evaluation.

Parameters:

Name Type Description Default
file_path str

Glob pattern for CSV files (e.g., "data/*.csv").

required
batch_size int

Number of samples per batch.

32
label_name str | None

Name of the column to use as the label.

None
column_names list[str] | None

List of specific columns to load.

None
shuffle bool

Whether to shuffle the data.

True
num_workers int

Number of subprocesses to use for data loading.

0
pin_memory bool

Whether to copy tensors into CUDA pinned memory.

False
drop_last bool

Whether to drop the last incomplete batch.

False
dtype str

Data type for the features (e.g., "float32").

'float32'
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
DataLoader Annotated[DataLoader, PyTorchDataLoader]

The configured PyTorch DataLoader.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@step
def load_csv_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str | None = None,
    column_names: list[str] | None = None,
    shuffle: bool = True,
    num_workers: int = 0,
    pin_memory: bool = False,
    drop_last: bool = False,
    dtype: str = "float32",
    metadata: dict[str, Any] | None = None,
) -> Annotated[DataLoader, "PyTorchDataLoader"]:
    """Load data from CSV files into a PyTorch DataLoader.

    This step uses `CSVDataset` and `CSVDataLoader` to load data matching the specified file pattern.
    It returns a configured `DataLoader` ready for training or evaluation.

    Args:
        file_path: Glob pattern for CSV files (e.g., "data/*.csv").
        batch_size: Number of samples per batch.
        label_name: Name of the column to use as the label.
        column_names: List of specific columns to load.
        shuffle: Whether to shuffle the data.
        num_workers: Number of subprocesses to use for data loading.
        pin_memory: Whether to copy tensors into CUDA pinned memory.
        drop_last: Whether to drop the last incomplete batch.
        dtype: Data type for the features (e.g., "float32").
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        DataLoader: The configured PyTorch DataLoader.
    """
    logger.info(f"Loading data from: {file_path}")

    # Convert dtype string to torch.dtype
    torch_dtype = getattr(torch, dtype)

    # Create dataset
    dataset = CSVDataset(
        file_pattern=file_path,
        column_names=column_names,
        label_name=label_name,
        dtype=torch_dtype,
    )

    # Create DataLoader config
    config = DataLoadingConfig(
        file_pattern=file_path,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last,
    )

    # Create DataLoader using factory (exclude fields not accepted by CSVDataLoader)
    loader_factory = CSVDataLoader(**config.dict(exclude={"file_pattern", "config"}))
    dataloader = loader_factory.load(dataset)

    if metadata:
        log_step_metadata(metadata=metadata)

    return dataloader

load_model

load_model(
    model_path: str,
    model_class: type[nn.Module] | None = None,
    map_location: str = "cpu",
    strict: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[nn.Module, LoadedModel]

Load a PyTorch model from disk using ModelPersistence.

This step loads a previously saved model. If loading a state dict, model_class must be provided.

Parameters:

Name Type Description Default
model_path str

The path to the saved model.

required
model_class type[nn.Module] | None

The class of the model (required for state dict loading).

None
map_location str

Device to load the model onto.

'cpu'
strict bool

Whether to strictly enforce state dict keys match.

True
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[nn.Module, LoadedModel]

nn.Module: The loaded PyTorch model.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@step
def load_model(
    model_path: str,
    model_class: type[nn.Module] | None = None,
    map_location: str = "cpu",
    strict: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[nn.Module, "LoadedModel"]:
    """Load a PyTorch model from disk using `ModelPersistence`.

    This step loads a previously saved model. If loading a state dict, `model_class`
    must be provided.

    Args:
        model_path: The path to the saved model.
        model_class: The class of the model (required for state dict loading).
        map_location: Device to load the model onto.
        strict: Whether to strictly enforce state dict keys match.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        nn.Module: The loaded PyTorch model.
    """
    logger.info(f"Loading model from: {model_path}")

    persistence = ModelPersistence(path=model_path)
    model, _ = persistence.load(
        model_class=model_class,
        map_location=map_location,
        strict=strict,
    )

    if metadata:
        log_step_metadata(metadata={**metadata, "model_path": model_path})

    return model

load_streaming_csv_data

load_streaming_csv_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str | None = None,
    column_names: list[str] | None = None,
    num_workers: int = 0,
    pin_memory: bool = False,
    chunksize: int = 10000,
    dtype: str = "float32",
    metadata: dict[str, Any] | None = None,
) -> Annotated[DataLoader, PyTorchDataLoader]

Load large CSV files as a streaming PyTorch DataLoader.

This step uses StreamingCSVDataset to load data in chunks, making it suitable for datasets that do not fit in memory. It returns a DataLoader wrapping the iterable dataset.

Parameters:

Name Type Description Default
file_path str

Glob pattern for CSV files (e.g., "data/*.csv").

required
batch_size int

Number of samples per batch.

32
label_name str | None

Name of the column to use as the label.

None
column_names list[str] | None

List of specific columns to load.

None
num_workers int

Number of subprocesses to use for data loading.

0
pin_memory bool

Whether to copy tensors into CUDA pinned memory.

False
chunksize int

Number of rows to read into memory at a time per file.

10000
dtype str

Data type for the features (e.g., "float32").

'float32'
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
DataLoader Annotated[DataLoader, PyTorchDataLoader]

The configured streaming PyTorch DataLoader.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@step
def load_streaming_csv_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str | None = None,
    column_names: list[str] | None = None,
    num_workers: int = 0,
    pin_memory: bool = False,
    chunksize: int = 10000,
    dtype: str = "float32",
    metadata: dict[str, Any] | None = None,
) -> Annotated[DataLoader, "PyTorchDataLoader"]:
    """Load large CSV files as a streaming PyTorch DataLoader.

    This step uses `StreamingCSVDataset` to load data in chunks, making it suitable for
    datasets that do not fit in memory. It returns a `DataLoader` wrapping the iterable dataset.

    Args:
        file_path: Glob pattern for CSV files (e.g., "data/*.csv").
        batch_size: Number of samples per batch.
        label_name: Name of the column to use as the label.
        column_names: List of specific columns to load.
        num_workers: Number of subprocesses to use for data loading.
        pin_memory: Whether to copy tensors into CUDA pinned memory.
        chunksize: Number of rows to read into memory at a time per file.
        dtype: Data type for the features (e.g., "float32").
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        DataLoader: The configured streaming PyTorch DataLoader.
    """
    logger.info(f"Loading streaming data from: {file_path}")

    # Convert dtype string to torch.dtype
    torch_dtype = getattr(torch, dtype)

    # Create streaming dataset
    dataset = StreamingCSVDataset(
        file_pattern=file_path,
        column_names=column_names,
        label_name=label_name,
        chunksize=chunksize,
        dtype=torch_dtype,
    )

    # Create DataLoader config (no shuffle for streaming)
    config = DataLoadingConfig(
        file_pattern=file_path,
        batch_size=batch_size,
        shuffle=False,  # Streaming datasets don't support shuffle
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    # Create DataLoader using factory (exclude fields not accepted by CSVDataLoader)
    loader_factory = CSVDataLoader(**config.dict(exclude={"file_pattern", "config"}))
    dataloader = loader_factory.load(dataset)

    if metadata:
        log_step_metadata(metadata=metadata)

    return dataloader

save_model

save_model(
    model: nn.Module,
    save_path: str,
    save_full_model: bool = False,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, SavePath]

Save a PyTorch model to disk using ModelPersistence.

This step saves the model for later reloading. It supports saving just the state dict (recommended) or the full model object.

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model to save.

required
save_path str

The destination path.

required
save_full_model bool

Whether to save the full model object (pickle) instead of state dict.

False
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, SavePath]

The path to the saved model.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@step
def save_model(
    model: nn.Module,
    save_path: str,
    save_full_model: bool = False,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "SavePath"]:
    """Save a PyTorch model to disk using `ModelPersistence`.

    This step saves the model for later reloading. It supports saving just the state dict
    (recommended) or the full model object.

    Args:
        model: The PyTorch model to save.
        save_path: The destination path.
        save_full_model: Whether to save the full model object (pickle) instead of state dict.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the saved model.
    """
    logger.info(f"Saving model to: {save_path}")

    persistence = ModelPersistence(path=save_path, model=model)
    persistence.save(save_full_model=save_full_model)

    if metadata:
        log_step_metadata(
            metadata={
                **metadata,
                "save_path": save_path,
                "save_full_model": save_full_model,
            }
        )

    return save_path

train_model

train_model(
    model: nn.Module,
    dataloader: DataLoader,
    epochs: int = 10,
    learning_rate: float = 0.001,
    optimizer: str = "adam",
    loss_fn: str = "mse",
    device: str = "cpu",
    validation_dataloader: DataLoader | None = None,
    verbose: int = 1,
    max_batches_per_epoch: int | None = None,
    metadata: dict[str, Any] | None = None,
) -> Tuple[
    Annotated[nn.Module, TrainedModel],
    Annotated[dict[str, float], TrainingMetrics],
]

Train a PyTorch model using ModelTrainer.

This step configures and runs a training session. It supports validation data, custom loss functions, and automatic device management.

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model to train.

required
dataloader DataLoader

The training DataLoader.

required
epochs int

Number of epochs to train.

10
learning_rate float

Learning rate for the optimizer.

0.001
optimizer str

Name of the optimizer (e.g., "adam", "sgd").

'adam'
loss_fn str

Name of the loss function (e.g., "mse", "cross_entropy").

'mse'
device str

Device to train on ("cpu" or "cuda").

'cpu'
validation_dataloader DataLoader | None

Optional validation DataLoader.

None
verbose int

Verbosity mode (0 or 1).

1
max_batches_per_epoch int | None

Limit number of batches per epoch (useful for debugging).

None
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Tuple[Annotated[nn.Module, TrainedModel], Annotated[dict[str, float], TrainingMetrics]]

Tuple[nn.Module, dict[str, float]]: The trained model and a dictionary of final metrics.

Source code in mlpotion/integrations/zenml/pytorch/steps.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
@step
def train_model(
    model: nn.Module,
    dataloader: DataLoader,
    epochs: int = 10,
    learning_rate: float = 0.001,
    optimizer: str = "adam",
    loss_fn: str = "mse",
    device: str = "cpu",
    validation_dataloader: DataLoader | None = None,
    verbose: int = 1,
    max_batches_per_epoch: int | None = None,
    metadata: dict[str, Any] | None = None,
) -> Tuple[
    Annotated[nn.Module, "TrainedModel"], Annotated[dict[str, float], "TrainingMetrics"]
]:
    """Train a PyTorch model using `ModelTrainer`.

    This step configures and runs a training session. It supports validation data,
    custom loss functions, and automatic device management.

    Args:
        model: The PyTorch model to train.
        dataloader: The training `DataLoader`.
        epochs: Number of epochs to train.
        learning_rate: Learning rate for the optimizer.
        optimizer: Name of the optimizer (e.g., "adam", "sgd").
        loss_fn: Name of the loss function (e.g., "mse", "cross_entropy").
        device: Device to train on ("cpu" or "cuda").
        validation_dataloader: Optional validation `DataLoader`.
        verbose: Verbosity mode (0 or 1).
        max_batches_per_epoch: Limit number of batches per epoch (useful for debugging).
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        Tuple[nn.Module, dict[str, float]]: The trained model and a dictionary of final metrics.
    """
    logger.info(f"Training model for {epochs} epochs on {device}")

    config = ModelTrainingConfig(
        epochs=epochs,
        learning_rate=learning_rate,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
        verbose=verbose,
        max_batches_per_epoch=max_batches_per_epoch,
    )

    trainer = ModelTrainer()
    result = trainer.train(
        model=model,
        dataloader=dataloader,
        config=config,
        validation_dataloader=validation_dataloader,
    )

    if metadata:
        log_step_metadata(
            metadata={
                **metadata,
                "history": result.history,
                "best_epoch": result.best_epoch,
                "final_metrics": result.metrics,
            }
        )
    logger.info(f"{result=}")
    model = result.model
    metrics = result.metrics

    return model, metrics

Keras Steps

mlpotion.integrations.zenml.keras.steps

ZenML steps for Keras framework.

Classes

Functions

evaluate_model

evaluate_model(
    model: keras.Model,
    data: CSVSequence,
    verbose: int = 1,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, float], EvaluationMetrics]

Evaluate a Keras model using ModelEvaluator.

This step computes metrics on a given dataset using the provided model.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to evaluate.

required
data CSVSequence

The evaluation dataset (CSVSequence).

required
verbose int

Verbosity mode (0 or 1).

1
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[dict[str, float], EvaluationMetrics]

dict[str, float]: A dictionary of computed metrics.

Source code in mlpotion/integrations/zenml/keras/steps.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
@step
def evaluate_model(
    model: keras.Model,
    data: CSVSequence,
    verbose: int = 1,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, float], "EvaluationMetrics"]:
    """Evaluate a Keras model using `ModelEvaluator`.

    This step computes metrics on a given dataset using the provided model.

    Args:
        model: The Keras model to evaluate.
        data: The evaluation dataset (`CSVSequence`).
        verbose: Verbosity mode (0 or 1).
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        dict[str, float]: A dictionary of computed metrics.
    """
    logger.info("Evaluating model")

    evaluator = ModelEvaluator()

    config = ModelEvaluationConfig(
        verbose=verbose,
    )

    result = evaluator.evaluate(
        model=model,
        dataset=data,
        config=config,
    )

    metrics = result.metrics

    if metadata:
        log_step_metadata(metadata={**metadata, "metrics": metrics})

    return metrics

export_model

export_model(
    model: keras.Model,
    export_path: str,
    export_format: str | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, ExportPath]

Export a Keras model to disk using ModelExporter.

This step exports the model to a specified format (e.g., SavedModel, H5, TFLite).

Parameters:

Name Type Description Default
model keras.Model

The Keras model to export.

required
export_path str

The destination path for the exported model.

required
export_format str | None

The format to export to (optional).

None
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, ExportPath]

The path to the exported model artifact.

Source code in mlpotion/integrations/zenml/keras/steps.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@step
def export_model(
    model: keras.Model,
    export_path: str,
    export_format: str | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "ExportPath"]:
    """Export a Keras model to disk using `ModelExporter`.

    This step exports the model to a specified format (e.g., SavedModel, H5, TFLite).

    Args:
        model: The Keras model to export.
        export_path: The destination path for the exported model.
        export_format: The format to export to (optional).
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the exported model artifact.
    """
    logger.info(f"Exporting model to: {export_path}")

    exporter = ModelExporter()

    config = {}
    if export_format:
        config["export_format"] = export_format

    exporter.export(
        model=model,
        path=export_path,
        **config,
    )

    if metadata:
        log_step_metadata(metadata={**metadata, "export_path": export_path})

    return export_path

inspect_model

inspect_model(
    model: keras.Model,
    include_layers: bool = True,
    include_signatures: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, Any], ModelInspection]

Inspect a Keras model using ModelInspector.

This step extracts metadata about the model, such as layer configuration, input/output shapes, and parameter counts.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to inspect.

required
include_layers bool

Whether to include detailed layer information.

True
include_signatures bool

Whether to include signature information.

True
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[dict[str, Any], ModelInspection]

dict[str, Any]: A dictionary containing the inspection results.

Source code in mlpotion/integrations/zenml/keras/steps.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@step
def inspect_model(
    model: keras.Model,
    include_layers: bool = True,
    include_signatures: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[dict[str, Any], "ModelInspection"]:
    """Inspect a Keras model using `ModelInspector`.

    This step extracts metadata about the model, such as layer configuration,
    input/output shapes, and parameter counts.

    Args:
        model: The Keras model to inspect.
        include_layers: Whether to include detailed layer information.
        include_signatures: Whether to include signature information.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        dict[str, Any]: A dictionary containing the inspection results.
    """
    logger.info("Inspecting model")

    inspector = ModelInspector(
        include_layers=include_layers,
        include_signatures=include_signatures,
    )
    inspection = inspector.inspect(model)

    if metadata:
        log_step_metadata(metadata={**metadata, "inspection": inspection})

    return inspection

load_data

load_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str | None = None,
    column_names: list[str] | None = None,
    shuffle: bool = True,
    dtype: str = "float32",
    metadata: dict[str, Any] | None = None,
) -> Annotated[CSVSequence, CSVSequence]

Load data from CSV files into a Keras Sequence.

This step uses CSVDataLoader to load data matching the specified file pattern. It returns a CSVSequence which can be used for training or evaluation.

Parameters:

Name Type Description Default
file_path str

Glob pattern for CSV files (e.g., "data/*.csv").

required
batch_size int

Number of samples per batch.

32
label_name str | None

Name of the column to use as the label.

None
column_names list[str] | None

List of specific columns to load.

None
shuffle bool

Whether to shuffle the data.

True
dtype str

Data type for the features (e.g., "float32").

'float32'
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
CSVSequence Annotated[CSVSequence, CSVSequence]

The loaded Keras Sequence.

Source code in mlpotion/integrations/zenml/keras/steps.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@step
def load_data(
    file_path: str,
    batch_size: int = 32,
    label_name: str | None = None,
    column_names: list[str] | None = None,
    shuffle: bool = True,
    dtype: str = "float32",
    metadata: dict[str, Any] | None = None,
) -> Annotated[CSVSequence, "CSVSequence"]:
    """Load data from CSV files into a Keras Sequence.

    This step uses `CSVDataLoader` to load data matching the specified file pattern.
    It returns a `CSVSequence` which can be used for training or evaluation.

    Args:
        file_path: Glob pattern for CSV files (e.g., "data/*.csv").
        batch_size: Number of samples per batch.
        label_name: Name of the column to use as the label.
        column_names: List of specific columns to load.
        shuffle: Whether to shuffle the data.
        dtype: Data type for the features (e.g., "float32").
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        CSVSequence: The loaded Keras Sequence.
    """
    logger.info(f"Loading data from: {file_path}")

    config = DataLoadingConfig(
        file_pattern=file_path,
        batch_size=batch_size,
        column_names=column_names,
        label_name=label_name,
        shuffle=shuffle,
        dtype=dtype,
    )

    loader = CSVDataLoader(**config.dict())
    sequence = loader.load()

    if metadata:
        log_step_metadata(metadata=metadata)

    return sequence

load_model

load_model(
    model_path: str,
    inspect: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[keras.Model, LoadedModel]

Load a Keras model from disk using ModelPersistence.

This step loads a previously saved model. It can optionally inspect the loaded model to log metadata about its structure.

Parameters:

Name Type Description Default
model_path str

The path to the saved model.

required
inspect bool

Whether to inspect the model after loading.

True
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Annotated[keras.Model, LoadedModel]

keras.Model: The loaded Keras model.

Source code in mlpotion/integrations/zenml/keras/steps.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
@step
def load_model(
    model_path: str,
    inspect: bool = True,
    metadata: dict[str, Any] | None = None,
) -> Annotated[keras.Model, "LoadedModel"]:
    """Load a Keras model from disk using `ModelPersistence`.

    This step loads a previously saved model. It can optionally inspect the loaded model
    to log metadata about its structure.

    Args:
        model_path: The path to the saved model.
        inspect: Whether to inspect the model after loading.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        keras.Model: The loaded Keras model.
    """
    logger.info(f"Loading model from: {model_path}")

    persistence = ModelPersistence(path=model_path)
    model, inspection = persistence.load(inspect=inspect)

    if metadata:
        meta = {**metadata}
        if inspection:
            meta["inspection"] = inspection
        log_step_metadata(metadata=meta)

    return model

save_model

save_model(
    model: keras.Model,
    save_path: str,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, SavePath]

Save a Keras model to disk using ModelPersistence.

This step saves the model for later reloading, typically preserving the optimizer state.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to save.

required
save_path str

The destination path.

required
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, SavePath]

The path to the saved model.

Source code in mlpotion/integrations/zenml/keras/steps.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
@step
def save_model(
    model: keras.Model,
    save_path: str,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "SavePath"]:
    """Save a Keras model to disk using `ModelPersistence`.

    This step saves the model for later reloading, typically preserving the optimizer state.

    Args:
        model: The Keras model to save.
        save_path: The destination path.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the saved model.
    """
    logger.info(f"Saving model to: {save_path}")

    persistence = ModelPersistence(path=save_path, model=model)
    persistence.save()

    if metadata:
        log_step_metadata(metadata={**metadata, "save_path": save_path})

    return save_path

train_model

train_model(
    model: keras.Model,
    data: CSVSequence,
    epochs: int = 10,
    validation_data: CSVSequence | None = None,
    learning_rate: float = 0.001,
    verbose: int = 1,
    callbacks: list[Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> Tuple[
    Annotated[keras.Model, TrainedModel],
    Annotated[dict[str, float], TrainingMetrics],
]

Train a Keras model using ModelTrainer.

This step configures and runs a training session. It supports validation data, callbacks, and logging of training metrics.

Parameters:

Name Type Description Default
model keras.Model

The Keras model to train.

required
data CSVSequence

The training dataset (CSVSequence).

required
epochs int

Number of epochs to train.

10
validation_data CSVSequence | None

Optional validation dataset (CSVSequence).

None
learning_rate float

Learning rate for the Adam optimizer.

0.001
verbose int

Verbosity mode (0, 1, or 2).

1
callbacks list[Any] | None

List of Keras callbacks to apply during training.

None
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Type Description
Tuple[Annotated[keras.Model, TrainedModel], Annotated[dict[str, float], TrainingMetrics]]

Tuple[keras.Model, dict[str, float]]: The trained model and a dictionary of final metrics.

Source code in mlpotion/integrations/zenml/keras/steps.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@step
def train_model(
    model: keras.Model,
    data: CSVSequence,
    epochs: int = 10,
    validation_data: CSVSequence | None = None,
    learning_rate: float = 0.001,
    verbose: int = 1,
    callbacks: list[Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> Tuple[
    Annotated[keras.Model, "TrainedModel"],
    Annotated[dict[str, float], "TrainingMetrics"],
]:
    """Train a Keras model using `ModelTrainer`.

    This step configures and runs a training session. It supports validation data,
    callbacks, and logging of training metrics.

    Args:
        model: The Keras model to train.
        data: The training dataset (`CSVSequence`).
        epochs: Number of epochs to train.
        validation_data: Optional validation dataset (`CSVSequence`).
        learning_rate: Learning rate for the Adam optimizer.
        verbose: Verbosity mode (0, 1, or 2).
        callbacks: List of Keras callbacks to apply during training.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        Tuple[keras.Model, dict[str, float]]: The trained model and a dictionary of final metrics.
    """
    logger.info(f"Training model for {epochs} epochs")

    trainer = ModelTrainer()

    config = ModelTrainingConfig(
        epochs=epochs,
        learning_rate=learning_rate,
        verbose=verbose,
        optimizer="adam",  # Defaulting to adam as per previous logic
        loss="mse",
        metrics=["mae"],
        framework_options={"callbacks": callbacks} if callbacks else {},
    )
    # If user passed custom optimizer/loss/metrics via some other way, we might need to handle it,
    # but here we are hardcoding them as per previous implementation.
    # Actually, the previous implementation created an optimizer instance.
    # ModelTrainingConfig supports passing instances via arbitrary types if allowed,
    # or we can pass them via framework_options if the trainer supports it.
    # But Keras ModelTrainer uses config fields.
    # Let's stick to the config fields.

    # Note: The previous implementation created a new optimizer instance: keras.optimizers.Adam(learning_rate=learning_rate)
    # The new ModelTrainer handles optimizer creation from config.

    result = trainer.train(
        model=model,
        dataset=data,
        config=config,
        validation_dataset=validation_data,
    )

    # Result is TrainingResult object
    training_metrics = result.metrics

    if metadata:
        log_step_metadata(metadata={**metadata, "history": result.history})

    return model, training_metrics

transform_data

transform_data(
    dataset: CSVSequence,
    model: keras.Model,
    data_output_path: str,
    data_output_per_batch: bool = False,
    batch_size: int | None = None,
    feature_names: list[str] | None = None,
    input_columns: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, OutputPath]

Transform data using a Keras model and save predictions to CSV.

This step uses CSVDataTransformer to run inference on a dataset using a provided model and saves the results to the specified output path.

Parameters:

Name Type Description Default
dataset CSVSequence

The input dataset (CSVSequence).

required
model keras.Model

The Keras model to use for transformation.

required
data_output_path str

Path to save the transformed data (CSV).

required
data_output_per_batch bool

Whether to save a separate file per batch.

False
batch_size int | None

Batch size for inference (overrides dataset batch size if provided).

None
feature_names list[str] | None

Optional list of feature names for the output CSV.

None
input_columns list[str] | None

Optional list of input columns to pass to the model.

None
metadata dict[str, Any] | None

Optional dictionary of metadata to log to ZenML.

None

Returns:

Name Type Description
str Annotated[str, OutputPath]

The path to the saved output file(s).

Source code in mlpotion/integrations/zenml/keras/steps.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@step
def transform_data(
    dataset: CSVSequence,
    model: keras.Model,
    data_output_path: str,
    data_output_per_batch: bool = False,
    batch_size: int | None = None,
    feature_names: list[str] | None = None,
    input_columns: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
) -> Annotated[str, "OutputPath"]:
    """Transform data using a Keras model and save predictions to CSV.

    This step uses `CSVDataTransformer` to run inference on a dataset using a provided model
    and saves the results to the specified output path.

    Args:
        dataset: The input dataset (`CSVSequence`).
        model: The Keras model to use for transformation.
        data_output_path: Path to save the transformed data (CSV).
        data_output_per_batch: Whether to save a separate file per batch.
        batch_size: Batch size for inference (overrides dataset batch size if provided).
        feature_names: Optional list of feature names for the output CSV.
        input_columns: Optional list of input columns to pass to the model.
        metadata: Optional dictionary of metadata to log to ZenML.

    Returns:
        str: The path to the saved output file(s).
    """
    logger.info(f"Transforming data and saving to: {data_output_path}")

    config = DataTransformationConfig(
        data_output_path=data_output_path,
        data_output_per_batch=data_output_per_batch,
        batch_size=batch_size,
        feature_names=feature_names,
        input_columns=input_columns,
    )

    transformer = CSVDataTransformer(
        dataset=dataset,
        model=model,
        data_output_path=data_output_path,
        data_output_per_batch=data_output_per_batch,
        batch_size=batch_size,
        feature_names=feature_names,
        input_columns=input_columns,
    )
    transformer.transform(dataset=dataset, model=model, config=config)

    if metadata:
        log_step_metadata(metadata=metadata)

    return data_output_path

Materializers

mlpotion.integrations.zenml.tensorflow.materializers

Custom materializers for TensorFlow types.

Classes

TFConfigDatasetMaterializer

Bases: BaseMaterializer

Materializer for tf.data.Dataset created from CSV files.

Instead of serializing the entire dataset to TFRecords, this materializer stores only the configuration needed to recreate the dataset using tf.data.experimental.make_csv_dataset. This is much more efficient and avoids shape-related issues during serialization/deserialization.

This materializer works specifically with datasets created via: - tf.data.experimental.make_csv_dataset - MLPotion's TFCSVDataLoader

Advantages: - Lightweight: Only stores config, not data - Fast: No TFRecord serialization overhead - Reliable: Recreates dataset with exact same parameters - Flexible: Works with any subsequent transformations (batching, shuffling, etc.)

Functions
load
load(data_type: Type[Any]) -> tf.data.Dataset

Load dataset by recreating it from stored configuration.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to load.

required

Returns:

Type Description
tf.data.Dataset

Recreated tf.data.Dataset with the same configuration.

Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
def load(self, data_type: Type[Any]) -> tf.data.Dataset:
    """Load dataset by recreating it from stored configuration.

    Args:
        data_type: The type of the data to load.

    Returns:
        Recreated tf.data.Dataset with the same configuration.
    """
    config_path = Path(self.uri) / "config.json"

    logger.info("Loading CSV dataset config from: %s", config_path)

    with open(config_path, "r", encoding="utf-8") as f:
        config = json.load(f)

    logger.info("Recreating dataset with config: %s", config)

    # Use CSVDataLoader to recreate the dataset
    # This ensures we handle empty lines correctly (unlike make_csv_dataset)
    from mlpotion.frameworks.tensorflow.data.loaders import CSVDataLoader

    # Extract parameters for CSVDataLoader
    loader_config = {
        "file_pattern": config["file_pattern"],
        "batch_size": config["batch_size"],
        "label_name": config.get("label_name"),
        "column_names": config.get("column_names"),
    }

    # Handle num_epochs and other config
    extra_params = config.get("extra_params", {})
    if "num_epochs" in config:
        extra_params["num_epochs"] = config["num_epochs"]
    elif "num_epochs" not in extra_params:
        extra_params["num_epochs"] = 1

    if extra_params:
        loader_config["config"] = extra_params

    # Create loader and load dataset
    loader = CSVDataLoader(**loader_config)
    dataset = loader.load()

    # Apply any transformations that were recorded
    transformations = config.get("transformations", [])
    for transform in transformations:
        transform_type = transform["type"]
        params = transform["params"]

        if transform_type == "batch":
            dataset = dataset.batch(params["batch_size"])
        elif transform_type == "shuffle":
            dataset = dataset.shuffle(params["buffer_size"])
        elif transform_type == "prefetch":
            buffer_size = params["buffer_size"]
            if buffer_size == "AUTOTUNE":
                buffer_size = tf.data.AUTOTUNE
            dataset = dataset.prefetch(buffer_size)
        elif transform_type == "unbatch":
            dataset = dataset.unbatch()
        elif transform_type == "repeat":
            count = params.get("count")
            dataset = dataset.repeat(count)
        # Add more transformation types as needed

    logger.info("✅ Successfully recreated CSV dataset")
    return dataset
save
save(data: tf.data.Dataset) -> None

Save dataset configuration instead of actual data.

This method attempts to extract the original CSV loading configuration from the dataset. If the dataset doesn't have this metadata, it falls back to the TFRecord materializer.

Parameters:

Name Type Description Default
data tf.data.Dataset

The dataset to save configuration for.

required
Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
def save(self, data: tf.data.Dataset) -> None:
    """Save dataset configuration instead of actual data.

    This method attempts to extract the original CSV loading configuration
    from the dataset. If the dataset doesn't have this metadata, it falls
    back to the TFRecord materializer.

    Args:
        data: The dataset to save configuration for.
    """
    config_path = Path(self.uri) / "config.json"
    config_path.parent.mkdir(parents=True, exist_ok=True)

    logger.info("🔵 TFConfigDatasetMaterializer.save() called")
    logger.info("Saving CSV dataset config to: %s", config_path)
    logger.debug("Dataset type: %s", type(data))
    logger.debug("URI: %s", self.uri)

    # Try to extract configuration from the dataset
    # This requires the dataset to have been created with our loader
    # or to have metadata attached
    config = self._extract_config_from_dataset(data)

    if config is None:
        logger.warning(
            "❌ Could not extract CSV config from dataset. "
            "This materializer only works with datasets created from CSV files. "
            "Falling back to TFRecord materializer."
        )
        logger.debug(
            "Dataset attributes: %s",
            [attr for attr in dir(data) if not attr.startswith("__")],
        )
        # Fall back to TFRecord materializer
        from mlpotion.integrations.zenml.tensorflow.materializers import (
            TFRecordDatasetMaterializer,
        )

        logger.info("🔄 Falling back to TFRecordDatasetMaterializer")
        try:
            tfrecord_materializer = TFRecordDatasetMaterializer(self.uri)
            tfrecord_materializer.save(data)
            logger.info("✅ Successfully saved dataset as TFRecord")
        except Exception as e:
            logger.error(f"Failed to save as TFRecord: {e}")
            raise
        return

    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)

    logger.info("✅ Successfully saved CSV dataset config to: %s", config_path)

TFRecordDatasetMaterializer

Bases: BaseMaterializer

Generic TFRecord materializer for tf.data.Dataset.

This materializer is designed to be robust and round-trip safe for datasets produced by tf.data.experimental.make_csv_dataset, and in general for any dataset whose element_spec is a nested structure of:

- dict / tuple / list containers
- `tf.TensorSpec` leaves

It works as follows:

  • Save:

    • Reads dataset.element_spec and serializes it to JSON.
    • For each batch (dataset element), recursively flattens it to a list of tensors in a deterministic order implied by the spec.
    • Writes a single tf.train.Example per batch, with features named "f0", "f1", ... corresponding to each leaf tensor.
  • Load:

    • Deserializes element_spec from JSON.
    • Builds a feature_description for tf.io.parse_single_example using the leaf specs.
    • Parses each example into a list of tensors.
    • Recursively unflattens the list back into the same nested structure as element_spec.

This supports all typical make_csv_dataset shapes:

1. label_name=None:
   element: dict[str, Tensor]

2. label_name="target":
   element: (dict[str, Tensor], Tensor)

3. label_name=["t1", "t2"]:
   element: (dict[str, Tensor], dict[str, Tensor])

and also more complex nesting as long as it's composed of dict / tuple / list and TensorSpec leaves.

Functions
load
load(data_type: Type[Any]) -> tf.data.Dataset

Deserialize a tf.data.Dataset from TFRecord + metadata JSON.

Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def load(self, data_type: Type[Any]) -> tf.data.Dataset:
    """Deserialize a `tf.data.Dataset` from TFRecord + metadata JSON."""
    dataset_dir = Path(self.uri)
    tfrecord_path = str(dataset_dir / "data.tfrecord")
    metadata_path = dataset_dir / "metadata.json"

    logger.info("Loading dataset from TFRecord: %s", tfrecord_path)

    with open(metadata_path, "r", encoding="utf-8") as f:
        metadata = json.load(f)

    element_spec = self._deserialize_element_spec(metadata["element_spec"])
    num_leaves = metadata["num_leaves"]
    concrete_shapes = metadata.get(
        "concrete_shapes", None
    )  # May be None for older versions

    logger.info("Loaded element_spec: %s", element_spec)
    logger.info("Expected number of leaves: %s", num_leaves)
    if concrete_shapes:
        logger.info("Concrete shapes available: %s", concrete_shapes)

    flat_spec_leaves = self._flatten_element_spec(element_spec)
    if len(flat_spec_leaves) != num_leaves:
        raise ValueError(
            f"Metadata num_leaves={num_leaves} but element_spec "
            f"has {len(flat_spec_leaves)} leaves."
        )

    # Build feature description for parsing
    feature_description = self._build_feature_description(flat_spec_leaves)

    def parse_fn(serialized_example: tf.Tensor) -> Any:
        parsed = tf.io.parse_single_example(serialized_example, feature_description)

        flat_tensors: list[tf.Tensor] = []
        for i, (_, leaf_spec) in enumerate(flat_spec_leaves):
            key = f"f{i}"

            if leaf_spec.dtype in (tf.float32, tf.float64, tf.int32, tf.int64):
                # Numeric: stored as VarLenFeature, results in 1D tensor
                dense = tf.sparse.to_dense(parsed[key])
                tensor = tf.cast(dense, leaf_spec.dtype)

                # Use concrete shape if available, otherwise fall back to spec-based logic
                if concrete_shapes and i < len(concrete_shapes):
                    # We have the actual shape from when the data was saved
                    concrete_shape = concrete_shapes[i]
                    # Replace None with -1 for reshape
                    target_shape = [
                        d if d is not None else -1 for d in concrete_shape
                    ]
                    tensor = tf.reshape(tensor, target_shape)
                    # Set the shape with proper None values
                    tensor.set_shape(concrete_shape)
                else:
                    # Fallback to spec-based reshaping (legacy behavior)
                    if leaf_spec.shape.rank is not None:
                        if leaf_spec.shape.rank == 1:
                            # Original was 1D, VarLen already gives us 1D - just set shape
                            tensor.set_shape(leaf_spec.shape)
                        elif leaf_spec.shape.rank > 1:
                            # Original was multi-dimensional - need to reshape from 1D
                            shape_list = leaf_spec.shape.as_list()
                            none_indices = [
                                i for i, d in enumerate(shape_list) if d is None
                            ]

                            if len(none_indices) <= 1:
                                # Safe to reshape with at most one -1
                                target_shape = [
                                    d if d is not None else -1 for d in shape_list
                                ]
                                tensor = tf.reshape(tensor, target_shape)
                                tensor.set_shape(leaf_spec.shape)
                    else:
                        # Unknown rank - just set shape
                        tensor.set_shape(leaf_spec.shape)
            else:
                # Other dtypes: stored as serialized bytes
                serialized = parsed[key]
                tensor = tf.io.parse_tensor(serialized, out_type=leaf_spec.dtype)
                tensor.set_shape(leaf_spec.shape)

            flat_tensors.append(tensor)

        # Rebuild nested structure
        flat_iter = iter(flat_tensors)
        return self._unflatten_data_with_spec(element_spec, flat_iter)

    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)

    logger.info("Successfully loaded dataset from TFRecord.")
    logger.info("Dataset cardinality: %s", dataset.cardinality().numpy())

    return dataset
save
save(data: tf.data.Dataset) -> None

Serialize a tf.data.Dataset to TFRecord + metadata JSON.

Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def save(self, data: tf.data.Dataset) -> None:
    """Serialize a `tf.data.Dataset` to TFRecord + metadata JSON."""
    dataset_dir = Path(self.uri)
    dataset_dir.mkdir(parents=True, exist_ok=True)

    tfrecord_path = str(dataset_dir / "data.tfrecord")
    metadata_path = dataset_dir / "metadata.json"

    element_spec = data.element_spec

    logger.info("Saving dataset to TFRecord: %s", tfrecord_path)
    logger.info("Dataset element_spec: %s", element_spec)

    # Handle cardinality
    cardinality = data.cardinality().numpy()
    logger.info("Dataset cardinality: %s", cardinality)

    if cardinality == tf.data.INFINITE_CARDINALITY:
        logger.warning("Infinite dataset detected. Taking first 100000 batches.")
        data = data.take(100_000)
    elif cardinality == tf.data.UNKNOWN_CARDINALITY:
        logger.warning("Unknown dataset cardinality. Taking first 100000 batches.")
        data = data.take(100_000)
    else:
        logger.info("Finite dataset with %s batches.", cardinality)

    # Serialize element_spec so we can restore structure and leaf specs
    serialized_spec = self._serialize_element_spec(element_spec)
    flat_spec_leaves = self._flatten_element_spec(element_spec)
    num_leaves = len(flat_spec_leaves)

    # Get concrete shapes from the first batch element (if available)
    # We store shapes WITHOUT the batch dimension to handle variable batch sizes
    concrete_shapes = None
    try:
        first_batch = next(iter(data.take(1)))
        flat_tensors_sample: list[tf.Tensor] = []
        self._flatten_data_with_spec(first_batch, element_spec, flat_tensors_sample)
        # Store the shape WITHOUT the first (batch) dimension
        # This allows the materializer to work with variable batch sizes
        concrete_shapes = []
        for t in flat_tensors_sample:
            shape_list = list(t.shape.as_list())
            # Remove the first (batch) dimension, keep the rest
            if len(shape_list) > 1:
                shape_without_batch = [None] + shape_list[1:]  # None for batch dim
            else:
                shape_without_batch = [None]  # Just batch dimension
            concrete_shapes.append(shape_without_batch)
    except Exception:
        # If we can't get a sample, proceed without concrete shapes
        pass

    metadata = {
        "format_version": "3.1",  # Increment version for new feature
        "element_spec": serialized_spec,
        "num_leaves": num_leaves,
        "concrete_shapes": concrete_shapes,  # Store actual shapes if available
    }

    with open(metadata_path, "w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=2)

    # Write TFRecord
    writer = tf.io.TFRecordWriter(tfrecord_path)
    batch_count = 0

    for batch in data:
        flat_tensors: list[tf.Tensor] = []
        self._flatten_data_with_spec(batch, element_spec, flat_tensors)

        if len(flat_tensors) != num_leaves:
            raise ValueError(
                f"Flattened batch has {len(flat_tensors)} leaves but "
                f"element_spec indicates {num_leaves}."
            )

        example = self._flat_tensors_to_example(flat_tensors)
        writer.write(example.SerializeToString())
        batch_count += 1

        if batch_count % 100 == 0:
            logger.info("Written %d batches...", batch_count)

    writer.close()
    logger.info("Successfully saved %d batches to TFRecord.", batch_count)

TensorMaterializer

Bases: BaseMaterializer

Materializer for TensorFlow Tensor objects.

This materializer handles the serialization and deserialization of tf.Tensor objects. It saves tensors as binary protobuf files (tensor.pb) using tf.io.serialize_tensor.

Functions
load
load(data_type: type[Any]) -> tf.Tensor

Load a TensorFlow Tensor from the artifact store.

Parameters:

Name Type Description Default
data_type type[Any]

The type of the data to load (should be tf.Tensor).

required

Returns:

Type Description
tf.Tensor

tf.Tensor: The loaded tensor.

Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def load(self, data_type: type[Any]) -> tf.Tensor:  # noqa: ARG002
    """Load a TensorFlow Tensor from the artifact store.

    Args:
        data_type: The type of the data to load (should be `tf.Tensor`).

    Returns:
        tf.Tensor: The loaded tensor.
    """
    logger.info("Loading TensorFlow tensor...")
    try:
        tensor_path = Path(self.uri) / "tensor.pb"
        return tf.io.parse_tensor(
            tf.io.read_file(str(tensor_path)), out_type=tf.float32
        )
    except Exception as e:
        logger.error(f"Failed to load tensor: {e}")
        raise
save
save(data: tf.Tensor) -> None

Save a TensorFlow Tensor to the artifact store.

Parameters:

Name Type Description Default
data tf.Tensor

The tensor to save.

required
Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def save(self, data: tf.Tensor) -> None:
    """Save a TensorFlow Tensor to the artifact store.

    Args:
        data: The tensor to save.
    """
    logger.info("Saving TensorFlow tensor...")
    try:
        Path(self.uri).mkdir(parents=True, exist_ok=True)
        tensor_path = Path(self.uri) / "tensor.pb"
        tf.io.write_file(str(tensor_path), tf.io.serialize_tensor(data))
        logger.info("✅ Successfully saved TensorFlow tensor")
    except Exception as e:
        logger.error(f"Failed to save tensor: {e}")
        raise

TensorSpecMaterializer

Bases: BaseMaterializer

Materializer for TensorFlow TensorSpec objects.

This materializer handles the serialization and deserialization of tf.TensorSpec objects. It saves the spec as a JSON file (spec.json) containing shape, dtype, and other metadata.

Functions
load
load(data_type: type[Any]) -> tf.TensorSpec

Load a TensorFlow TensorSpec from the artifact store.

Parameters:

Name Type Description Default
data_type type[Any]

The type of the data to load (should be tf.TensorSpec).

required

Returns:

Type Description
tf.TensorSpec

tf.TensorSpec: The loaded tensor spec.

Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def load(self, data_type: type[Any]) -> tf.TensorSpec:  # noqa: ARG002
    """Load a TensorFlow TensorSpec from the artifact store.

    Args:
        data_type: The type of the data to load (should be `tf.TensorSpec`).

    Returns:
        tf.TensorSpec: The loaded tensor spec.
    """
    logger.info("Loading TensorFlow TensorSpec...")
    try:
        spec_path = Path(self.uri) / "spec.json"
        with open(spec_path) as f:
            spec_dict = json.load(f)
        # Reconstruct TensorSpec from dict representation
        return tf.TensorSpec.from_spec(spec_dict)
    except Exception as e:
        logger.error(f"Failed to load TensorSpec: {e}")
        raise
save
save(data: tf.TensorSpec) -> None

Save a TensorFlow TensorSpec to the artifact store.

Parameters:

Name Type Description Default
data tf.TensorSpec

The tensor spec to save.

required
Source code in mlpotion/integrations/zenml/tensorflow/materializers.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def save(self, data: tf.TensorSpec) -> None:
    """Save a TensorFlow TensorSpec to the artifact store.

    Args:
        data: The tensor spec to save.
    """
    logger.info("Saving TensorFlow TensorSpec...")
    try:
        Path(self.uri).mkdir(parents=True, exist_ok=True)
        spec_path = Path(self.uri) / "spec.json"
        # Convert TensorSpec to serializable dict format
        spec_dict = {
            "shape": list(data.shape),
            "dtype": str(data.dtype),
        }
        with open(spec_path, "w") as f:
            json.dump(spec_dict, f, indent=2)
        logger.info("✅ Successfully saved TensorFlow TensorSpec")
    except Exception as e:
        logger.error(f"Failed to save TensorSpec: {e}")
        raise

mlpotion.integrations.zenml.pytorch.materializers


See the ZenML Integration Guide for usage examples