Experiments

The experiments module provides dataset and model wrappers that form the foundation of a Krum training loop. It handles model construction, dataset loading, loss composition, optimization, checkpointing, and device/dtype configuration.

Components

Component

Role

Key Class

Configuration

Device, dtype, and memory-transfer settings

experiments.Configuration

Model

Model wrapper with parameter flattening and gradient handling

experiments.Model

Dataset

Infinite-batch wrapper around torchvision and custom datasets

experiments.Dataset

Loss

Derivable loss with regularization and composition support

experiments.Loss

Criterion

Non-derivable evaluation metric (top-k, sigmoid)

experiments.Criterion

Optimizer

Optimizer wrapper with learning-rate control

experiments.Optimizer

Checkpoint

Save/restore state dictionaries for models and optimizers

experiments.Checkpoint

Core Modules

Custom Models & Datasets

Custom Models & Datasets:

API Reference

Experiment components for model training, dataset loading, and evaluation.

This module groups the building blocks of a Krum training loop:

Custom models and datasets can be added under experiments/models/ and experiments/datasets/; they are discovered automatically at import time.

Example:

from experiments import (
    Configuration, Model, Dataset,
    Loss, Criterion, Optimizer,
    make_datasets,
)

config = Configuration(device="cuda:0")
model = Model("resnet18", config, num_classes=10)
trainset, testset = make_datasets("cifar10", train_batch=128)
class experiments.Checkpoint[source]

Bases: object

Collection of state dictionaries with saving/loading helpers.

This class can snapshot any object implementing the state_dict / load_state_dict protocol (e.g. torch.nn.Module, torch.optim.Optimizer). It also knows how to unwrap Model and Optimizer wrappers automatically.

Example:

>>> ckpt = Checkpoint()
>>> ckpt.snapshot(model, deepcopy=True)
>>> ckpt.restore(model)
load(filepath, overwrite=False)[source]

Load checkpoint data from a file.

Parameters:
  • filepath (str or pathlib.Path) – Path to the saved checkpoint.

  • overwrite (bool, optional) – Whether to overwrite any existing snapshots.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If the checkpoint is non-empty and overwrite is False.

restore(instance, nothrow=False)[source]

Restore an instance from its stored snapshot.

Parameters:
  • instance (object) – Instance to restore. Must support load_state_dict().

  • nothrow (bool, optional) – If True, silently skip when no snapshot is available.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If no snapshot exists and nothrow is False.

save(filepath, overwrite=False)[source]

Save the current checkpoint to a file.

Parameters:
  • filepath (str or pathlib.Path) – Destination path.

  • overwrite (bool, optional) – Whether to overwrite an existing file.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If the file exists and overwrite is False.

snapshot(instance, overwrite=False, deepcopy=False, nowarnref=False)[source]

Take (or overwrite) a snapshot of an instance’s state dictionary.

Parameters:
  • instance (object) – Instance to snapshot. Must support state_dict().

  • overwrite (bool, optional) – Whether to overwrite an existing snapshot for the same class.

  • deepcopy (bool, optional) – Whether to deep-copy the state dictionary instead of shallow-copying.

  • nowarnref (bool, optional) – Suppress the debug warning when restoring a reference is the intended behavior.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If a snapshot already exists and overwrite is False.

class experiments.Configuration(device=None, dtype=None, noblock=False, relink=False)[source]

Bases: Mapping

Immutable tensor configuration holder.

This class bundles device, dtype, and memory-transfer options into a single immutable mapping. It is used throughout experiments to ensure every created or moved tensor uses the same configuration.

Parameters:
  • device (str, torch.device, or None, optional) – Target device. None defaults to "cuda" when available, otherwise "cpu". Strings such as "cuda:0" are resolved automatically.

  • dtype (torch.dtype or None, optional) – Tensor datatype. None uses PyTorch’s current default dtype.

  • noblock (bool, optional) – Whether to use non-blocking host-to-device transfers.

  • relink (bool, optional) – Whether to relink instead of copying during parameter assignments.

  • Example

  • -------

  • Configuration (>>> from experiments import)

  • Configuration(device="cpu" (>>> config =)

  • dtype=torch.float32)

  • config["device"] (>>>)

  • device(type='cpu')

default_device = 'cpu'
class experiments.Criterion(name_build, *args, **kwargs)[source]

Bases: object

Non-derivable evaluation metric wrapper.

Available criteria:

  • "top-k" — top-k classification accuracy.

  • "sigmoid" — binary accuracy with sigmoid threshold at 0.5.

All criteria return a 1-D tensor [num_correct, batch_size].

Parameters:
  • name_build (str or callable) – Criterion name or constructor function.

  • *args (object) – Forwarded to the criterion constructor.

  • **kwargs (object) – Forwarded to the criterion constructor.

  • Raises

  • ------

  • tools.UnavailableException – If name_build is an unknown string.

class experiments.Dataset(data, name=None, root=None, *args, **kwargs)[source]

Bases: object

Unified dataset wrapper producing infinite batches.

This class can wrap:

  • A torchvision dataset loaded by name.

  • A custom generator yielding batches forever.

  • A single fixed batch repeated forever.

Parameters:
  • data (str, generator, or object) – Dataset name, infinite generator, or single batch.

  • name (str or None, optional) – User-defined name for debugging.

  • root (str or pathlib.Path or None, optional) – Cache root directory. None uses the default.

  • *args (object) – Forwarded to the dataset constructor when data is a string.

  • **kwargs (object) – Forwarded to the dataset constructor when data is a string.

  • Raises

  • ------

  • tools.UnavailableException – If data is an unknown dataset name.

  • TypeError – If constructor arguments are invalid.

epoch(config=None)[source]

Return a finite epoch iterator.

Note

Only works for DataLoader-based datasets.

Parameters:
  • config (experiments.Configuration or None, optional) – Target configuration for tensor placement.

  • Returns

  • -------

  • generator – Finite iterator over one epoch.

classmethod get_default_root()[source]

Lazily initialize and return the default dataset cache directory.

Returns:

pathlib.Path

Path to the dataset cache. Falls back to the system temp directory if the default does not exist.

sample(config=None)[source]

Sample the next batch.

Parameters:
  • config (experiments.Configuration or None, optional) – Target configuration for tensor placement.

  • Returns

  • -------

  • tuple – Next batch, optionally moved to the target device.

class experiments.Loss(name_build, *args, **kwargs)[source]

Bases: object

Derivable loss function wrapper with composition support.

Losses can be added (loss1 + loss2) and scaled (0.5 * loss). All standard PyTorch losses are available by lower-case name. Additionally, "l1" and "l2" provide parameter-norm regularization.

Parameters:
  • name_build (str or callable) – Loss name (e.g. "crossentropy", "mse") or a callable with signature (output, target, params) -> tensor.

  • *args (object) – Forwarded to the loss constructor when name_build is a string.

  • **kwargs (object) – Forwarded to the loss constructor when name_build is a string.

  • Raises

  • ------

  • tools.UnavailableException – If name_build is an unknown string.

class experiments.Model(name_build, config=None, init_multi=None, init_multi_args=None, init_mono=None, init_mono_args=None, *args, **kwargs)[source]

Bases: object

Unified model wrapper with parameter and gradient management.

Models are resolved by lower-case name from torchvision.models and from custom modules under experiments/models/. Parameters are automatically flattened into a contiguous vector accessible via get() and set().

Data parallelism (torch.nn.DataParallel) is enabled automatically when the model is placed on a non-indexed CUDA device.

Parameters:
  • name_build (str or callable) – Model name (e.g. "resnet18", "torchvision-resnet18") or a constructor function.

  • config (experiments.Configuration, optional) – Target device and dtype configuration.

  • init_multi (str or callable or None, optional) – Weight initializer for tensors of dimension >= 2.

  • init_multi_args (dict or None, optional) – Keyword arguments for init_multi when it is a string.

  • init_mono (str or callable or None, optional) – Weight initializer for tensors of dimension == 1.

  • init_mono_args (dict or None, optional) – Keyword arguments for init_mono when it is a string.

  • *args (object) – Forwarded to the model constructor.

  • **kwargs (object) – Forwarded to the model constructor.

  • Raises

  • ------

  • tools.UnavailableException – If name_build is an unknown string.

  • tools.UserException – If the built object is not a torch.nn.Module.

backprop(dataset=None, loss=None, outloss=False, **kwargs)[source]

Compute gradient on a batch from the given dataset.

Parameters:
  • dataset (experiments.Dataset or None, optional) – Dataset to sample from. None uses the default trainset.

  • loss (experiments.Loss or None, optional) – Loss function. None uses the default loss.

  • outloss (bool, optional) – Whether to also return the loss value.

  • **kwargs (object) – Forwarded to loss.backward().

  • Returns

  • -------

  • tuple[torch.Tensor (torch.Tensor or) – Flat gradient, optionally paired with the loss value.

  • torch.Tensor] – Flat gradient, optionally paired with the loss value.

property config

Return the immutable configuration.

Returns:

experiments.Configuration

Model configuration.

default(name, new=None, erase=False)[source]

Get and/or set a named default.

Parameters:
  • name (str) – Default key (e.g. "trainset", "loss", "optimizer").

  • new (object or None, optional) – New value to set. Ignored unless new is not None or erase is True.

  • erase (bool, optional) – Force the value to None.

  • Returns

  • -------

  • object – Current (or old) value of the default.

  • Raises

  • ------

  • tools.UnavailableException – If name is not a known default.

eval(dataset=None, criterion=None)[source]

Evaluate the model on a batch from the given dataset.

Parameters:
  • dataset (experiments.Dataset or None, optional) – Dataset to sample from. None uses the default testset.

  • criterion (experiments.Criterion or None, optional) – Criterion function. None uses the default criterion.

  • Returns

  • -------

  • torch.Tensor – Mean criterion value over the sampled batch.

get()[source]

Get a reference to the flat parameter vector.

Returns:

torch.Tensor

Flat parameter tensor. Future calls to set() will modify it in place.

get_gradient()[source]

Get (or create) the flat gradient vector.

Returns:

torch.Tensor

Flat gradient tensor. Future calls to set_gradient() will modify it in place.

loss(dataset=None, loss=None, training=None)[source]

Estimate loss on a batch from the given dataset.

Parameters:
  • dataset (experiments.Dataset or None, optional) – Dataset to sample from. None uses the default trainset.

  • loss (experiments.Loss or None, optional) – Loss function. None uses the default loss.

  • training (bool or None, optional) – Whether this is a training run. None guesses from torch.is_grad_enabled().

  • Returns

  • -------

  • torch.Tensor – Scalar loss value.

run(data, training=False)[source]

Forward pass through the model.

Parameters:
  • data (torch.Tensor) – Input tensor.

  • training (bool, optional) – Whether to use training mode (enables dropout, batch-norm updates, etc.). Defaults to evaluation mode.

  • Returns

  • -------

  • torch.Tensor – Model output.

set(params, relink=None)[source]

Overwrite parameters with the given flat vector.

Parameters:
  • params (torch.Tensor) – New flat parameter vector.

  • relink (bool or None, optional) – Whether to relink instead of copying. None uses the configuration default.

set_gradient(gradient, relink=None)[source]

Overwrite the gradient with the given flat vector.

Parameters:
  • gradient (torch.Tensor) – New flat gradient.

  • relink (bool or None, optional) – Whether to relink instead of copying. None uses the configuration default.

update(gradient, optimizer=None, relink=None)[source]

Update parameters using the given gradient and optimizer.

Parameters:
  • gradient (torch.Tensor) – Flat gradient to apply.

  • optimizer (experiments.Optimizer or None, optional) – Optimizer wrapper. None uses the default optimizer.

  • relink (bool or None, optional) – Whether to relink the gradient. None uses the configuration default.

class experiments.Optimizer(name_build, model, *args, **kwargs)[source]

Bases: object

Optimizer wrapper with name resolution and LR control.

Parameters:
  • name_build (str or callable) – Optimizer name (e.g. "adam", "sgd") or a constructor function. Names are resolved against torch.optim.

  • model (experiments.Model) – Model whose parameters will be optimized.

  • *args (object) – Additional positional arguments forwarded to the optimizer constructor.

  • **kwargs (object) – Additional keyword arguments forwarded to the optimizer constructor.

  • Raises

  • ------

  • tools.UnavailableException – If name_build is a string that does not match any known optimizer.

set_lr(lr)[source]

Set the learning rate for all parameter groups.

Parameters:

lr (float) – New learning rate.

class experiments.Storage[source]

Bases: dict

Plain dictionary that implements the state_dict protocol.

This allows arbitrary key/value data to be snapshotted and restored alongside models and optimizers using Checkpoint.

load_state_dict(state)[source]

Replace contents with the given state.

Parameters:

state (dict) – New dictionary contents.

state_dict()[source]

Return the dictionary itself as state.

Returns:

dict

Self.

experiments.batch_dataset(inputs, labels, train=False, batch_size=None, split=0.75)[source]

Batch a raw tensor dataset into infinite sampler generators.

Parameters:
  • inputs (torch.Tensor) – Input data tensor.

  • labels (torch.Tensor) – Label tensor with the same first-dimension size as inputs.

  • train (bool, optional) – Whether to build a training set (adds shuffling) or a test set.

  • batch_size (int or None, optional) – Batch size. None or 0 uses the full split size.

  • split (float or int, optional) – Fraction of samples for training when < 1, or absolute count when >= 1.

  • Returns

  • -------

  • generator – Infinite sampler generator.

experiments.get_default_transform(dataset, train)[source]

Return the default transform for a torchvision dataset.

Parameters:
  • dataset (str or None) – Case-sensitive dataset name. None returns None.

  • train (bool) – Whether to return the training transform. Ignored when dataset is None.

  • Returns

  • -------

  • None (torchvision.transforms.Compose or) – Composed transform, or None if the dataset is unknown.

experiments.make_datasets(dataset, train_batch=None, test_batch=None, train_transforms=None, test_transforms=None, num_workers=1, **custom_args)[source]

Build training and testing dataset wrappers.

Parameters:
  • dataset (str) – Case-sensitive dataset name.

  • train_batch (int or None, optional) – Training batch size. None or 0 for full-batch.

  • test_batch (int or None, optional) – Testing batch size. None or 0 for full-batch.

  • train_transforms (callable or None, optional) – Transform for the training set. None uses the default.

  • test_transforms (callable or None, optional) – Transform for the testing set. None uses the default.

  • num_workers (int or tuple[int, int], optional) – Number of workers for the training and testing loaders. An int applies to both; a tuple specifies (train_workers, test_workers).

  • **custom_args (object) – Additional keyword arguments forwarded to the dataset constructor.

  • Returns

  • -------

  • tuple[Dataset – Training and testing dataset wrappers.

  • Dataset] – Training and testing dataset wrappers.

experiments.make_sampler(loader)[source]

Create an infinite sampler from a DataLoader.

Parameters:
  • loader (torch.utils.data.DataLoader) – Finite data loader.

  • Yields

  • ------

  • tuple – Batches, transparently restarting the loader when exhausted.