PyTorch Utilities¶
Helper functions for PyTorch tensor manipulation and gradient handling.
This module provides helper functions for PyTorch tensor manipulation, gradient handling, and common operations used throughout Krum.
Functions¶
Memory Management:
relink: Make tensors point to a contiguous memory segmentflatten: Flatten tensors into a single contiguous tensor
Gradient Operations:
grad_of: Get or create gradient for a tensorgrads_of: Generator version for multiple tensors
Statistics:
compute_avg_dev_max: Compute mean, std, norm stats
Time Measurement:
AccumulatedTimedContext: Accumulated timing with optional CUDA sync
Utilities:
weighted_mse_loss: Weighted MSE loss for experimentsregression: Generic optimization for free variablespnm: Export tensor to PGM/PBM format
Example:¶
import torch
from tools import flatten, relink
# Flatten model parameters
params = list(model.parameters())
flat_params = flatten(params)
# Relink gradients to same memory
grads = [p.grad for p in params]
flat_grads = flatten(grads)
- class tools.pytorch.AccumulatedTimedContext(sync: bool | float = False)[source]¶
Bases:
objectAccumulated timed context manager with optional CUDA synchronization.
This context manager measures elapsed time across multiple entries, with optional CUDA synchronization to ensure accurate GPU timing.
- Parameters:
sync (bool, optional) – Whether to synchronize CUDA before and after timing. Defaults to
False.Example
-------
torch (>>> import)
AccumulatedTimedContext (>>> from tools import)
AccumulatedTimedContext(sync=True) (>>> atc =)
atc (>>> with)
here (... # GPU operations)
pass (...)
print(atc.current_runtime()) (>>>)
- class tools.pytorch.WeightedMSELoss[source]¶
Bases:
ModuleWeighted MSE loss module.
This module wraps
weighted_mse_loss()as a PyTorch module.- forward(input: Tensor, target: Tensor, weight: Tensor) Tensor[source]¶
Compute weighted MSE loss.
- Parameters:
input (torch.Tensor) – Input tensor.
target (torch.Tensor) – Target tensor.
weight (torch.Tensor) – Weight tensor.
Returns
-------
torch.Tensor – Weighted MSE loss value.
- tools.pytorch.compute_avg_dev_max(samples: list[Tensor]) tuple[Tensor | None, float, float, float][source]¶
Compute average, average norm, norm deviation, and max absolute value.
- Parameters:
samples (list of torch.Tensor) – List of tensors to compute statistics on.
Returns
-------
tuple[torch.Tensor – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.
float – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.
float – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.
float] – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.
Notes
-----
tensor. (The returned tensor is newly created and does not alias any input)
- tools.pytorch.flatten(tensors: list[Tensor]) Tensor[source]¶
Flatten tensors into a single contiguous tensor.
- Parameters:
tensors (list of torch.Tensor) – Tensors to flatten. All must have the same dtype.
Returns
-------
torch.Tensor – Flat tensor containing all data from input tensors, stored in a contiguous memory segment.
Notes
-----
Modifications (The returned tensor shares memory with the original tensors.)
tensors. (to the flat tensor will reflect in the original)
Example
-------
torch (>>> import)
flatten (>>> from tools import)
torch.tensor([1. (>>> t1 =)
2.])
torch.tensor([3. (>>> t2 =)
4.
5.])
flatten([t1 (>>> flat =)
t2])
tensor([1.
2.
3.
4.
5.])
- tools.pytorch.grad_of(tensor: Tensor) Tensor[source]¶
Get the gradient of a given tensor, create zero gradient if missing.
- Parameters:
tensor (torch.Tensor) – A tensor that may have a gradient attached.
Returns
-------
torch.Tensor – The gradient tensor. If none existed, a zero gradient is created and attached to the tensor.
Example
-------
torch (>>> import)
grad_of (>>> from tools import)
torch.randn(3 (>>> x =)
requires_grad=True)
x.sum() (>>> y =)
y.backward() (>>>)
grad_of(x) (>>> grad =)
- tools.pytorch.grads_of(tensors: list[Tensor])[source]¶
Generator that gets or creates gradients for multiple tensors.
- Parameters:
tensors (list of torch.Tensor) – Tensors that may have gradients attached.
Yields
------
torch.Tensor – Gradient for each tensor.
Example
-------
torch (>>> import)
grads_of (>>> from tools import)
[torch.randn(3 (>>> params =)
range(2)] (requires_grad=True) for _ in)
params) (>>> loss = sum(p.sum() for p in)
loss.backward() (>>>)
grads_of(params) (>>> for g in)
print(g) (...)
tensor([1.
1.
1.])
tensor([1.
1.
1.])
- tools.pytorch.pnm(fd: BufferedWriter, tn: Tensor) None[source]¶
Export tensor to PGM/PBM format.
- Parameters:
fd (io.BufferedWriter) – File descriptor to write to.
tn (torch.Tensor) – Tensor to export. Supports float32/float64 for grayscale (PGM) or boolean/integer for binary (PBM).
Notes
-----
(PGM) (- Grayscale format)
(PBM) (- Binary format)
- tools.pytorch.regression(func: Callable[[Tensor, dict], Tensor], vars, data, loss=None, opt=None, steps=1000) float[source]¶
Generic optimization for free variables.
- Parameters:
func (callable) – Function to optimize. Takes variables and data dictionary as arguments.
vars (list) – List of variables to optimize.
data (dict) – Data dictionary, must contain a
"target"key.loss (torch.nn.Module, optional) – Loss function. Defaults to
torch.nn.MSELoss.opt (torch.optim.Optimizer, optional) – Optimizer. Defaults to
torch.optim.Adam.steps (int, optional) – Number of optimization steps. Defaults to 1000.
Returns
-------
float – Final loss value after optimization.
- tools.pytorch.relink(tensors: list[Tensor], common: Tensor) Tensor[source]¶
Relink tensors to share a common contiguous memory storage.
- Parameters:
tensors (list of torch.Tensor) – Tensors to relink. All must have the same dtype.
common (torch.Tensor) – Flat tensor of sufficient size to use as underlying storage. Must have the same dtype as the given tensors.
Returns
-------
torch.Tensor – The common tensor, with
linked_tensorsattribute set.Notes
-----
the (The returned tensor has a linked_tensors attribute pointing to)
simultaneously. (original tensors. This allows updating all tensors)
Example
-------
torch (>>> import)
relink (>>> from tools import)
torch.tensor([1. (>>> t1 =)
2.])
torch.tensor([3. (>>> t2 =)
4.
5.])
torch.zeros(5) (>>> common =)
relink([t1 (>>>)
t2]
common)
tensor([1.
2.
3.
4.
5.])
- tools.pytorch.weighted_mse_loss(input: Tensor, target: Tensor, weight: Tensor) Tensor[source]¶
Compute weighted mean squared error loss.
- Parameters:
input (torch.Tensor) – Input tensor.
target (torch.Tensor) – Target tensor.
weight (torch.Tensor) – Weight tensor for each element.
Returns
-------
torch.Tensor – Weighted MSE loss value.
Notes
-----
tensor. (The returned tensor is newly created and does not alias any input)
See also
For general utilities (parsing, timing, registries), see Miscellaneous Utilities. For experiment job management, see Jobs Utilities.