"""This module contains utility functions related to the weighted_tensor module."""
from typing import Tuple, TypeVar, Union
import torch
from ._weighted_tensor import TensorOrWeightedTensor, WeightedTensor
__all__ = [
"expand_left",
"expand_right",
"unsqueeze_right",
"sum_dim",
"wsum_dim",
"wsum_dim_return_weighted_sum_only",
"wsum_dim_return_sum_of_weights_only",
]
S = TypeVar("S")
[docs]
def expand_left(
t: TensorOrWeightedTensor[S], *, shape: Tuple[int, ...]
) -> TensorOrWeightedTensor[S]:
"""Expand shape of tensor at left with provided shape."""
return t.expand(shape + t.shape)
[docs]
def expand_right(
t: TensorOrWeightedTensor[S], *, shape: Tuple[int, ...]
) -> TensorOrWeightedTensor[S]:
"""Expand shape of tensor at right with provided shape."""
return t.expand(t.shape + shape)
[docs]
def unsqueeze_right(
t: TensorOrWeightedTensor[S], *, ndim: int
) -> TensorOrWeightedTensor[S]:
"""
Adds `ndim` dimensions to tensor, from right-side, without
copy (useful for right broadcasting which is non-standard).
"""
# Nota: `unsqueeze_left` is useless since it is automatically done with standard broadcasting
assert isinstance(ndim, int) and ndim >= 0, f"Can not unsqueeze {ndim} dimensions"
if ndim == 0:
return t
return t.view(t.shape + (1,) * ndim)
[docs]
def sum_dim(
x: TensorOrWeightedTensor,
*,
fill_value=0,
dim: Union[None, int, Tuple[int, ...]] = None,
but_dim: Union[None, int, Tuple[int, ...]] = None,
**kws,
) -> torch.Tensor:
"""
Sum dimension(s) of provided tensor (regular or weighted -
filling with `fill_value` aggregates without any summed weighting if any).
"""
dim = _get_dim(x, dim=dim, but_dim=but_dim)
if isinstance(x, WeightedTensor):
return x.sum(fill_value=fill_value, dim=dim, **kws)
return x.sum(dim=dim, **kws)
[docs]
def wsum_dim(
x: WeightedTensor,
*,
fill_value=0,
dim: Union[None, int, Tuple[int, ...]] = None,
but_dim: Union[None, int, Tuple[int, ...]] = None,
**kws,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sum dimension(s) of provided weighted tensor.
The weighted tensor is filled with `fill_value` if provided.
The function returns the sum of weights as well.
Parameters
----------
x : WeightedTensor
The weighted tensor on which to compute the sum.
fill_value : float
The value to use for filling the weighted tensor.
dim : int or tuple of int, optional
The dimension(s) on which to sum.
but_dim : int or tuple of int, optional
The dimension(s) to omit when summing.
Returns
-------
weighted_sum : torch.Tensor
Weighted sum, with totally un-weighted aggregates filled with `fill_value`.
sum_weights : torch.Tensor
The sum of weights.
"""
dim = _get_dim(x, dim=dim, but_dim=but_dim)
return x.wsum(fill_value=fill_value, dim=dim, **kws)
[docs]
def wsum_dim_return_weighted_sum_only(
x: WeightedTensor,
*,
fill_value=0,
dim: Union[None, int, Tuple[int, ...]] = None,
but_dim: Union[None, int, Tuple[int, ...]] = None,
**kws,
) -> torch.Tensor:
"""
Sum dimension(s) of provided weighted tensor like `wsum_dim` but
only return the weighted sum and not the sum of weights.
Parameters
----------
x : WeightedTensor
The weighted tensor on which to compute the sum.
fill_value : float
The value to use for filling the weighted tensor.
dim : int or tuple of int, optional
The dimension(s) on which to sum.
but_dim : int or tuple of int, optional
The dimension(s) to omit when summing.
Returns
-------
torch.Tensor :
Weighted sum, with totally un-weighted aggregates filled with `fill_value`.
"""
return wsum_dim(x, fill_value=fill_value, dim=dim, but_dim=but_dim, **kws)[0]
[docs]
def wsum_dim_return_sum_of_weights_only(
x: WeightedTensor,
*,
fill_value=0,
dim: Union[None, int, Tuple[int, ...]] = None,
but_dim: Union[None, int, Tuple[int, ...]] = None,
**kws,
) -> torch.Tensor:
"""
Sum dimension(s) of provided weighted tensor like `wsum_dim` but
only return the sum of weights and not the weighted sum.
Parameters
----------
x : WeightedTensor
The weighted tensor on which to compute the sum.
fill_value : float
The value to use for filling the weighted tensor.
dim : int or tuple of int, optional
The dimension(s) on which to sum.
but_dim : int or tuple of int, optional
The dimension(s) to omit when summing.
Returns
-------
torch.Tensor :
The sum of weights.
"""
return wsum_dim(x, fill_value=fill_value, dim=dim, but_dim=but_dim, **kws)[1]
def _get_dim(
x: TensorOrWeightedTensor,
*,
dim: Union[None, int, Tuple[int, ...]] = None,
but_dim: Union[None, int, Tuple[int, ...]] = None,
) -> Union[int, Tuple[int, ...]]:
if (dim is not None) and (but_dim is not None):
raise ValueError("`dim` and `but_dim` should not be both defined.")
if but_dim is not None:
ndim = x.ndim
if isinstance(but_dim, int):
but_dim = {but_dim}
but_dim = {i if i >= 0 else ndim + i for i in but_dim}
assert all(i >= 0 for i in but_dim), but_dim
dim = tuple(i for i in range(ndim) if i not in but_dim)
elif dim is None:
# full summation by default
dim = ()
return dim