Nearest Neighbor Average¶
Nearest-neighbor averaging aggregation rule.
- Reference:
Sadegh Farhadkhani, Rachid Guerraoui, Nirupam Gupta, Lê Nguyên Hoang, Rafael Pinot, and John Stephan. “Robust Collaborative Learning with Linear Gradient Overhead.” In Proceedings of the 40th International Conference on Machine Learning (ICML 2023).
- class aggregators.nearest_neighbor_average.NearestNeighborAverage[source]¶
Bases:
AggregatorNearest-neighbor averaging aggregation rule.
The rule keeps the
num_closestvectors with smallest Euclidean distance to the pivot, then returns their mean. Bothnum_closestand the pivot belong to a single aggregation call, so they are passed toaggregate()rather than stored on the (stateless) aggregator.How many vectors to keep is a caller policy (e.g.
n - ffor plain nearest-neighbor averaging,n - 2ffor MoNNA-style model mixing); this rule only needs the resulting count, notnorf.- classmethod aggregate(gradients: Sequence[Tensor] | Tensor, /, out: Tensor | None = None, *, num_closest: int, pivot: Tensor, **specialized: Any) Tensor[source]¶
Average the
num_closestvectors nearest to the pivot.- Parameters:
gradients – Sequence of
mcandidate vectors, each of shape(d,).out – Optional pre-allocated tensor to write the result into.
num_closest – Number of nearest vectors to average.
pivot – Tensor of shape
(d,)used as the distance reference.**specialized – Additional keyword arguments.
- Returns:
Mean of the ``num_closest`` closest vectors, shape `` (d,)
- Raises:
ValueError – If
num_closestis not positive, fewer thannum_closestcandidates are supplied, or the pivot shape is wrong.
See also
This rule is used by the MoNNA simulation as the per-worker nearest-neighbor averaging step.