DecentralisedSimulation

Base class for decentralised Byzantine-resilient learning simulations.

class krum.simulations.decentralised.base.DecentralisedSimulation(*, model: Model, data: Sequence[Iterable[tuple[Tensor, Tensor]]], loss_fn: Callable[[Tensor, Tensor], Tensor], n: int, f: int, attack: type[Attack] | None = None, attack_kwargs: dict[str, Any] | None = None, aggregator: type[Aggregator], aggregator_kwargs: dict[str, Any] | None = None, seed: int | None = None)[source]

Bases: ABC, Generic[StepResultT]

Base for decentralised simulations with per-worker model mixing.

Each honest worker holds its own flat parameter vector — one row of parameters. A round runs a local optimisation step, then a mixing phase in which every worker replaces its model with an aggregate of the models it received this round.

Two things vary between protocols, and each is an abstract seam:

  • how a worker updates locallylocal_update() maps the worker gradients to the post-local parameters theta_{t+1/2}. The base is agnostic to the rule and its state, so a momentum-free or optimiser-based protocol keeps none of MoNNA’s momentum machinery.

  • which models a worker receivesgather_received_models() builds each worker’s received set (the communication topology).

The base owns everything else: gradient computation, the Byzantine generation hook, the mixing loop, the generic state commit, and the multi-round run() driver. Subclasses also implement build_step_result() so each protocol’s snapshot can carry its own fields (the base StepResult plus, e.g., momentum).

run() may be called repeatedly to continue training: all state (parameters, step_index, subclass optimiser state, and the worker data streams) lives on the instance and persists across calls. Callers that run more rounds than a finite stream provides should cycle their streams.

aggregate_over_received_nodes(local_parameters: Tensor, byzantine_parameters: Tensor) Tensor[source]

Run the model-mixing phase over post-local-update parameter vectors.

For each worker, builds its received set via the protocol-specific gather_received_models(), then aggregates it.

Parameters:
  • local_parameters – Post-local-update honest models, one row per worker.

  • byzantine_parameters – Byzantine models, shape (f, d).

Returns:

The mixed models, one row per honest worker.

aggregate_received_models(candidates: Tensor, *, pivot: Tensor) Tensor[source]

Aggregate the set of models one worker received.

NearestNeighborAverage anchors on the worker’s own model via pivot; pivot-free aggregators (e.g. Krum, Median) absorb it through **specialized.

Parameters:
  • candidates – The received models for one worker.

  • pivot – The worker’s own model, used as the distance reference.

Returns:

The single mixed model for the worker, shape `` (d,)

abstract build_step_result(*, honest_gradients: Tensor, local_parameters: Tensor, byzantine_parameters: Tensor, mixed_parameters: Tensor, losses: Tensor) StepResultT[source]

Build the round snapshot, called after the state has been committed.

Implementations return the protocol’s StepResult subtype, adding any optimiser-state fields (e.g. momentum). The committed step_index and parameters are available via self.

Parameters:
  • honest_gradients – Stacked honest gradients this round.

  • local_parameters – Post-local-update honest models.

  • byzantine_parameters – Byzantine models injected this round.

  • mixed_parameters – Mixed models (equal to the committed parameters).

  • losses – Per-worker losses.

Returns:

The protocol snapshot, with a detached clone of each tensor.

collect_worker_batches() list[tuple[Tensor, Tensor]][source]

Pull one local batch from every honest worker stream.

Returns:

One batch per honest worker, in worker order.

commit_state(parameters: Tensor) None[source]

Persist the next parameters and advance the round counter.

Subclasses commit their own optimiser state inside local_update(); this only handles the parameter state shared by every protocol.

Parameters:

parameters – The mixed models to persist as the next parameters.

compute_honest_worker_gradients(batches: Sequence[tuple[Tensor, Tensor]]) tuple[Tensor, Tensor][source]

Compute gradients at each honest worker’s current parameters.

Parameters:

batches – One batch per honest worker, in worker order.

Returns:
  • A tuple `` (gradients, losses)

  • per honest worker.

copy_parameters_to_model(parameters: Tensor) None[source]

Copy one flat parameter vector into the shared model wrapper.

Parameters:

parameters – Flat parameter vector of shape (d,) to load.

abstract gather_received_models(honest_vectors: Tensor, byzantine_parameters: Tensor, *, worker_index: int) Tensor[source]

Build the set of models received by one honest worker this round.

This is the communication-topology seam: each decentralised protocol decides which models (honest and Byzantine) land in a worker’s received set. Implementations should lead the set with the worker’s own model so a pivot-anchored aggregator can rely on its position.

Parameters:
  • honest_vectors – Post-local-update honest models, one row per worker.

  • byzantine_parameters – Byzantine models, shape (f, d), as produced by generate_byzantine_models() (may be empty or ignored by protocols that generate Byzantine replies per recipient).

  • worker_index – Index of the receiving honest worker.

Returns:

The received models for the worker, with its own model first.

generate_byzantine_models(local_parameters: Tensor) Tensor[source]

Generate the Byzantine model vectors injected into the mixing phase.

Called once per round before the per-worker mixing loop. The same Byzantine models are then placed into received sets by gather_received_models(). Protocols whose Byzantine replies depend on the recipient (e.g. recipient-specific attacks) override this to produce them inside gather_received_models instead.

Parameters:

local_parameters – Post-local-update honest models, one row per worker, passed to the attack.

Returns:

The Byzantine models, shape `` (f, d)

abstract local_update(gradients: Tensor) Tensor[source]

Compute the post-local-update parameters theta_{t+1/2}.

This is the optimisation seam: each protocol defines how a worker turns its gradient into its pre-mixing parameters, and owns any optimiser state (momentum, moments, …). Implementations should update that state in place so build_step_result() can snapshot it.

Parameters:

gradients – Stacked honest gradients, one row per worker.

Returns:

The post-local-update parameters, one row per honest worker.

run(rounds: int) list[StepResultT][source]

Execute several rounds and collect their snapshots.

State persists on the instance, so successive calls continue training from where the previous call left off.

Parameters:

rounds – Number of rounds to run; must be non-negative.

Returns:

One snapshot per round, in execution order.

Raises:

ValueError – If rounds is negative.

step() StepResultT[source]

Execute one decentralised training round.

Runs one local optimisation phase (local_update()) followed by one model-mixing phase over the per-worker received sets, then commits the resulting state and builds the snapshot.

Returns:

A snapshot dict of the round, as built by :meth:`build_step_result`.

class krum.simulations.decentralised.base.StepResult[source]

Bases: TypedDict

Snapshot returned by DecentralisedSimulation.step() for one round.

Holds only the fields common to every decentralised protocol. Subclasses whose local step keeps extra state (e.g. momentum) extend this TypedDict and bind it as the simulation’s StepResultT.

byzantine_parameters: Tensor
honest_gradients: Tensor
local_parameters: Tensor
losses: Tensor
mixed_parameters: Tensor
parameters: Tensor
step: int