Skip to content

utils

utils ¤

csafelog = ComplexSafeLog.apply module-attribute ¤

safelog = SafeLog.apply module-attribute ¤

ComplexSafeLog ¤

Bases: Function

Source code in cirkit/backend/torch/utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ComplexSafeLog(autograd.Function):
    @staticmethod
    def forward(x: Tensor) -> Tensor:
        return torch.log(x)

    @staticmethod
    def setup_context(ctx: Any, inputs: tuple[Tensor, ...], output: Tensor) -> None:
        (x,) = inputs
        ctx.save_for_backward(x)

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tensor:
        (x,) = ctx.saved_tensors
        return torch.nan_to_num(grad_output / x.conj())

backward(ctx, grad_output) staticmethod ¤

Source code in cirkit/backend/torch/utils.py
38
39
40
41
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tensor:
    (x,) = ctx.saved_tensors
    return torch.nan_to_num(grad_output / x.conj())

forward(x) staticmethod ¤

Source code in cirkit/backend/torch/utils.py
29
30
31
@staticmethod
def forward(x: Tensor) -> Tensor:
    return torch.log(x)

setup_context(ctx, inputs, output) staticmethod ¤

Source code in cirkit/backend/torch/utils.py
33
34
35
36
@staticmethod
def setup_context(ctx: Any, inputs: tuple[Tensor, ...], output: Tensor) -> None:
    (x,) = inputs
    ctx.save_for_backward(x)

SafeLog ¤

Bases: Function

Source code in cirkit/backend/torch/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class SafeLog(autograd.Function):
    @staticmethod
    def forward(x: Tensor) -> Tensor:
        return torch.log(x)

    @staticmethod
    def setup_context(ctx: Any, inputs: tuple[Tensor, ...], output: Tensor) -> None:
        (x,) = inputs
        ctx.save_for_backward(x)

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tensor:
        (x,) = ctx.saved_tensors
        return torch.nan_to_num(grad_output / x)

backward(ctx, grad_output) staticmethod ¤

Source code in cirkit/backend/torch/utils.py
19
20
21
22
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tensor:
    (x,) = ctx.saved_tensors
    return torch.nan_to_num(grad_output / x)

forward(x) staticmethod ¤

Source code in cirkit/backend/torch/utils.py
10
11
12
@staticmethod
def forward(x: Tensor) -> Tensor:
    return torch.log(x)

setup_context(ctx, inputs, output) staticmethod ¤

Source code in cirkit/backend/torch/utils.py
14
15
16
17
@staticmethod
def setup_context(ctx: Any, inputs: tuple[Tensor, ...], output: Tensor) -> None:
    (x,) = inputs
    ctx.save_for_backward(x)

flatten_dims(x, /, *, dims) ¤

Flatten the given dims in the input.

If the dims are not continuous, they will be permuted and flattened to the position of the first element in dims.

Intended to be used as a helper for some torch functions that can only work on one dim.

Parameters:

Name Type Description Default
x Tensor

The tensor to be flattened.

required
dims Sequence[int]

The dimensions to flatten along, expected to be sorted.

required

Returns:

Name Type Description
Tensor Tensor

The flattened tensor.

Source code in cirkit/backend/torch/utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def flatten_dims(x: Tensor, /, *, dims: Sequence[int]) -> Tensor:
    """Flatten the given dims in the input.

    If the dims are not continuous, they will be permuted and flattened to the position of the \
    first element in dims.

    Intended to be used as a helper for some torch functions that can only work on one dim.

    Args:
        x (Tensor): The tensor to be flattened.
        dims (Sequence[int]): The dimensions to flatten along, expected to be sorted.

    Returns:
        Tensor: The flattened tensor.
    """
    if not dims:  # When dims[0] does not work.
        return x

    start_dim, end_dim = dims[0], dims[0] + len(dims)
    # Note that for flatten, end_dim is inclusive.
    return x.movedim(tuple(dims), tuple(range(start_dim, end_dim))).flatten(start_dim, end_dim - 1)

unflatten_dims(x, /, *, dims, shape) ¤

Unflatten the first dim in dims in the input to get a given shape.

This is the inverse transformation of flatten_dims, provided a correspondimg shape.

Parameters:

Name Type Description Default
x Tensor

The tensor to be unflattened.

required
dims Sequence[int]

The dimensions to unflatten to, should be the same as flatten_dims.

required
shape Sequence[int]

The shape to unflatten to, can be either the shape for dims, or the whole shape for the output. If the latter, the shape will not be checked for consistency outside dims.

required

Returns:

Name Type Description
Tensor Tensor

The unflattened tensor.

Source code in cirkit/backend/torch/utils.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def unflatten_dims(x: Tensor, /, *, dims: Sequence[int], shape: Sequence[int]) -> Tensor:
    """Unflatten the first dim in dims in the input to get a given shape.

    This is the inverse transformation of flatten_dims, provided a correspondimg shape.

    Args:
        x (Tensor): The tensor to be unflattened.
        dims (Sequence[int]): The dimensions to unflatten to, should be the same as flatten_dims.
        shape (Sequence[int]): The shape to unflatten to, can be either the shape for dims, or the \
            whole shape for the output. If the latter, the shape will not be checked for \
            consistency outside dims.

    Returns:
        Tensor: The unflattened tensor.
    """
    if not dims:  # When dims[0] does not work.
        return x

    # We require dims to be sorted so that there's no ambiguation in how shape is interpreted,
    # unless the shape itself never causes ambiguation.
    assert all(s == 1 for s in shape) or all(
        l < r for l, r in itertools.pairwise(dims)
    ), "dims must be sorted for unflatten_dims."

    if len(shape) == x.ndim - 1 + len(dims):  # The shape is for whole output.
        shape = [shape[d] for d in dims]
    # The shape is now for dims.

    start_dim, end_dim = dims[0], dims[0] + len(dims)
    # TODO: x.unflatten is not typed, must use torch.unflatten for now.
    return torch.unflatten(x, dim=start_dim, sizes=shape).movedim(
        tuple(range(start_dim, end_dim)), tuple(dims)
    )