Checkpoint

Checkpoint management for model, optimizer, and arbitrary stateful objects.

This module provides Checkpoint for saving and restoring state dictionaries, and Storage for plain-dictionary checkpointing.

Example:

>>> from experiments import Checkpoint, Model, Optimizer
>>> ckpt = Checkpoint()
>>> ckpt.snapshot(model).snapshot(optimizer)
>>> ckpt.save("run.pt")
>>> # Later...
>>> ckpt.load("run.pt")
>>> ckpt.restore(model).restore(optimizer)
class experiments.checkpoint.Checkpoint

Bases: object

Collection of state dictionaries with saving/loading helpers.

This class can snapshot any object implementing the state_dict / load_state_dict protocol (e.g. torch.nn.Module, torch.optim.Optimizer). It also knows how to unwrap Model and Optimizer wrappers automatically.

Example:

>>> ckpt = Checkpoint()
>>> ckpt.snapshot(model, deepcopy=True)
>>> ckpt.restore(model)
load(filepath, overwrite=False)

Load checkpoint data from a file.

Parameters:
  • filepath (str or pathlib.Path) – Path to the saved checkpoint.

  • overwrite (bool, optional) – Whether to overwrite any existing snapshots.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If the checkpoint is non-empty and overwrite is False.

restore(instance, nothrow=False)

Restore an instance from its stored snapshot.

Parameters:
  • instance (object) – Instance to restore. Must support load_state_dict().

  • nothrow (bool, optional) – If True, silently skip when no snapshot is available.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If no snapshot exists and nothrow is False.

save(filepath, overwrite=False)

Save the current checkpoint to a file.

Parameters:
  • filepath (str or pathlib.Path) – Destination path.

  • overwrite (bool, optional) – Whether to overwrite an existing file.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If the file exists and overwrite is False.

snapshot(instance, overwrite=False, deepcopy=False, nowarnref=False)

Take (or overwrite) a snapshot of an instance’s state dictionary.

Parameters:
  • instance (object) – Instance to snapshot. Must support state_dict().

  • overwrite (bool, optional) – Whether to overwrite an existing snapshot for the same class.

  • deepcopy (bool, optional) – Whether to deep-copy the state dictionary instead of shallow-copying.

  • nowarnref (bool, optional) – Suppress the debug warning when restoring a reference is the intended behavior.

  • Returns

  • -------

  • Checkpoint – Self, for chaining.

  • Raises

  • ------

  • tools.UserException – If a snapshot already exists and overwrite is False.

class experiments.checkpoint.Storage

Bases: dict

Plain dictionary that implements the state_dict protocol.

This allows arbitrary key/value data to be snapshotted and restored alongside models and optimizers using Checkpoint.

load_state_dict(state)

Replace contents with the given state.

Parameters:

state (dict) – New dictionary contents.

state_dict()

Return the dictionary itself as state.

Returns:

dict

Self.

See also

For the model state being checkpointed, see Model. For the optimizer state being checkpointed, see Optimizer.