Skip to content

PyTorch API Reference 📖

Complete API reference for MLPotion's PyTorch components.

Auto-Generated Documentation

This page is automatically populated with API documentation from the source code.

Extensibility

These components are built using protocol-based design, making MLPotion easy to extend. Want to add new data sources, training methods, or integrations? See Contributing Guide.

Data Loading

mlpotion.frameworks.pytorch.data.datasets

Classes

CSVDataset dataclass

Bases: Dataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]

PyTorch Dataset for CSV files with on-demand tensor conversion.

This class loads CSV data into memory (using Pandas) and provides a map-style PyTorch Dataset. It supports filtering columns, separating labels, and efficient on-demand tensor conversion to minimize memory usage.

Attributes:

Name Type Description
file_pattern str

Glob pattern matching the CSV files to load.

column_names list[str] | None

Specific columns to load. If None, all columns are loaded.

label_name str | None

Name of the column to use as the label. If None, no labels are returned.

dtype torch.dtype

The data type for the features (default: torch.float32).

Example
from mlpotion.frameworks.pytorch import CSVDataset
from torch.utils.data import DataLoader

# Create dataset
dataset = CSVDataset(
    file_pattern="data/train_*.csv",
    label_name="target_class",
    column_names=["feature1", "feature2", "target_class"]
)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Iterate
for features, labels in dataloader:
    print(features.shape, labels.shape)
Functions
__getitem__
__getitem__(
    idx: int,
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor

Get item at index.

Parameters:

Name Type Description Default
idx int

Global row index.

required

Returns:

Type Description
tuple[torch.Tensor, torch.Tensor] | torch.Tensor

(features, label) tuple if labels exist, else just features.

Source code in mlpotion/frameworks/pytorch/data/datasets.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def __getitem__(
    self,
    idx: int,
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
    """Get item at index.

    Args:
        idx: Global row index.

    Returns:
        (features, label) tuple if labels exist, else just features.
    """
    if self._features_df is None:
        raise IndexError("Dataset is empty or not properly initialized.")

    row = self._features_df.iloc[idx]

    # Convert to numpy array and then to tensor
    features_np: np.ndarray = row.to_numpy(dtype="float32", copy=False)
    features = torch.as_tensor(features_np, dtype=self._dtype)

    if self._labels is not None:
        label_val = self._labels[idx]
        label = torch.as_tensor(label_val, dtype=self._dtype)
        return features, label

    return features
__len__
__len__() -> int

Return dataset length.

Source code in mlpotion/frameworks/pytorch/data/datasets.py
164
165
166
167
168
def __len__(self) -> int:
    """Return dataset length."""
    if self._features_df is None:
        return 0
    return len(self._features_df)
__post_init__
__post_init__() -> None

Eagerly load CSV files into a DataFrame and validate configuration.

Source code in mlpotion/frameworks/pytorch/data/datasets.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __post_init__(self) -> None:
    """Eagerly load CSV files into a DataFrame and validate configuration."""
    try:
        files = self._resolve_files()
        df = self._load_dataframe(files)
        df = self._select_columns(df)
        self._split_features_labels(df)

        logger.info(
            "Initialized CSVDataset with "
            "n_rows={rows}, n_features={features}, labels={labels}",
            rows=len(self._features_df) if self._features_df is not None else 0,
            features=len(self._feature_cols),
            labels="yes" if self._labels is not None else "no",
        )
    except DataLoadingError:
        raise
    except Exception as exc:  # noqa: BLE001
        raise DataLoadingError(f"Failed to load CSV dataset: {exc!s}") from exc

StreamingCSVDataset dataclass

Bases: IterableDataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]

Streaming PyTorch IterableDataset for large CSV files.

This dataset is designed for datasets that are too large to fit in memory. It reads CSV files in chunks (using Pandas) and streams samples one by one. It is compatible with PyTorch's IterableDataset interface.

Attributes:

Name Type Description
file_pattern str

Glob pattern matching the CSV files to load.

column_names list[str] | None

Specific columns to load.

label_name str | None

Name of the label column.

chunksize int

Number of rows to read into memory at a time per file.

dtype torch.dtype

The data type for the features.

Example
from mlpotion.frameworks.pytorch import StreamingCSVDataset
from torch.utils.data import DataLoader

# Create streaming dataset
dataset = StreamingCSVDataset(
    file_pattern="data/large_dataset_*.csv",
    label_name="target",
    chunksize=10000
)

# Create DataLoader (shuffle must be False for IterableDataset)
dataloader = DataLoader(dataset, batch_size=64)

for features, labels in dataloader:
    # Train model...
    pass
Functions
__iter__
__iter__() -> (
    Iterator[
        tuple[torch.Tensor, torch.Tensor] | torch.Tensor
    ]
)

Yield samples one by one across all CSV files.

Source code in mlpotion/frameworks/pytorch/data/datasets.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def __iter__(
    self,
) -> Iterator[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]:
    """Yield samples one by one across all CSV files."""
    for file_path in self.files:
        logger.info(f"Streaming CSV file: {file_path}")
        try:
            # Use pandas chunked reading
            chunk_iter = pd.read_csv(
                file_path,
                usecols=self.column_names,
                chunksize=self.chunksize,
            )
        except TypeError:
            # If usecols=None is not accepted by some pandas version
            chunk_iter = pd.read_csv(
                file_path,
                chunksize=self.chunksize,
            )

        for chunk_df in chunk_iter:
            # Validate label column if needed
            if self.label_name:
                if self.label_name not in chunk_df.columns:
                    raise DataLoadingError(
                        f"Label column '{self.label_name}' not found in "
                        f"file {file_path} (columns: {list(chunk_df.columns)})"
                    )
                labels_np = chunk_df[self.label_name].to_numpy()
                features_df = chunk_df.drop(columns=[self.label_name])
            else:
                labels_np = None
                features_df = chunk_df

            # Convert whole chunk to numpy once
            features_np = features_df.to_numpy(dtype="float32", copy=False)

            if labels_np is not None:
                for row_idx in range(features_np.shape[0]):
                    x = torch.as_tensor(
                        features_np[row_idx],
                        dtype=self.dtype,
                    )
                    y = torch.as_tensor(labels_np[row_idx], dtype=self.dtype)
                    yield x, y
            else:
                for row_idx in range(features_np.shape[0]):
                    x = torch.as_tensor(
                        features_np[row_idx],
                        dtype=self.dtype,
                    )
                    yield x
__post_init__
__post_init__() -> None

Resolve files eagerly and log basic configuration.

Source code in mlpotion/frameworks/pytorch/data/datasets.py
245
246
247
248
249
250
251
252
253
254
def __post_init__(self) -> None:
    """Resolve files eagerly and log basic configuration."""
    self.files = self._resolve_files()
    logger.info(
        "Initialized StreamingCSVDataset with {n_files} file(s), "
        "chunksize={chunksize}, label_name={label}",
        n_files=len(self.files),
        chunksize=self.chunksize,
        label=self.label_name,
    )

mlpotion.frameworks.pytorch.data.loaders

Classes

CSVDataLoader dataclass

Bases: Generic[T_co]

Factory for creating configured PyTorch DataLoaders.

This class simplifies the creation of torch.utils.data.DataLoader instances by encapsulating common configuration options and handling differences between map-style and iterable datasets (e.g., automatically disabling shuffling for iterables).

Attributes:

Name Type Description
batch_size int

Number of samples per batch.

shuffle bool

Whether to shuffle the data (ignored for IterableDatasets).

num_workers int

Number of subprocesses to use for data loading.

pin_memory bool

Whether to copy tensors into CUDA pinned memory.

drop_last bool

Whether to drop the last incomplete batch.

persistent_workers bool | None

Whether to keep workers alive between epochs.

prefetch_factor int | None

Number of batches loaded in advance by each worker.

Example
from mlpotion.frameworks.pytorch import CSVDataLoader, CSVDataset

# 1. Create a dataset
dataset = CSVDataset("data.csv", label_name="target")

# 2. Configure the loader factory
loader_factory = CSVDataLoader(
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# 3. Create the actual DataLoader
train_loader = loader_factory.load(dataset)

# 4. Use it
for X, y in train_loader:
    ...
Functions
load
load(
    dataset: Dataset[T_co] | IterableDataset[T_co],
) -> DataLoader[T_co]

Load a configured :class:DataLoader from a dataset.

This method is aware of :class:IterableDataset vs map-style :class:Dataset and will:

  • Disable shuffling for iterable datasets (with a warning if shuffle=True was requested).
  • Apply worker-related options only when valid.

Parameters:

Name Type Description Default
dataset Dataset[T_co] | IterableDataset[T_co]

PyTorch :class:Dataset or :class:IterableDataset.

required

Returns:

Name Type Description
Configured DataLoader[T_co]

class:torch.utils.data.DataLoader instance.

Source code in mlpotion/frameworks/pytorch/data/loaders.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
@trycatch(
    error=DataLoadingError,
    success_msg="✅ Successfully Loading data",
)
def load(
    self,
    dataset: Dataset[T_co] | IterableDataset[T_co],
) -> DataLoader[T_co]:
    """Load a configured :class:`DataLoader` from a dataset.

    This method is aware of :class:`IterableDataset` vs map-style
    :class:`Dataset` and will:

    - Disable shuffling for iterable datasets (with a warning if
      ``shuffle=True`` was requested).
    - Apply worker-related options only when valid.

    Args:
        dataset: PyTorch :class:`Dataset` or :class:`IterableDataset`.

    Returns:
        Configured :class:`torch.utils.data.DataLoader` instance.
    """
    is_iterable = isinstance(dataset, IterableDataset)
    effective_shuffle = self._resolve_shuffle(is_iterable=is_iterable)

    loader_kwargs = self._build_loader_kwargs(
        dataset=dataset,
        shuffle=effective_shuffle,
        is_iterable=is_iterable,
    )

    logger.info(
        "Creating DataLoader with config: "
        "batch_size={batch_size}, shuffle={shuffle}, "
        "num_workers={num_workers}, pin_memory={pin_memory}, "
        "drop_last={drop_last}, persistent_workers={persistent_workers}, "
        "prefetch_factor={prefetch_factor}, dataset_type={dtype}",
        batch_size=self.batch_size,
        shuffle=effective_shuffle,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=self.drop_last,
        persistent_workers=self.persistent_workers,
        prefetch_factor=self.prefetch_factor,
        dtype="IterableDataset" if is_iterable else "Dataset",
    )

    return DataLoader(**loader_kwargs)

CSVDataset dataclass

Bases: Dataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]

PyTorch Dataset for CSV files with on-demand tensor conversion.

This class loads CSV data into memory (using Pandas) and provides a map-style PyTorch Dataset. It supports filtering columns, separating labels, and efficient on-demand tensor conversion to minimize memory usage.

Attributes:

Name Type Description
file_pattern str

Glob pattern matching the CSV files to load.

column_names list[str] | None

Specific columns to load. If None, all columns are loaded.

label_name str | None

Name of the column to use as the label. If None, no labels are returned.

dtype torch.dtype

The data type for the features (default: torch.float32).

Example
from mlpotion.frameworks.pytorch import CSVDataset
from torch.utils.data import DataLoader

# Create dataset
dataset = CSVDataset(
    file_pattern="data/train_*.csv",
    label_name="target_class",
    column_names=["feature1", "feature2", "target_class"]
)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Iterate
for features, labels in dataloader:
    print(features.shape, labels.shape)
Functions
__getitem__
__getitem__(
    idx: int,
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor

Get item at index.

Parameters:

Name Type Description Default
idx int

Global row index.

required

Returns:

Type Description
tuple[torch.Tensor, torch.Tensor] | torch.Tensor

(features, label) tuple if labels exist, else just features.

Source code in mlpotion/frameworks/pytorch/data/loaders.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def __getitem__(
    self,
    idx: int,
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
    """Get item at index.

    Args:
        idx: Global row index.

    Returns:
        (features, label) tuple if labels exist, else just features.
    """
    if self._features_df is None:
        raise IndexError("Dataset is empty or not properly initialized.")

    row = self._features_df.iloc[idx]

    # Convert to numpy array and then to tensor
    features_np: np.ndarray = row.to_numpy(dtype="float32", copy=False)
    features = torch.as_tensor(features_np, dtype=self._dtype)

    if self._labels is not None:
        label_val = self._labels[idx]
        label = torch.as_tensor(label_val, dtype=self._dtype)
        return features, label

    return features
__len__
__len__() -> int

Return dataset length.

Source code in mlpotion/frameworks/pytorch/data/loaders.py
168
169
170
171
172
def __len__(self) -> int:
    """Return dataset length."""
    if self._features_df is None:
        return 0
    return len(self._features_df)
__post_init__
__post_init__() -> None

Eagerly load CSV files into a DataFrame and validate configuration.

Source code in mlpotion/frameworks/pytorch/data/loaders.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __post_init__(self) -> None:
    """Eagerly load CSV files into a DataFrame and validate configuration."""
    try:
        files = self._resolve_files()
        df = self._load_dataframe(files)
        df = self._select_columns(df)
        self._split_features_labels(df)

        logger.info(
            "Initialized CSVDataset with "
            "n_rows={rows}, n_features={features}, labels={labels}",
            rows=len(self._features_df) if self._features_df is not None else 0,
            features=len(self._feature_cols),
            labels="yes" if self._labels is not None else "no",
        )
    except DataLoadingError:
        raise
    except Exception as exc:  # noqa: BLE001
        raise DataLoadingError(f"Failed to load CSV dataset: {exc!s}") from exc

StreamingCSVDataset dataclass

Bases: IterableDataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]

Streaming PyTorch IterableDataset for large CSV files.

This dataset is designed for datasets that are too large to fit in memory. It reads CSV files in chunks (using Pandas) and streams samples one by one. It is compatible with PyTorch's IterableDataset interface.

Attributes:

Name Type Description
file_pattern str

Glob pattern matching the CSV files to load.

column_names list[str] | None

Specific columns to load.

label_name str | None

Name of the label column.

chunksize int

Number of rows to read into memory at a time per file.

dtype torch.dtype

The data type for the features.

Example
from mlpotion.frameworks.pytorch import StreamingCSVDataset
from torch.utils.data import DataLoader

# Create streaming dataset
dataset = StreamingCSVDataset(
    file_pattern="data/large_dataset_*.csv",
    label_name="target",
    chunksize=10000
)

# Create DataLoader (shuffle must be False for IterableDataset)
dataloader = DataLoader(dataset, batch_size=64)

for features, labels in dataloader:
    # Train model...
    pass
Functions
__iter__
__iter__() -> (
    Iterator[
        tuple[torch.Tensor, torch.Tensor] | torch.Tensor
    ]
)

Yield samples one by one across all CSV files.

Source code in mlpotion/frameworks/pytorch/data/loaders.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def __iter__(
    self,
) -> Iterator[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]:
    """Yield samples one by one across all CSV files."""
    for file_path in self.files:
        logger.info("Streaming CSV file: {path}", path=file_path)
        try:
            chunk_iter = pd.read_csv(
                file_path,
                usecols=self.column_names,
                chunksize=self.chunksize,
            )
        except TypeError:
            # If usecols=None is not accepted by some pandas version
            chunk_iter = pd.read_csv(
                file_path,
                chunksize=self.chunksize,
            )

        for chunk_df in chunk_iter:
            if self.label_name:
                if self.label_name not in chunk_df.columns:
                    raise DataLoadingError(
                        f"Label column '{self.label_name}' not found in "
                        f"file {file_path} (columns: {list(chunk_df.columns)})"
                    )
                labels_np = chunk_df[self.label_name].to_numpy()
                features_df = chunk_df.drop(columns=[self.label_name])
            else:
                labels_np = None
                features_df = chunk_df

            features_np = features_df.to_numpy(dtype="float32", copy=False)

            if labels_np is not None:
                for row_idx in range(features_np.shape[0]):
                    x = torch.as_tensor(
                        features_np[row_idx],
                        dtype=self.dtype,
                    )
                    y = torch.as_tensor(labels_np[row_idx], dtype=self.dtype)
                    yield x, y
            else:
                for row_idx in range(features_np.shape[0]):
                    x = torch.as_tensor(
                        features_np[row_idx],
                        dtype=self.dtype,
                    )
                    yield x
__post_init__
__post_init__() -> None

Resolve files eagerly and log basic configuration.

Source code in mlpotion/frameworks/pytorch/data/loaders.py
252
253
254
255
256
257
258
259
260
261
def __post_init__(self) -> None:
    """Resolve files eagerly and log basic configuration."""
    self.files = self._resolve_files()
    logger.info(
        "Initialized StreamingCSVDataset with {n_files} file(s), "
        "chunksize={chunksize}, label_name={label}",
        n_files=len(self.files),
        chunksize=self.chunksize,
        label=self.label_name,
    )

Training

mlpotion.frameworks.pytorch.training.trainers

PyTorch model training.

Classes

ModelTrainer

Bases: ModelTrainerProtocol[nn.Module, DataLoader]

Generic trainer for PyTorch models.

This class implements the ModelTrainerProtocol for PyTorch models. It handles the training loop, device placement, loss calculation, backpropagation, and validation.

It supports: - Supervised learning (batch is (inputs, targets)). - Unsupervised/Self-supervised learning (batch is inputs only, loss is fn(outputs, inputs)). - Custom loss functions (string alias, nn.Module, or callable). - Automatic device management (CPU/GPU).

Attributes:

Name Type Description
model nn.Module

The PyTorch model to train.

dataloader DataLoader

The training data loader.

config ModelTrainingConfig

Configuration for training (epochs, optimizer, etc.).

Example
import torch
import torch.nn as nn
from mlpotion.frameworks.pytorch import ModelTrainer
from mlpotion.frameworks.pytorch.config import ModelTrainingConfig

# Define model
model = nn.Linear(10, 1)

# Define config
config = ModelTrainingConfig(
    epochs=5,
    learning_rate=0.01,
    optimizer="adam",
    loss_fn="mse",
    device="cpu"
)

# Initialize trainer
trainer = ModelTrainer()

# Train
result = trainer.train(model, train_loader, config, val_loader)
print(result.metrics)
Functions
train
train(
    model: nn.Module,
    dataloader: DataLoader[Any],
    config: ModelTrainingConfig,
    validation_dataloader: DataLoader[Any] | None = None,
) -> TrainingResult[nn.Module]

Train a PyTorch model.

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model (nn.Module) to train.

required
dataloader DataLoader[Any]

The DataLoader providing training data.

required
config ModelTrainingConfig

A ModelTrainingConfig object containing training parameters.

required
validation_dataloader DataLoader[Any] | None

Optional DataLoader for validation.

None

Returns:

Type Description
TrainingResult[nn.Module]

TrainingResult[nn.Module]: A dataclass containing the trained model,

TrainingResult[nn.Module]

training history (loss/metrics per epoch), and final metrics.

Raises:

Type Description
TrainingError

If the training loop encounters an error (e.g., NaN loss,

Source code in mlpotion/frameworks/pytorch/training/trainers.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
@trycatch(
    error=ModelTrainerError,
    success_msg="✅ Successfully trained PyTorch model",
)
def train(
    self,
    model: nn.Module,
    dataloader: DataLoader[Any],
    config: ModelTrainingConfig,
    validation_dataloader: DataLoader[Any] | None = None,
) -> TrainingResult[nn.Module]:
    """Train a PyTorch model.

    Args:
        model: The PyTorch model (`nn.Module`) to train.
        dataloader: The `DataLoader` providing training data.
        config: A `ModelTrainingConfig` object containing training parameters.
        validation_dataloader: Optional `DataLoader` for validation.

    Returns:
        TrainingResult[nn.Module]: A dataclass containing the trained model,
        training history (loss/metrics per epoch), and final metrics.

    Raises:
        TrainingError: If the training loop encounters an error (e.g., NaN loss,
        device mismatch, empty dataloader).
    """
    try:
        logger.info("Starting PyTorch model training...")
        logger.info(
            "Config: epochs={epochs}, lr={lr}, optimizer={opt}, "
            "loss_fn={loss_fn}, device={device}",
            epochs=config.epochs,
            lr=config.learning_rate,
            opt=config.optimizer,
            loss_fn=config.loss_fn,
            device=config.device,
        )

        # Setup device
        device = torch.device(config.device)
        model = model.to(device)

        # Setup optimizer and loss
        optimizer = self._create_optimizer(model, config)
        criterion = self._create_loss_fn(config)

        # Optional limit on batches per epoch
        max_batches_per_epoch = getattr(config, "max_batches_per_epoch", None)
        if max_batches_per_epoch is None:
            max_batches_per_epoch = getattr(config, "max_batches", None)

        history: dict[str, list[float]] = {"loss": []}
        if validation_dataloader is not None:
            history["val_loss"] = []

        # Initialize callbacks and TensorBoard
        callbacks = self._prepare_callbacks(config)
        tensorboard_writer = self._setup_tensorboard(config)

        start_time = time.time()

        # Call on_train_begin callbacks
        for callback in callbacks:
            if hasattr(callback, "on_train_begin"):
                callback.on_train_begin()

        for epoch in range(config.epochs):
            model.train()
            epoch_loss = 0.0
            num_batches = 0

            for batch in dataloader:
                inputs, targets = self._prepare_batch(batch, device=device)

                optimizer.zero_grad()
                outputs = model(inputs)

                # Supervised vs unsupervised / autoencoder
                if targets is not None:
                    loss = criterion(outputs, targets)
                else:
                    loss = criterion(outputs, inputs)

                loss.backward()
                optimizer.step()

                epoch_loss += float(loss.item())
                num_batches += 1

                if (
                    max_batches_per_epoch is not None
                    and num_batches >= max_batches_per_epoch
                ):
                    logger.info(
                        "Reached max_batches_per_epoch={mb}; "
                        "stopping epoch {epoch} early.",
                        mb=max_batches_per_epoch,
                        epoch=epoch + 1,
                    )
                    break

            if num_batches == 0:
                raise TrainingError("Training dataloader yielded no batches.")

            avg_loss = epoch_loss / num_batches
            history["loss"].append(avg_loss)

            # Validation phase
            if validation_dataloader is not None:
                val_loss = self._validate(
                    model=model,
                    dataloader=validation_dataloader,
                    criterion=criterion,
                    device=device,
                )
                history["val_loss"].append(val_loss)
            else:
                val_loss = None

            # Logging
            if getattr(config, "verbose", True):
                msg = f"Epoch {epoch + 1}/{config.epochs} - loss: {avg_loss:.4f}"
                if val_loss is not None:
                    msg += f" - val_loss: {val_loss:.4f}"
                logger.info(msg)

            # TensorBoard logging
            if tensorboard_writer is not None:
                tensorboard_writer.add_scalar("loss", avg_loss, epoch)
                if val_loss is not None:
                    tensorboard_writer.add_scalar("val_loss", val_loss, epoch)

            # Call on_epoch_end callbacks
            for callback in callbacks:
                if hasattr(callback, "on_epoch_end"):
                    callback.on_epoch_end(
                        epoch, {"loss": avg_loss, "val_loss": val_loss}
                    )

        training_time = time.time() - start_time

        # Final metrics
        metrics: dict[str, float] = {"loss": float(history["loss"][-1])}
        if "val_loss" in history and history["val_loss"]:
            metrics["val_loss"] = float(history["val_loss"][-1])

        best_epoch = self._find_best_epoch(history)

        # Call on_train_end callbacks
        for callback in callbacks:
            if hasattr(callback, "on_train_end"):
                callback.on_train_end()

        # Close TensorBoard writer
        if tensorboard_writer is not None:
            tensorboard_writer.close()

        logger.info("Training completed in {t:.2f}s", t=training_time)
        logger.info("Final metrics: {metrics}", metrics=metrics)

        return TrainingResult(
            model=model,
            history=history,
            metrics=metrics,
            config=config,
            training_time=training_time,
            best_epoch=best_epoch,
        )

    except TrainingError:
        raise
    except Exception as exc:  # noqa: BLE001
        raise TrainingError(f"Training failed: {exc!s}") from exc

Evaluation

mlpotion.frameworks.pytorch.evaluation.evaluators

PyTorch model evaluation.

Classes

ModelEvaluator

Bases: ModelEvaluatorProtocol[nn.Module, DataLoader]

Generic evaluator for PyTorch models.

This class implements the ModelEvaluatorProtocol for PyTorch models. It performs a full pass over the evaluation dataset, computing the average loss.

It supports: - Supervised and unsupervised evaluation. - Custom loss functions. - Automatic device management.

Example
from mlpotion.frameworks.pytorch import ModelEvaluator
from mlpotion.frameworks.pytorch.config import ModelEvaluationConfig

evaluator = ModelEvaluator()
config = ModelEvaluationConfig(loss_fn="cross_entropy", device="cuda")

result = evaluator.evaluate(model, test_loader, config)
print(f"Test Loss: {result.metrics['loss']}")
Functions
evaluate
evaluate(
    model: nn.Module,
    dataloader: DataLoader[Any],
    config: ModelEvaluationConfig,
) -> EvaluationResult

Evaluate a PyTorch model.

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model to evaluate.

required
dataloader DataLoader[Any]

The DataLoader providing evaluation data.

required
config ModelEvaluationConfig

A ModelEvaluationConfig object containing evaluation parameters.

required

Returns:

Name Type Description
EvaluationResult EvaluationResult

A dataclass containing the computed metrics (e.g., average loss)

EvaluationResult

and execution time.

Raises:

Type Description
EvaluationError

If evaluation fails.

Source code in mlpotion/frameworks/pytorch/evaluation/evaluators.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@trycatch(
    error=ModelEvaluatorError,
    success_msg="✅ Successfully evaluated PyTorch model",
)
def evaluate(
    self,
    model: nn.Module,
    dataloader: DataLoader[Any],
    config: ModelEvaluationConfig,
) -> EvaluationResult:
    """Evaluate a PyTorch model.

    Args:
        model: The PyTorch model to evaluate.
        dataloader: The `DataLoader` providing evaluation data.
        config: A `ModelEvaluationConfig` object containing evaluation parameters.

    Returns:
        EvaluationResult: A dataclass containing the computed metrics (e.g., average loss)
        and execution time.

    Raises:
        EvaluationError: If evaluation fails.
    """
    try:
        device_str = getattr(config, "device", "cpu")
        logger.info("Starting PyTorch model evaluation...")
        logger.info(
            f"Config: device={device_str}, loss_fn={getattr(config, 'loss_fn', 'mse')}"
        )

        device = torch.device(device_str)
        model = model.to(device)
        model.eval()

        criterion = self._create_loss_fn(config)

        # Support optional max_batches on the config
        max_batches = getattr(config, "max_batches", None)

        total_loss = 0.0
        num_batches = 0
        start_time = time.time()

        with torch.no_grad():
            for batch in dataloader:
                inputs, targets = self._prepare_batch(batch, device=device)

                outputs = model(inputs)

                if targets is not None:
                    loss = criterion(outputs, targets)
                else:
                    loss = criterion(outputs, inputs)

                total_loss += float(loss.item())
                num_batches += 1

                if max_batches is not None and num_batches >= max_batches:
                    logger.info(
                        f"Reached max_batches={max_batches}; "
                        "stopping evaluation early."
                    )
                    break

        if num_batches == 0:
            raise EvaluationError("Evaluation dataloader yielded no batches.")

        avg_loss = total_loss / num_batches
        evaluation_time = time.time() - start_time

        metrics = {"loss": float(avg_loss)}

        logger.info(f"Evaluation completed in {evaluation_time:.2f}s")
        logger.info(f"Metrics: {metrics}")

        return EvaluationResult(
            metrics=metrics,
            config=config,
            evaluation_time=evaluation_time,
        )

    except EvaluationError:
        raise
    except Exception as exc:  # noqa: BLE001
        raise EvaluationError(f"Evaluation failed: {exc!s}") from exc

Persistence

mlpotion.frameworks.pytorch.deployment.persistence

PyTorch model persistence.

Classes

ModelPersistence dataclass

Bases: ModelPersistenceProtocol[nn.Module]

Persistence helper for PyTorch models.

This class manages saving and loading of PyTorch models. It supports two modes: 1. State Dict (Recommended): Saves only the model parameters (model.state_dict()). Requires the model class to be available when loading. 2. Full Model: Saves the entire model object using pickle. Less portable but easier to load.

Attributes:

Name Type Description
path str | Path

The file path for the model artifact.

model nn.Module | None

The PyTorch model instance.

Example

Saving and Loading State Dict (Recommended):

from mlpotion.frameworks.pytorch import ModelPersistence
import torch.nn as nn

# Define model
class MyModel(nn.Module):
    def __init__(self): super().__init__(); self.l = nn.Linear(1, 1)

model = MyModel()

# Save
saver = ModelPersistence(path="model.pth", model=model)
saver.save(save_full_model=False)

# Load
loader = ModelPersistence(path="model.pth")
# We must provide the model class or an instance for state_dict loading
loaded_model = loader.load(model_class=MyModel)

Example

Saving and Loading Full Model:

# Save
saver.save(save_full_model=True)

# Load (no model class needed)
loader = ModelPersistence(path="model.pth")
loaded_model = loader.load()

Attributes
path_obj property writable
path_obj: Path

Return the model path as a Path.

Functions
load
load(
    *,
    model_class: type[nn.Module] | None = None,
    map_location: str | torch.device | None = "cpu",
    strict: bool = True,
    model_kwargs: dict[str, Any] | None = None,
    **torch_load_kwargs: Any
) -> tuple[nn.Module, dict[str, Any] | None]

Load a PyTorch model from disk.

This method automatically detects if the file is a full model checkpoint or a state dict.

Parameters:

Name Type Description Default
model_class type[nn.Module] | None

The model class to instantiate if loading a state dict and no model instance is currently attached.

None
map_location str | torch.device | None

Device to load the model onto (default: "cpu").

'cpu'
strict bool

Whether to strictly enforce state dict keys match the model.

True
model_kwargs dict[str, Any] | None

Arguments to pass to model_class constructor.

None
**torch_load_kwargs Any

Additional arguments passed to torch.load().

{}

Returns:

Type Description
tuple[nn.Module, dict[str, Any] | None]

nn.Module: The loaded PyTorch model.

Raises:

Type Description
ModelPersistenceError

If loading fails, or if model_class is missing

Source code in mlpotion/frameworks/pytorch/deployment/persistence.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
@trycatch(
    error=ModelPersistenceError,
    success_msg="✅ Successfully loaded PyTorch model",
)
def load(
    self,
    *,
    model_class: type[nn.Module] | None = None,
    map_location: str | torch.device | None = "cpu",
    strict: bool = True,
    model_kwargs: dict[str, Any] | None = None,
    **torch_load_kwargs: Any,
) -> tuple[nn.Module, dict[str, Any] | None]:
    """Load a PyTorch model from disk.

    This method automatically detects if the file is a full model checkpoint or a
    state dict.

    Args:
        model_class: The model class to instantiate if loading a state dict and no
            model instance is currently attached.
        map_location: Device to load the model onto (default: "cpu").
        strict: Whether to strictly enforce state dict keys match the model.
        model_kwargs: Arguments to pass to `model_class` constructor.
        **torch_load_kwargs: Additional arguments passed to `torch.load()`.

    Returns:
        nn.Module: The loaded PyTorch model.

    Raises:
        ModelPersistenceError: If loading fails, or if `model_class` is missing
        when required.
    """
    path = self._ensure_path_exists()

    logger.info("Loading PyTorch model from {path}", path=str(path))

    checkpoint = torch.load(path, map_location=map_location, **torch_load_kwargs)

    # Case 1: full model was saved
    if isinstance(checkpoint, nn.Module):
        logger.info("Detected full-model checkpoint (nn.Module).")
        self.model = checkpoint
        logger.info("PyTorch model loaded successfully from full-model checkpoint.")
        return checkpoint, None

    # Case 2: dict-like checkpoint (state_dict or wrapped)
    if isinstance(checkpoint, dict):
        logger.info("Detected dict-like checkpoint; treating as state_dict.")
        state_dict = self._extract_state_dict(checkpoint)

        # If we already have a model attached, reuse it; otherwise, we need model_class
        if self.model is not None:
            model = self.model
            logger.info(
                "Using attached model instance of type {cls} for state_dict loading.",
                cls=type(model).__name__,
            )
        else:
            if model_class is None:
                raise ModelPersistenceError(
                    "model_class is required when loading from a state_dict "
                    "checkpoint if no model is attached."
                )
            model = self._instantiate_model(model_class, model_kwargs)
            self.model = model

        missing, unexpected = model.load_state_dict(state_dict, strict=strict)

        if strict:
            logger.debug(
                "State_dict loaded with strict=True (no mismatch error raised)."
            )
        else:
            if missing:
                logger.warning(f"Missing keys in state_dict: {missing}")
            if unexpected:
                logger.warning(f"Unexpected keys in state_dict: {unexpected}")

        logger.info("PyTorch model loaded successfully from state_dict checkpoint.")
        return model, None

    # Case 3: unsupported checkpoint structure
    raise ModelPersistenceError(
        f"Unsupported checkpoint type: {type(checkpoint)!r}. "
        "Expected nn.Module or dict-like object."
    )
save
save(
    *,
    save_full_model: bool = False,
    **torch_save_kwargs: Any
) -> None

Save the attached PyTorch model to disk.

Parameters:

Name Type Description Default
save_full_model bool

If True, saves the entire model object (pickle). If False (default), saves only the state_dict.

False
**torch_save_kwargs Any

Additional arguments passed to torch.save().

{}

Raises:

Type Description
ModelPersistenceError

If no model is attached or saving fails.

Source code in mlpotion/frameworks/pytorch/deployment/persistence.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@trycatch(
    error=ModelPersistenceError,
    success_msg="✅ Successfully saved PyTorch model",
)
def save(
    self,
    *,
    save_full_model: bool = False,
    **torch_save_kwargs: Any,
) -> None:
    """Save the attached PyTorch model to disk.

    Args:
        save_full_model: If True, saves the entire model object (pickle).
            If False (default), saves only the `state_dict`.
        **torch_save_kwargs: Additional arguments passed to `torch.save()`.

    Raises:
        ModelPersistenceError: If no model is attached or saving fails.
    """
    model = self._ensure_model()
    path = self.path_obj

    logger.info(
        "Saving PyTorch model to {path} ({mode})",
        path=str(path),
        mode="full model" if save_full_model else "state_dict",
    )

    path.parent.mkdir(parents=True, exist_ok=True)

    if save_full_model:
        logger.warning(
            "Saving a full model object. This is less portable and may break "
            "if the code structure changes. Prefer saving a state_dict for "
            "long-term storage."
        )
        torch.save(model, path, **torch_save_kwargs)
    else:
        torch.save(model.state_dict(), path, **torch_save_kwargs)

    logger.info("PyTorch model saved successfully.")

Export

mlpotion.frameworks.pytorch.deployment.exporters

Classes

ModelExporter dataclass

Bases: ModelExporterProtocol[nn.Module]

Export PyTorch models to TorchScript, ONNX, or state_dict formats.

This class implements the ModelExporterProtocol for PyTorch. It supports exporting models for deployment or interoperability.

Supported formats: - torchscript: Exports via torch.jit.script or torch.jit.trace. - onnx: Exports to ONNX format (requires example_input). - state_dict: Saves the model parameters.

Example
from mlpotion.frameworks.pytorch import ModelExporter
from mlpotion.frameworks.pytorch.config import ModelExportConfig
import torch

# Prepare model and input
model = ...
example_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
exporter = ModelExporter()
config = ModelExportConfig(
    export_path="models/model.onnx",
    format="onnx",
    example_input=example_input
)

result = exporter.export(model, config)
Functions
export
export(
    model: nn.Module, config: ModelExportConfig
) -> ExportResult

Export a PyTorch model to the specified format.

Parameters:

Name Type Description Default
model nn.Module

The PyTorch model to export.

required
config ModelExportConfig

Configuration object specifying format, path, and other options.

required

Returns:

Name Type Description
ExportResult ExportResult

A dataclass containing the path to the exported artifact and metadata.

Raises:

Type Description
ExportError

If the export process fails (e.g., invalid format, missing example input).

Source code in mlpotion/frameworks/pytorch/deployment/exporters.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@trycatch(
    error=ModelExporterError,
    success_msg="✅ Successfully Exported model",
)
def export(
    self,
    model: nn.Module,
    config: ModelExportConfig,
) -> ExportResult:
    """Export a PyTorch model to the specified format.

    Args:
        model: The PyTorch model to export.
        config: Configuration object specifying format, path, and other options.

    Returns:
        ExportResult: A dataclass containing the path to the exported artifact and metadata.

    Raises:
        ExportError: If the export process fails (e.g., invalid format, missing example input).
    """
    try:
        export_root = Path(config.export_path)
        export_root.parent.mkdir(parents=True, exist_ok=True)

        fmt = config.format.lower()
        device_str = getattr(config, "device", "cpu")
        device = torch.device(device_str)

        logger.info(
            "Exporting PyTorch model "
            f"[format={fmt}, device={device_str}, target={export_root}]"
        )

        model = model.to(device)
        model.eval()

        # Dispatch
        if fmt == "torchscript":
            final_path = self._export_torchscript(
                model=model,
                export_root=export_root,
                config=config,
                device=device,
            )
        elif fmt == "onnx":
            final_path = self._export_onnx(
                model=model,
                export_root=export_root,
                config=config,
                device=device,
            )
        elif fmt == "state_dict":
            final_path = self._export_state_dict(
                model=model,
                export_root=export_root,
            )
        else:
            raise ExportError(f"Unknown export format: {config.format!r}")

        logger.success(f"Model successfully exported → {final_path}")

        metadata: dict[str, Any] = {
            "model_type": "pytorch",
            "format": fmt,
            "device": device_str,
        }

        return ExportResult(
            export_path=str(final_path),
            format=fmt,
            config=config,
            metadata=metadata,
        )

    except ExportError:
        raise
    except Exception as exc:  # noqa: BLE001
        raise ExportError(f"Export failed: {exc!s}") from exc

See the PyTorch Guide for usage examples