Model

Model wrapper with name resolution, initialization, and gradient handling.

This module provides Model, a unified interface that can instantiate torchvision models (by name), custom models from experiments/models/, or arbitrary callables. It also manages parameter flattening, gradient extraction, and data-parallelism automatically.

Example:

>>> from experiments import Model, Configuration
>>> config = Configuration(device="cpu")
>>> model = Model("resnet18", config, num_classes=10)
>>> output = model.run(inputs)
>>> gradient, loss = model.backprop(dataset=dataset, loss=loss, outloss=True)
>>> model.update(gradient, optimizer=optimizer)
class experiments.model.Model(name_build, config=None, init_multi=None, init_multi_args=None, init_mono=None, init_mono_args=None, *args, **kwargs)

Bases: object

Unified model wrapper with parameter and gradient management.

Models are resolved by lower-case name from torchvision.models and from custom modules under experiments/models/. Parameters are automatically flattened into a contiguous vector accessible via get() and set().

Data parallelism (torch.nn.DataParallel) is enabled automatically when the model is placed on a non-indexed CUDA device.

Parameters:
  • name_build (str or callable) – Model name (e.g. "resnet18", "torchvision-resnet18") or a constructor function.

  • config (experiments.Configuration, optional) – Target device and dtype configuration.

  • init_multi (str or callable or None, optional) – Weight initializer for tensors of dimension >= 2.

  • init_multi_args (dict or None, optional) – Keyword arguments for init_multi when it is a string.

  • init_mono (str or callable or None, optional) – Weight initializer for tensors of dimension == 1.

  • init_mono_args (dict or None, optional) – Keyword arguments for init_mono when it is a string.

  • *args (object) – Forwarded to the model constructor.

  • **kwargs (object) – Forwarded to the model constructor.

  • Raises

  • ------

  • tools.UnavailableException – If name_build is an unknown string.

  • tools.UserException – If the built object is not a torch.nn.Module.

backprop(dataset=None, loss=None, outloss=False, **kwargs)

Compute gradient on a batch from the given dataset.

Parameters:
  • dataset (experiments.Dataset or None, optional) – Dataset to sample from. None uses the default trainset.

  • loss (experiments.Loss or None, optional) – Loss function. None uses the default loss.

  • outloss (bool, optional) – Whether to also return the loss value.

  • **kwargs (object) – Forwarded to loss.backward().

  • Returns

  • -------

  • tuple[torch.Tensor (torch.Tensor or) – Flat gradient, optionally paired with the loss value.

  • torch.Tensor] – Flat gradient, optionally paired with the loss value.

property config

Return the immutable configuration.

Returns:

experiments.Configuration

Model configuration.

default(name, new=None, erase=False)

Get and/or set a named default.

Parameters:
  • name (str) – Default key (e.g. "trainset", "loss", "optimizer").

  • new (object or None, optional) – New value to set. Ignored unless new is not None or erase is True.

  • erase (bool, optional) – Force the value to None.

  • Returns

  • -------

  • object – Current (or old) value of the default.

  • Raises

  • ------

  • tools.UnavailableException – If name is not a known default.

eval(dataset=None, criterion=None)

Evaluate the model on a batch from the given dataset.

Parameters:
  • dataset (experiments.Dataset or None, optional) – Dataset to sample from. None uses the default testset.

  • criterion (experiments.Criterion or None, optional) – Criterion function. None uses the default criterion.

  • Returns

  • -------

  • torch.Tensor – Mean criterion value over the sampled batch.

get()

Get a reference to the flat parameter vector.

Returns:

torch.Tensor

Flat parameter tensor. Future calls to set() will modify it in place.

get_gradient()

Get (or create) the flat gradient vector.

Returns:

torch.Tensor

Flat gradient tensor. Future calls to set_gradient() will modify it in place.

loss(dataset=None, loss=None, training=None)

Estimate loss on a batch from the given dataset.

Parameters:
  • dataset (experiments.Dataset or None, optional) – Dataset to sample from. None uses the default trainset.

  • loss (experiments.Loss or None, optional) – Loss function. None uses the default loss.

  • training (bool or None, optional) – Whether this is a training run. None guesses from torch.is_grad_enabled().

  • Returns

  • -------

  • torch.Tensor – Scalar loss value.

run(data, training=False)

Forward pass through the model.

Parameters:
  • data (torch.Tensor) – Input tensor.

  • training (bool, optional) – Whether to use training mode (enables dropout, batch-norm updates, etc.). Defaults to evaluation mode.

  • Returns

  • -------

  • torch.Tensor – Model output.

set(params, relink=None)

Overwrite parameters with the given flat vector.

Parameters:
  • params (torch.Tensor) – New flat parameter vector.

  • relink (bool or None, optional) – Whether to relink instead of copying. None uses the configuration default.

set_gradient(gradient, relink=None)

Overwrite the gradient with the given flat vector.

Parameters:
  • gradient (torch.Tensor) – New flat gradient.

  • relink (bool or None, optional) – Whether to relink instead of copying. None uses the configuration default.

update(gradient, optimizer=None, relink=None)

Update parameters using the given gradient and optimizer.

Parameters:
  • gradient (torch.Tensor) – Flat gradient to apply.

  • optimizer (experiments.Optimizer or None, optional) – Optimizer wrapper. None uses the default optimizer.

  • relink (bool or None, optional) – Whether to relink the gradient. None uses the configuration default.

See also

For gradient computation and backprop, see Loss. For parameter updates, see Optimizer. For saving and restoring state, see Checkpoint.