Model

Zero-copy flat-tensor view of a torch.nn.Module.

The Model wrapper relinks a module’s parameters and gradients to contiguous flat 1-D tensors, so that aggregators and attacks can operate on a single vector representation of the model state without copying data on every access. Reading Model.parameters or Model.gradients returns a view sharing the underlying buffer; writing to model.gradients = flat unpacks the flat vector back into each parameter’s .grad in place.

Example:

from torch.nn import Linear
from krum.primitives import Model

model = Model(Linear(4, 2))

# Read the flat parameter vector (shares memory with the module)
params = model.parameters          # shape (d,)

# Simulate a backward pass, then read gradients
loss = model.module(torch.randn(3, 4)).sum()
loss.backward()
grads = model.gradients            # shape (d,), lazy-initializes .grad

# Write aggregated gradients back (zero-copy relink)
model.gradients = grads.clone()
class primitives.model.Model(module: Module)[source]

Bases: object

Model encapsulating a Module with zero-copy flat views of parameters and gradients.

All parameters and gradients are relinked to their respective, single contiguous buffer. Reading the parameters or gradients property returns a flat tensor sharing that buffer — modifying it modifies the module parameters and gradients directly. Writing to model.gradients = flat unpacks the flat vector back into each parameter’s .grad, sharing the memory of this flat vector.

The end-user is responsible for tensor copy, swap, and sharing (e.g. calling .clone() before “sending” gradients to a remote worker).

Parameters:

module – The module to encapsulate.

property gradients: Tensor

Zero-copy flat view of module gradients.

The returned tensor shares memory with each parameter’s .grad. If a parameter has no gradient yet, a zero-filled gradient is assigned.

This is a simple, fast getter once the flat gradient has been lazy-initialized on the first access. If gradients are replaced or removed externally (e.g. via module.zero_grad(set_to_none=True)), call relink_gradients() to re-synchronise the flat view.

Returns:
  • Tensor of shape `` (d,)

  • gradients, in the iteration order of ``module.parameters()``.

property module: Module

The encapsulated module.

Returns:

The encapsulated module.

property parameters: Tensor

Zero-copy flat view of module parameters.

The returned tensor shares memory with the model’s weights. Modifying it modifies the model directly. Clone before sending.

This is a simple, fast getter once the flat parameters have been lazy-initialized on the first access. If a parameter’s .data has been replaced externally, call relink_parameters() to re-synchronise the flat view.

Returns:
  • Tensor of shape `` (d,)

  • parameters, in the iteration order of ``module.parameters()``.

Re-synchronise the flat gradient view after an external .grad replacement.

This method is necessary when a parameter’s .grad attribute has been replaced or removed since the flat gradient was built, for instance after module.zero_grad(set_to_none=True). It restores the zero-copy link between the cached flat tensor and every .grad so that the fast .gradients getter yields a consistent view again.

Re-synchronise the flat parameter view after an external .data replacement.

This method is necessary when a parameter’s .data has been replaced since the flat parameters were built, restoring the zero-copy link between the cached flat tensor and every parameter so that the fast .parameters getter yields a consistent view again.