Tools

Utility modules for Krum providing infrastructure, tensor operations, and job management.

Overview

Tools

Module

Description

Key Functions

tools

Core utilities (logging, exceptions, parsing)

Context, UserException, parse_keyval

tools.pytorch

PyTorch helpers (tensor operations, gradients)

flatten, relink, compute_avg_dev_max

tools.misc

Miscellaneous utilities (registries, timing)

pairwise, line_maximize, fullqual

tools.jobs

Job management for experiments

Command, Jobs, dict_to_cmdlist

Available Tools

API Reference

Core Utility Module for Krum.

This module provides the fundamental infrastructure utilities used throughout Krum, including logging, error handling, and common operations.

Key Components

Exceptions:

  • UserException: Base exception for user-facing errors

  • Context: Thread-local context for colored logging

Logging:

  • info(), success(), warning(), error(): Colored logging functions

  • fatal(): Print error and exit

I/O:

  • ContextIOWrapper: Wrapper for stdout/stderr with context prefixing

Module Loading:

  • import_directory(): Load all Python modules from a directory

  • import_exported_symbols(): Import symbols from a module

Utilities:

  • parse_keyval(): Parse key:value CLI arguments

  • fullqual(): Get fully qualified name of objects

  • onetime(): Thread-safe one-time flag

class tools.AccumulatedTimedContext(sync: bool | float = False)[source]

Bases: object

Accumulated timed context manager with optional CUDA synchronization.

This context manager measures elapsed time across multiple entries, with optional CUDA synchronization to ensure accurate GPU timing.

Parameters:
  • sync (bool, optional) – Whether to synchronize CUDA before and after timing. Defaults to False.

  • Example

  • -------

  • torch (>>> import)

  • AccumulatedTimedContext (>>> from tools import)

  • AccumulatedTimedContext(sync=True) (>>> atc =)

  • atc (>>> with)

  • here (... # GPU operations)

  • pass (...)

  • print(atc.current_runtime()) (>>>)

current_runtime() float[source]

Return the accumulated runtime.

Returns:

float

Total accumulated time in seconds.

class tools.ClassRegister(singular: str, optplural: str | None = None)[source]

Bases: object

Minimal registry mapping user-visible names to classes.

instantiate(name: str, *args, **kwargs) object[source]

Instantiate the class registered under name.

Parameters:
  • name (str) – Registered class name.

  • *args (object) – Positional arguments forwarded to the class constructor.

  • **kwargs (object) – Keyword arguments forwarded to the class constructor.

  • Returns

  • -------

  • object – Instance of the registered class.

  • Raises

  • ------

  • UserException – If name is not registered.

itemize() list[str][source]

Return the registered class names.

register(name: str, cls: type) None[source]

Register a class under a user-visible name.

Parameters:
  • name (str) – Name used to retrieve the class.

  • cls (type) – Class associated with name.

class tools.Command(command: Iterable[str])[source]

Bases: object

Command wrapper that adds standard runtime arguments.

This class wraps a base command and automatically appends seed, device, and result-directory arguments when building the final command line.

Parameters:
  • command (iterable of str) – Base command as an iterable of strings (e.g. ["python", "train.py"]). The iterable is copied on instantiation.

  • Attributes

  • ----------

  • _basecmd (list of str) – Internal copy of the base command.

build(seed: int | str, device: str, resdir: Path | str) list[str][source]

Build the final command line.

Parameters:
  • seed (int or str) – Seed to use for the experiment.

  • device (str) – Device on which to run the experiment (e.g. "cuda:0").

  • resdir (pathlib.Path or str) – Target directory path for results.

  • Returns

  • -------

  • str (list of) – Final command list ready to be passed to subprocess.run.

class tools.Context(cntxtname: str | None, colorname: str | None)[source]

Bases: object

Per-thread logging context and color manager.

class tools.ContextIOWrapper(output: TextIO, nocolor: bool | None = None)[source]

Bases: object

Context-aware text I/O wrapper.

write(text: str) int[source]

Write text with the active context prefix and color.

Parameters:
  • text (str) – Text to write.

  • Returns

  • -------

  • int – Return value forwarded from the wrapped stream’s write method.

class tools.Jobs(res_dir: Path | str, devices: list[str] | None = None, devmult: int = 1, seeds: Sequence[int] | None = None)[source]

Bases: object

Job execution manager for parallel experiments.

Manages parallel execution of experiments across multiple devices, with support for result tracking and error handling.

Parameters:
  • res_dir (pathlib.Path or str) – Directory to store results.

  • devices (list of str, optional) – List of device names (e.g. ["cuda:0", "cuda:1"]). Defaults to ["cpu"] if none specified.

  • devmult (int, optional) – Number of parallel jobs per device. Default is 1.

  • seeds (sequence of int, optional) – Seeds to use for repeating experiments. Default is range(1, 6).

  • Attributes

  • ----------

  • _res_dir (pathlib.Path) – Resolved result directory.

  • _jobs (list of tuple or None) – Pending job queue as (name, seed, command) tuples, or None when the manager has been closed.

  • _workers (list of threading.Thread) – Worker thread pool, one entry per active slot.

  • _devices (list of str) – Devices used for execution.

  • _seeds (tuple of int) – Seeds used for repeating experiments.

  • _lock (threading.Lock) – Main lock protecting shared state.

  • _cvready (threading.Condition) – Condition variable to signal that new jobs are available or that workers must shut down.

  • _cvdone (threading.Condition) – Condition variable to signal that all submitted jobs have been processed.

close() None[source]

Close and wait for the worker pool, discarding not-yet-started submissions.

get_seeds() tuple[int, ...][source]

Get the list of seeds used for repeating the experiments.

Returns:

tuple of int

Seeds used by this manager.

submit(name: str, command: Command) None[source]

Submit a job for execution.

The job is repeated for every seed configured in the manager.

Parameters:
  • name (str) – Job identifier.

  • command (Command) – Command builder to execute.

  • Raises

  • ------

  • RuntimeError – If the manager has already been closed.

wait(predicate: Callable[[], bool] | None = None) None[source]

Wait for all the submitted jobs to be processed.

Parameters:

predicate (callable returning bool, optional) – Optional custom predicate. If provided, waiting stops when the predicate returns True.

class tools.MethodCallReplicator(*args: object)[source]

Bases: object

Proxy that replicates method calls across multiple instances.

Accessing an attribute returns a callable that invokes the same named attribute on each bound instance, in order, and returns the list of results.

class tools.TimedContext(*args, **kwargs)[source]

Bases: Context

Context manager that logs the elapsed runtime of a block.

exception tools.UnavailableException(*args, **kwargs)[source]

Bases: UserException

User-facing exception raised when a selected registry entry is missing.

exception tools.UserException[source]

Bases: Exception

Base exception for user-facing errors.

class tools.WeightedMSELoss[source]

Bases: Module

Weighted MSE loss module.

This module wraps weighted_mse_loss() as a PyTorch module.

forward(input: Tensor, target: Tensor, weight: Tensor) Tensor[source]

Compute weighted MSE loss.

Parameters:
  • input (torch.Tensor) – Input tensor.

  • target (torch.Tensor) – Target tensor.

  • weight (torch.Tensor) – Weight tensor.

  • Returns

  • -------

  • torch.Tensor – Weighted MSE loss value.

tools.compute_avg_dev_max(samples: list[Tensor]) tuple[Tensor | None, float, float, float][source]

Compute average, average norm, norm deviation, and max absolute value.

Parameters:
  • samples (list of torch.Tensor) – List of tensors to compute statistics on.

  • Returns

  • -------

  • tuple[torch.Tensor – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • float – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • float – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • float] – Tuple containing: average tensor, average norm, norm deviation, and max absolute value.

  • Notes

  • -----

  • tensor. (The returned tensor is newly created and does not alias any input)

tools.deltatime_format(a: int, b: int) tuple[int, str][source]

Compute and format elapsed time between two captured points.

Parameters:
  • a (int) – Earlier point returned by deltatime_point().

  • b (int) – Later point returned by deltatime_point().

  • Returns

  • -------

  • tuple[int – Tuple (seconds, text) containing elapsed seconds and a human-readable duration string.

  • str] – Tuple (seconds, text) containing elapsed seconds and a human-readable duration string.

tools.deltatime_point() int[source]

Capture an opaque point in monotonic time.

Returns:

int

Monotonic timestamp rounded to seconds. The value is intended for use with deltatime_format().

tools.dict_to_cmdlist(dp: dict[str, Any]) list[str][source]

Convert a dictionary into command-line arguments.

This helper is useful for turning experiment configurations into CLI arguments.

Parameters:
  • dp (dict of str to Any) – Dictionary mapping parameter names to values.

  • Returns

  • -------

  • str (list of) – Command-line arguments such as ["--lr", "0.01", "--batch", "32"].

  • Notes

  • -----

  • True. (- Boolean values are included only when they are)

  • pairs. (- Lists and tuples expand to repeated --name value)

  • Example

  • -------

  • dict_to_cmdlist({"lr" (>>>)

  • ['--lr'

  • '0.01'

  • '--batch'

  • '32'

  • '--debug']

  • dict_to_cmdlist({"layers" (>>>)

  • ['--layers'

  • '64'

  • '--layers'

  • '128']

tools.error(*args, context: str | None = None, **kwargs) object[source]

Print inside the configured colored context.

Parameters:
  • *args (object) – Positional arguments forwarded to print().

  • context (str or None, optional) – Context name to use while printing.

  • **kwargs (object) – Keyword arguments forwarded to print().

  • Returns

  • -------

  • object – Return value forwarded from print().

tools.fatal(*args, with_traceback: bool = False, **kwargs) None[source]

Print an error message and terminate the process with exit code 1.

Parameters:
  • *args (object) – Positional arguments forwarded to error().

  • with_traceback (bool, optional) – Whether to include the current traceback after the message.

  • **kwargs (object) – Keyword arguments forwarded to error().

tools.fatal_unavailable(*args, **kwargs) None[source]

Report an unavailable entry as a fatal user-facing error.

Parameters:
  • *args (str) – Positional arguments forwarded to make_unavailable_exception_text().

  • **kwargs (str) – Keyword arguments forwarded to make_unavailable_exception_text().

tools.flatten(tensors: list[Tensor]) Tensor[source]

Flatten tensors into a single contiguous tensor.

Parameters:
  • tensors (list of torch.Tensor) – Tensors to flatten. All must have the same dtype.

  • Returns

  • -------

  • torch.Tensor – Flat tensor containing all data from input tensors, stored in a contiguous memory segment.

  • Notes

  • -----

  • Modifications (The returned tensor shares memory with the original tensors.)

  • tensors. (to the flat tensor will reflect in the original)

  • Example

  • -------

  • torch (>>> import)

  • flatten (>>> from tools import)

  • torch.tensor([1. (>>> t1 =)

  • 2.])

  • torch.tensor([3. (>>> t2 =)

  • 4.

  • 5.])

  • flatten([t1 (>>> flat =)

  • t2])

  • tensor([1.

  • 2.

  • 3.

  • 4.

  • 5.])

tools.fullqual(obj: object) str[source]

Return a class or instance’s fully qualified name.

Parameters:
  • obj (object) – Class or instance to describe.

  • Returns

  • -------

  • str – Fully qualified class name. Instances are prefixed with "instance of ".

  • Example

  • -------

  • fullqual(str) (>>>)

  • 'builtins.str'

  • fullqual(pathlib.Path(".")) (>>>)

  • pathlib.PosixPath' ('instance of)

tools.get_loaded_dependencies() list[tuple[str, str | None, int]][source]

List currently loaded non-built-in root modules.

Returns:

list[tuple[str, str | None, int]]

Tuples of (root_module_name, version, flavor). version is the module’s __version__ attribute when present, otherwise None. flavor is one of IS_STANDARD, IS_SITE, or IS_LOCAL.

Raises:

RuntimeError

If Python’s site-packages locations cannot be discovered on the current platform.

tools.grad_of(tensor: Tensor) Tensor[source]

Get the gradient of a given tensor, create zero gradient if missing.

Parameters:
  • tensor (torch.Tensor) – A tensor that may have a gradient attached.

  • Returns

  • -------

  • torch.Tensor – The gradient tensor. If none existed, a zero gradient is created and attached to the tensor.

  • Example

  • -------

  • torch (>>> import)

  • grad_of (>>> from tools import)

  • torch.randn(3 (>>> x =)

  • requires_grad=True)

  • x.sum() (>>> y =)

  • y.backward() (>>>)

  • grad_of(x) (>>> grad =)

tools.grads_of(tensors: list[Tensor])[source]

Generator that gets or creates gradients for multiple tensors.

Parameters:
  • tensors (list of torch.Tensor) – Tensors that may have gradients attached.

  • Yields

  • ------

  • torch.Tensor – Gradient for each tensor.

  • Example

  • -------

  • torch (>>> import)

  • grads_of (>>> from tools import)

  • [torch.randn(3 (>>> params =)

  • range(2)] (requires_grad=True) for _ in)

  • params) (>>> loss = sum(p.sum() for p in)

  • loss.backward() (>>>)

  • grads_of(params) (>>> for g in)

  • print(g) (...)

  • tensor([1.

  • 1.

  • 1.])

  • tensor([1.

  • 1.

  • 1.])

tools.import_directory(dirpath: ~pathlib.Path, scope: dict, post: ~typing.Callable[[...], ~typing.Any] | None = <function import_exported_symbols>, ignore: list[str] | None = None) None[source]

Import every Python module from a directory into a target scope.

Parameters:
  • dirpath (pathlib.Path) – Directory containing modules to import.

  • scope (dict) – Target scope used for imports and post-processing.

  • post (object, optional) – Post-import callback with signature (name, module, scope) -> None.

  • ignore (list[str], optional) – Module names to ignore.

tools.import_exported_symbols(name: str, module, scope: dict) None[source]

Import a module’s exported symbols into a target scope.

Parameters:
  • name (str) – Source module name.

  • module (module) – Loaded module instance.

  • scope (dict) – Target scope to update with exported symbols.

tools.info(*args, context: str | None = None, **kwargs) object[source]

Print inside the configured colored context.

Parameters:
  • *args (object) – Positional arguments forwarded to print().

  • context (str or None, optional) – Context name to use while printing.

  • **kwargs (object) – Keyword arguments forwarded to print().

  • Returns

  • -------

  • object – Return value forwarded from print().

tools.interactive(glbs: dict[str, object] | None = None, lcls: dict[str, object] | None = None, prompt: str = '>>> ', cprmpt: str = '... ') None[source]

Run a small interactive Python prompt.

Press Ctrl+D or send an equivalent EOF signal to leave the prompt.

Parameters:
  • glbs (dict[str, object] | None, optional) – Globals dictionary used when evaluating commands. If None, the caller’s globals are used when available.

  • lcls (dict[str, object] | None, optional) – Locals dictionary used when evaluating commands. If None, the caller’s locals are used when available, otherwise glbs is used.

  • prompt (str, optional) – Prompt displayed for a new command.

  • cprmpt (str, optional) – Prompt displayed while continuing a multi-line command.

tools.line_maximize(scape: Callable[[...], Any], evals: int = 16, start: float = 0.0, delta: float = 1.0, ratio: float = 0.8) float[source]

Best-effort argmax search for a scalar function on non-negative inputs.

The search first expands while values improve, then contracts the step size to refine the best point found within the evaluation budget.

Parameters:
  • scape (callable) – Function to maximize. It is called with non-negative float values and must return comparable scores.

  • evals (int, optional) – Maximum number of function evaluations.

  • start (float, optional) – Initial non-negative point to evaluate.

  • delta (float, optional) – Initial positive step size.

  • ratio (float, optional) – Step contraction ratio, expected to be between 0.5 and 1.0 excluded.

  • Returns

  • -------

  • float – Best point found under the evaluation budget.

tools.localtime() str[source]

Return the current local time formatted for logs.

Returns:

str

Local time as YYYY/MM/DD HH:MM:SS.

tools.onetime(name: str | None = None) tuple[Callable[[...], Any], Callable[[...], Any]][source]

Create or retrieve a thread-safe one-shot flag.

Parameters:
  • name (str | None, optional) – Optional global flag name. Reusing the same name returns the same getter/setter pair.

  • Returns

  • -------

  • tuple[callable(getter, setter) pair. getter returns whether the flag has been set, and setter permanently sets it to True.

  • callable](getter, setter) pair. getter returns whether the flag has been set, and setter permanently sets it to True.

tools.pairwise(data: list | tuple)[source]

Yield unordered pairs from an indexable collection.

Parameters:
  • data (list | tuple) – Indexable collection such as a list or tuple.

  • Yields

  • ------

  • tuple – Tuples (data[i], data[j]) for every i < j.

  • Example

  • -------

  • list(pairwise([1 (>>>)

  • 2

  • 3]))

  • [(1

  • 2)

  • (1

  • 3)

  • (2

  • 3)]

  • list(pairwise("ab")) (>>>)

  • [('a'

  • 'b')]

tools.parse_keyval(list_keyval: list[str], defaults: dict[str, object] | None = None) dict[str, object][source]

Parse <key>:<value> strings into a typed dictionary.

This helper is used for command-line options such as --gar-args lr:0.01. Keys present in defaults are converted to the type of their default value; other keys are converted by parse_keyval_auto_convert().

Parameters:
  • list_keyval (list[str]) – Entries formatted as <key>:<value>.

  • defaults (dict[str, object] | None, optional) – Default key/value mappings. These defaults are also used for type inference and are copied into the returned dictionary when the corresponding key is not explicitly provided.

  • Returns

  • -------

  • dict[str – Parsed key/value pairs with converted values.

  • object] – Parsed key/value pairs with converted values.

  • Raises

  • ------

  • UserException – If an entry is malformed, a key is provided more than once, or conversion to a default value’s type fails.

  • Example

  • -------

  • parse_keyval(["lr (>>>)

  • {'lr' (0.01, 'batch': 32})

  • parse_keyval(["debug (>>>)

  • {'debug' (True, 'workers': 4})

tools.pnm(fd: BufferedWriter, tn: Tensor) None[source]

Export tensor to PGM/PBM format.

Parameters:
  • fd (io.BufferedWriter) – File descriptor to write to.

  • tn (torch.Tensor) – Tensor to export. Supports float32/float64 for grayscale (PGM) or boolean/integer for binary (PBM).

  • Notes

  • -----

  • (PGM) (- Grayscale format)

  • (PBM) (- Binary format)

tools.regression(func: Callable[[Tensor, dict], Tensor], vars, data, loss=None, opt=None, steps=1000) float[source]

Generic optimization for free variables.

Parameters:
  • func (callable) – Function to optimize. Takes variables and data dictionary as arguments.

  • vars (list) – List of variables to optimize.

  • data (dict) – Data dictionary, must contain a "target" key.

  • loss (torch.nn.Module, optional) – Loss function. Defaults to torch.nn.MSELoss.

  • opt (torch.optim.Optimizer, optional) – Optimizer. Defaults to torch.optim.Adam.

  • steps (int, optional) – Number of optimization steps. Defaults to 1000.

  • Returns

  • -------

  • float – Final loss value after optimization.

tools.relink(tensors: list[Tensor], common: Tensor) Tensor[source]

Relink tensors to share a common contiguous memory storage.

Parameters:
  • tensors (list of torch.Tensor) – Tensors to relink. All must have the same dtype.

  • common (torch.Tensor) – Flat tensor of sufficient size to use as underlying storage. Must have the same dtype as the given tensors.

  • Returns

  • -------

  • torch.Tensor – The common tensor, with linked_tensors attribute set.

  • Notes

  • -----

  • the (The returned tensor has a linked_tensors attribute pointing to)

  • simultaneously. (original tensors. This allows updating all tensors)

  • Example

  • -------

  • torch (>>> import)

  • relink (>>> from tools import)

  • torch.tensor([1. (>>> t1 =)

  • 2.])

  • torch.tensor([3. (>>> t2 =)

  • 4.

  • 5.])

  • torch.zeros(5) (>>> common =)

  • relink([t1 (>>>)

  • t2]

  • common)

  • tensor([1.

  • 2.

  • 3.

  • 4.

  • 5.])

tools.success(*args, context: str | None = None, **kwargs) object[source]

Print inside the configured colored context.

Parameters:
  • *args (object) – Positional arguments forwarded to print().

  • context (str or None, optional) – Context name to use while printing.

  • **kwargs (object) – Keyword arguments forwarded to print().

  • Returns

  • -------

  • object – Return value forwarded from print().

tools.trace(*args, context: str | None = None, **kwargs) object[source]

Print inside the configured colored context.

Parameters:
  • *args (object) – Positional arguments forwarded to print().

  • context (str or None, optional) – Context name to use while printing.

  • **kwargs (object) – Keyword arguments forwarded to print().

  • Returns

  • -------

  • object – Return value forwarded from print().

tools.uncaught_wrap(hook: Callable[[...], Any]) Callable[[...], Any][source]

Wrap an uncaught exception hook with contextual logging.

Parameters:
  • hook (object) – Uncaught exception hook to wrap.

  • Returns

  • -------

  • object – Wrapped uncaught exception hook.

tools.warning(*args, context: str | None = None, **kwargs) object[source]

Print inside the configured colored context.

Parameters:
  • *args (object) – Positional arguments forwarded to print().

  • context (str or None, optional) – Context name to use while printing.

  • **kwargs (object) – Keyword arguments forwarded to print().

  • Returns

  • -------

  • object – Return value forwarded from print().

tools.weighted_mse_loss(input: Tensor, target: Tensor, weight: Tensor) Tensor[source]

Compute weighted mean squared error loss.

Parameters:
  • input (torch.Tensor) – Input tensor.

  • target (torch.Tensor) – Target tensor.

  • weight (torch.Tensor) – Weight tensor for each element.

  • Returns

  • -------

  • torch.Tensor – Weighted MSE loss value.

  • Notes

  • -----

  • tensor. (The returned tensor is newly created and does not alias any input)