Model¶
Zero-copy flat-tensor view of a torch.nn.Module.
The Model wrapper relinks a module’s parameters and gradients to
contiguous flat 1-D tensors, so that aggregators and attacks can operate on
a single vector representation of the model state without copying data on
every access. Reading Model.parameters or Model.gradients
returns a view sharing the underlying buffer; writing to
model.gradients = flat unpacks the flat vector back into each
parameter’s .grad in place.
Example:
from torch.nn import Linear
from krum.primitives import Model
model = Model(Linear(4, 2))
# Read the flat parameter vector (shares memory with the module)
params = model.parameters # shape (d,)
# Simulate a backward pass, then read gradients
loss = model.module(torch.randn(3, 4)).sum()
loss.backward()
grads = model.gradients # shape (d,), lazy-initializes .grad
# Write aggregated gradients back (zero-copy relink)
model.gradients = grads.clone()
- class primitives.model.Model(module: Module)[source]¶
Bases:
objectModel encapsulating a
Modulewith zero-copy flat views of parameters and gradients.All parameters and gradients are relinked to their respective, single contiguous buffer. Reading the
parametersorgradientsproperty returns a flat tensor sharing that buffer — modifying it modifies the module parameters and gradients directly. Writing tomodel.gradients = flatunpacks the flat vector back into each parameter’s.grad, sharing the memory of this flat vector.The end-user is responsible for tensor copy, swap, and sharing (e.g. calling
.clone()before “sending” gradients to a remote worker).- Parameters:
module – The module to encapsulate.
- property gradients: Tensor¶
Zero-copy flat view of module gradients.
The returned tensor shares memory with each parameter’s
.grad. If a parameter has no gradient yet, a zero-filled gradient is assigned.This is a simple, fast getter once the flat gradient has been lazy-initialized on the first access. If gradients are replaced or removed externally (e.g. via
module.zero_grad(set_to_none=True)), callrelink_gradients()to re-synchronise the flat view.- Returns:
Tensor of shape `` (d,)
gradients, in the iteration order of ``module.parameters()``.
- property parameters: Tensor¶
Zero-copy flat view of module parameters.
The returned tensor shares memory with the model’s weights. Modifying it modifies the model directly. Clone before sending.
This is a simple, fast getter once the flat parameters have been lazy-initialized on the first access. If a parameter’s
.datahas been replaced externally, callrelink_parameters()to re-synchronise the flat view.- Returns:
Tensor of shape `` (d,)
parameters, in the iteration order of ``module.parameters()``.
- relink_gradients() Self[source]¶
Re-synchronise the flat gradient view after an external
.gradreplacement.This method is necessary when a parameter’s
.gradattribute has been replaced or removed since the flat gradient was built, for instance aftermodule.zero_grad(set_to_none=True). It restores the zero-copy link between the cached flat tensor and every.gradso that the fast.gradientsgetter yields a consistent view again.
- relink_parameters() Self[source]¶
Re-synchronise the flat parameter view after an external
.datareplacement.This method is necessary when a parameter’s
.datahas been replaced since the flat parameters were built, restoring the zero-copy link between the cached flat tensor and every parameter so that the fast.parametersgetter yields a consistent view again.