CentralisedSimulation

Parameter-server distributed SGD simulation confronting Byzantine workers.

Each synchronous round:
  1. Honest workers compute a gradient on their local data shard.

  2. Byzantine workers craft adversarial gradients.

  3. The aggregator combines all \(n\) gradients into a single update.

  4. The aggregated update is applied via an SGD step.

The CentralisedSimulation implements the full lifecycle (model initialisation, data sharding, training loop, evaluation). Protocol-specific metric reporting is configured via the evaluate_fn constructor parameter rather than subclassing — evaluation is delegated to a caller-supplied callable, following composition over inheritance.

class krum.simulations.centralised.base.CentralisedSimulation(*, model_cls: type[~torch.nn.modules.module.Module], train_set: ~torch.utils.data.dataset.Dataset[~typing.Any], test_set: ~torch.utils.data.dataset.Dataset[~typing.Any], aggregator: type[~krum.primitives.aggregators.Aggregator] | None = None, aggregator_kwargs: dict[str, ~typing.Any] | None = None, attack: type[~krum.primitives.attacks.Attack] | None = None, attack_kwargs: dict[str, ~typing.Any] | None = None, n: int, f: int, rounds: int, batch_size: int, lr: float, lr_schedule: ~typing.Literal['exponential', 'robbins_monro', 'none'] = 'exponential', lr_decay: float | None = 0.99, r_eta: float | None = None, weight_decay: float = 0.0, xavier_init: bool = False, stop_attack_at: int | None = None, aggregator_f: int | None = None, loss_fn: ~typing.Callable[[...], ~torch.Tensor] = <function cross_entropy>, device: ~torch.device | None = None, seed: int = 42, eval_every: int = 10, evaluate_fn: ~typing.Callable[[~krum.simulations.centralised.base.CentralisedSimulation], ~typing.Any] | None = None)[source]

Bases: object

Parameter-server distributed SGD simulation with Byzantine workers.

One instance = one (aggregator, attack, dataset, model) configuration run over rounds synchronous rounds. The training set is IID-sharded across \(n\) workers, of which \(f\) are Byzantine (up to the tolerance of the chosen aggregator).

Evaluation follows composition over inheritance: the caller supplies an evaluate_fn callable (e.g. one of the built-in evaluate_test_error_and_loss() or evaluate_full()) that receives the simulation instance and returns protocol-specific metrics.

Parameters:
  • model_clsnn.Module subclass to instantiate for training.

  • train_set – Full training dataset (will be IID-sharded across workers).

  • test_set – Test dataset (evaluated via full-batch loader).

  • aggregator – Gradient aggregation rule class (e.g. Average, Krum). Pass the class itself — step() calls aggregator.aggregate(gradients, n=n, f=f, **aggregator_kwargs).

  • aggregator_kwargs – Extra keyword arguments forwarded to aggregator.aggregate (e.g. {"m": 12} for MultiKrum). n and f are automatically injected by the simulation.

  • attack – Byzantine attack strategy class (e.g. GaussianAttack). Pass the class itself — step() calls attack.generate(honest_gradients, f=f, **attack_kwargs).

  • attack_kwargs – Extra keyword arguments forwarded to attack.generate (e.g. {"std": 200.0} for GaussianAttack).

  • n – Total number of workers.

  • f – Number of Byzantine workers. Must be within the aggregator’s Byzantine tolerance.

  • rounds – Number of synchronous training rounds.

  • batch_size – Mini-batch size per honest worker.

  • lr – Initial learning rate for SGD. Used as \(η_0\) by every supported scheduler.

  • lr_schedule – Learning-rate schedule applied after each round. "exponential" uses multiplicative decay lr_decay per round (the default for the ICML 2018 protocol). "robbins_monro" uses the \(η(t) = r_η · η_0 / (t + r_η)\) schedule of El Mhamdi et al. (ICML 2018), with fading rate r_eta. "none" keeps a constant learning rate (the default for the NIPS 2017 protocol).

  • lr_decay – Multiplicative learning-rate decay per round, used when lr_schedule == "exponential". None disables the scheduler in that mode. Default: 0.99.

  • r_eta – Fading rate for the Robbins-Monro schedule \(η(t) = r_η · η_0 / (t + r_η)\). Required when lr_schedule == "robbins_monro".

  • weight_decay\(ℓ_2\) regularization coefficient applied directly on the flat parameter tensor. Set to 0.0 (default) to disable. Section 5.1 of El Mhamdi et al. (ICML 2018) recommends 1e-4.

  • xavier_init – When True, apply Glorot/Xavier uniform initialization to every weight tensor of the model after construction, and zero-initialize biases. Matches Section 5.1 of El Mhamdi et al. (ICML 2018).

  • stop_attack_at – If set, the Byzantine attack is disabled at round t = stop_attack_at: the f Byzantine workers then send zero gradients. Used by Experiment 1 / Figure 2 of El Mhamdi et al. (ICML 2018), where the attack is maintained only up to round 50.

  • loss_fn – Per-sample loss function. Default: cross_entropy.

  • device – Device for training and evaluation. Auto-detected if None (CUDA → MPS → CPU).

  • seed – Random seed for reproducibility.

  • eval_every – Evaluate on the test set every eval_every rounds (and always on the last round).

  • evaluate_fn – Callable that receives the simulation instance and returns evaluation metrics. Built-in options include evaluate_test_error(), evaluate_test_error_and_loss(), and evaluate_full(). Custom evaluators must accept a single CentralisedSimulation argument.

Raises:
  • RuntimeError – If run() is called more than once on the same instance.

  • ValueError – If lr_schedule is invalid or required hyperparameters (\(r_eta\)) are missing.

evaluate() Any[source]

Compute evaluation metrics by delegating to evaluate_fn.

The evaluator callable was supplied at construction time. Built-in options include evaluate_test_error(), evaluate_test_error_and_loss(), and evaluate_full().

Returns:

Scalar or tuple of evaluation metrics, as defined by the evaluator.

evaluate_full() dict[str, float][source]

Compute training loss and test accuracy/loss on full datasets.

Returns:

Dict with ``”train_loss”``, ``”test_accuracy”``, ``”test_loss”``.

Raises:

RuntimeError – If setup() has not been called.

evaluate_test_error() float[source]

Compute misclassification error rate on the held-out test set.

Returns:

Error rate in ``[0, 1]``.

Raises:

RuntimeError – If setup() has not been called.

evaluate_test_error_and_loss() dict[str, float][source]

Compute misclassification error rate and cross-entropy loss on the test set.

Returns:

Dict with ``”test_error”`` and ``”test_loss”``.

Raises:

RuntimeError – If setup() has not been called.

property model: Model

The encapsulated Model, available after setup() or run().

Returns:

The wrapped ``nn.Module`` with zero-copy flat parameter/gradient views.

Raises:

RuntimeError – If the simulation has not been set up yet.

run() list[tuple[int, Any]][source]

Run the simulation to completion.

Calls setup(), then loops over step() and evaluate() every eval_every rounds (always evaluating on the last round).

Returns:
  • List of `` (round, …)

  • of :meth:`evaluate` (a scalar, tuple, or dict)

Raises:

RuntimeError – If run() has already been called.

setup() None[source]

Initialise the model, learning rate, IID data shards, and dataloaders.

The training set is evenly split into n shards (the remaining len(train_set) % n samples are dropped). Each worker receives a dedicated DataLoader with its own RNG generator, so mini-batch sampling is reproducible across runs.

The learning rate is initialised to self.lr. The Robbins-Monro schedule updates it each round inside step(); the exponential schedule decays it after every round via self._current_lr *= lr_decay.

When xavier_init is enabled, every weight tensor of the instantiated model is re-initialized with the Glorot/Xavier uniform rule, and every bias is reset to zero.

Safe to call multiple times — each call resets all internal state.

step() None[source]

Advance the simulation by one synchronous round.

  1. The learning rate is updated to the value of the current round (only for the Robbins-Monro schedule; the exponential schedule updates the rate after the optimizer step instead).

  2. Each of the \(n - f\) honest workers computes a gradient on its local data shard via _train_one_worker().

  3. If \(f > 0\) and the attack has not been stopped, Byzantine workers generate attack gradients. For

    FullGradientNegationAttack,

    the full-dataset honest gradient is computed first.

  4. The aggregator combines all \(n\) gradients into a single update via self.aggregator.aggregate(...).

  5. The aggregated gradient is written to self._model.gradients and applied via an in-place SGD update on the flat parameter tensor.

  6. The learning rate (if scheduled) is decayed.

Raises:

RuntimeError – If setup() has not been called.