PyTorch Utilities

Helper functions for PyTorch tensor manipulation and gradient handling.

This module provides helper functions for PyTorch tensor manipulation, gradient handling, and common operations used throughout Krum.

Functions

Memory Management:

  • relink: Make tensors point to a contiguous memory segment

  • flatten: Flatten tensors into a single contiguous tensor

Gradient Operations:

  • grad_of: Get or create gradient for a tensor

  • grads_of: Generator version for multiple tensors

Statistics:

  • compute_avg_dev_max: Compute mean, std, norm stats

Time Measurement:

  • AccumulatedTimedContext: Accumulated timing with optional CUDA sync

Utilities:

  • weighted_mse_loss: Weighted MSE loss for experiments

  • regression: Generic optimization for free variables

  • pnm: Export tensor to PGM/PBM format

Example:

import torch
from tools import flatten, relink

# Flatten model parameters
params = list(model.parameters())
flat_params = flatten(params)

# Relink gradients to same memory
grads = [p.grad for p in params]
flat_grads = flatten(grads)
class tools.pytorch.AccumulatedTimedContext(sync: bool | float = False)[source]

Bases: object

Accumulated timed context manager with optional CUDA synchronization.

This context manager measures elapsed time across multiple entries, with optional CUDA synchronization to ensure accurate GPU timing.

Parameters:
  • sync (bool, optional) – Whether to synchronize CUDA before and after timing. Defaults to False.

  • Example

  • -------

  • torch (>>> import)

  • AccumulatedTimedContext (>>> from tools import)

  • AccumulatedTimedContext(sync=True) (>>> atc =)

  • atc (>>> with)

  • here (... # GPU operations)

  • pass (...)

  • print(atc.current_runtime()) (>>>)

current_runtime() float[source]

Return the accumulated runtime.

Returns:

float

Total accumulated time in seconds.

class tools.pytorch.WeightedMSELoss[source]

Bases: Module

Weighted MSE loss module.

This module wraps weighted_mse_loss() as a PyTorch module.

forward(input: Tensor, target: Tensor, weight: Tensor) Tensor[source]

Compute weighted MSE loss.

Parameters:
  • input (torch.Tensor) – Input tensor.

  • target (torch.Tensor) – Target tensor.

  • weight (torch.Tensor) – Weight tensor.

  • Returns

  • -------

  • torch.Tensor – Weighted MSE loss value.

tools.pytorch.compute_avg_dev_max(samples: list[Tensor]) tuple[Tensor | None, float, float, float][source]

Compute average, average norm, norm deviation, and max absolute value.

Parameters:
  • samples (list of torch.Tensor) – List of tensors to compute statistics on.

  • Returns

  • -------

  • tuple[torch.Tensor – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • float – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • float – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • float] – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • Notes

  • -----

  • tensor. (The returned tensor is newly created and does not alias any input)

tools.pytorch.flatten(tensors: list[Tensor]) Tensor[source]

Flatten tensors into a single contiguous tensor.

Parameters:
  • tensors (list of torch.Tensor) – Tensors to flatten. All must have the same dtype.

  • Returns

  • -------

  • torch.Tensor – Flat tensor containing all data from input tensors, stored in a contiguous memory segment.

  • Notes

  • -----

  • Modifications (The returned tensor shares memory with the original tensors.)

  • tensors. (to the flat tensor will reflect in the original)

  • Example

  • -------

  • torch (>>> import)

  • flatten (>>> from tools import)

  • torch.tensor([1. (>>> t1 =)

  • 2.])

  • torch.tensor([3. (>>> t2 =)

  • 4.

  • 5.])

  • flatten([t1 (>>> flat =)

  • t2])

  • tensor([1.

  • 2.

  • 3.

  • 4.

  • 5.])

tools.pytorch.grad_of(tensor: Tensor) Tensor[source]

Get the gradient of a given tensor, create zero gradient if missing.

Parameters:
  • tensor (torch.Tensor) – A tensor that may have a gradient attached.

  • Returns

  • -------

  • torch.Tensor – The gradient tensor. If none existed, a zero gradient is created and attached to the tensor.

  • Example

  • -------

  • torch (>>> import)

  • grad_of (>>> from tools import)

  • torch.randn(3 (>>> x =)

  • requires_grad=True)

  • x.sum() (>>> y =)

  • y.backward() (>>>)

  • grad_of(x) (>>> grad =)

tools.pytorch.grads_of(tensors: list[Tensor])[source]

Generator that gets or creates gradients for multiple tensors.

Parameters:
  • tensors (list of torch.Tensor) – Tensors that may have gradients attached.

  • Yields

  • ------

  • torch.Tensor – Gradient for each tensor.

  • Example

  • -------

  • torch (>>> import)

  • grads_of (>>> from tools import)

  • [torch.randn(3 (>>> params =)

  • range(2)] (requires_grad=True) for _ in)

  • params) (>>> loss = sum(p.sum() for p in)

  • loss.backward() (>>>)

  • grads_of(params) (>>> for g in)

  • print(g) (...)

  • tensor([1.

  • 1.

  • 1.])

  • tensor([1.

  • 1.

  • 1.])

tools.pytorch.pnm(fd: BufferedWriter, tn: Tensor) None[source]

Export tensor to PGM/PBM format.

Parameters:
  • fd (io.BufferedWriter) – File descriptor to write to.

  • tn (torch.Tensor) – Tensor to export. Supports float32/float64 for grayscale (PGM) or boolean/integer for binary (PBM).

  • Notes

  • -----

  • (PGM) (- Grayscale format)

  • (PBM) (- Binary format)

tools.pytorch.regression(func: Callable[[Tensor, dict], Tensor], vars, data, loss=None, opt=None, steps=1000) float[source]

Generic optimization for free variables.

Parameters:
  • func (callable) – Function to optimize. Takes variables and data dictionary as arguments.

  • vars (list) – List of variables to optimize.

  • data (dict) – Data dictionary, must contain a "target" key.

  • loss (torch.nn.Module, optional) – Loss function. Defaults to torch.nn.MSELoss.

  • opt (torch.optim.Optimizer, optional) – Optimizer. Defaults to torch.optim.Adam.

  • steps (int, optional) – Number of optimization steps. Defaults to 1000.

  • Returns

  • -------

  • float – Final loss value after optimization.

Relink tensors to share a common contiguous memory storage.

Parameters:
  • tensors (list of torch.Tensor) – Tensors to relink. All must have the same dtype.

  • common (torch.Tensor) – Flat tensor of sufficient size to use as underlying storage. Must have the same dtype as the given tensors.

  • Returns

  • -------

  • torch.Tensor – The common tensor, with linked_tensors attribute set.

  • Notes

  • -----

  • the (The returned tensor has a linked_tensors attribute pointing to)

  • simultaneously. (original tensors. This allows updating all tensors)

  • Example

  • -------

  • torch (>>> import)

  • relink (>>> from tools import)

  • torch.tensor([1. (>>> t1 =)

  • 2.])

  • torch.tensor([3. (>>> t2 =)

  • 4.

  • 5.])

  • torch.zeros(5) (>>> common =)

  • relink([t1 (>>>)

  • t2]

  • common)

  • tensor([1.

  • 2.

  • 3.

  • 4.

  • 5.])

tools.pytorch.weighted_mse_loss(input: Tensor, target: Tensor, weight: Tensor) Tensor[source]

Compute weighted mean squared error loss.

Parameters:
  • input (torch.Tensor) – Input tensor.

  • target (torch.Tensor) – Target tensor.

  • weight (torch.Tensor) – Weight tensor for each element.

  • Returns

  • -------

  • torch.Tensor – Weighted MSE loss value.

  • Notes

  • -----

  • tensor. (The returned tensor is newly created and does not alias any input)

See also

For general utilities (parsing, timing, registries), see Miscellaneous Utilities. For experiment job management, see Jobs Utilities.