Skip to content

semiring

semiring ¤

Semiring = type['SemiringImpl'] module-attribute ¤

ComplexLSESumSemiring ¤

Bases: SemiringImpl

The complex log space computation.

Source code in cirkit/backend/torch/semiring.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
@SemiringImpl.register("complex-lse-sum")
class ComplexLSESumSemiring(SemiringImpl):
    """The complex log space computation."""

    @classmethod
    def cast(cls, x: Tensor) -> Tensor:
        if x.is_complex():
            return x
        if x.is_floating_point():
            return x.to(x.dtype.to_complex())
        default_float_dtype = torch.get_default_dtype()
        return x.to(default_float_dtype.to_complex())

    @classmethod
    def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        return x.logsumexp(dim=dim, keepdim=keepdim)

    @classmethod
    def add(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.logaddexp, xs)

    @classmethod
    def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        return x.sum(dim=dim, keepdim=keepdim)

    @classmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.add, xs)

    @classmethod
    def apply_reduce(
        cls,
        func: EinsumFunc,
        *xs: Tensor,
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
        #       generators, because generators can't save much if we want to reuse.
        max_xs = [
            torch.clamp(
                torch.amax(xi.real, dim=dim, keepdim=True),
                min=torch.finfo(xi.real.dtype).min,
                max=torch.finfo(xi.real.dtype).max,
            )
            for xi in xs
        ]
        exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

        # NOTE: exp_x is not tuple, but list still can be unpacked with *.
        func_exp_xs = func(*cast(tuple[Tensor, ...], exp_xs))

        reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
        if not keepdim:
            reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.

        # Compute log(x) and its gradients safely where x is a complex tensor.
        # The problem is that if x = 0 + 0j, then the complex gradient of log(x) yields NaNs.
        # Note that for real non-monotonic circuits this problem cannot be avoided by simply
        # clipping the parameters of e.g., dense layers. In fact, even if we clipped the
        # parameters to be sufficiently far from zero here, cancellations would still arise
        # from negations, which in turn might result in under-flows. This has been observed in
        # float32 for squared non-monotonic PCs with real parameters.
        # To solve this issue, here we use a 'safe' version of the complex logarithm whose gradients
        # are replaced with zero if NaN and to the largest/lowest representable values if +inf/-inf.
        return csafelog(func_exp_xs) + reduced_max_xs

add(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
428
429
430
@classmethod
def add(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.logaddexp, xs)

apply_reduce(func, *xs, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
@classmethod
def apply_reduce(
    cls,
    func: EinsumFunc,
    *xs: Tensor,
    dim: int,
    keepdim: bool,
) -> Tensor:
    # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
    #       generators, because generators can't save much if we want to reuse.
    max_xs = [
        torch.clamp(
            torch.amax(xi.real, dim=dim, keepdim=True),
            min=torch.finfo(xi.real.dtype).min,
            max=torch.finfo(xi.real.dtype).max,
        )
        for xi in xs
    ]
    exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

    # NOTE: exp_x is not tuple, but list still can be unpacked with *.
    func_exp_xs = func(*cast(tuple[Tensor, ...], exp_xs))

    reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
    if not keepdim:
        reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.

    # Compute log(x) and its gradients safely where x is a complex tensor.
    # The problem is that if x = 0 + 0j, then the complex gradient of log(x) yields NaNs.
    # Note that for real non-monotonic circuits this problem cannot be avoided by simply
    # clipping the parameters of e.g., dense layers. In fact, even if we clipped the
    # parameters to be sufficiently far from zero here, cancellations would still arise
    # from negations, which in turn might result in under-flows. This has been observed in
    # float32 for squared non-monotonic PCs with real parameters.
    # To solve this issue, here we use a 'safe' version of the complex logarithm whose gradients
    # are replaced with zero if NaN and to the largest/lowest representable values if +inf/-inf.
    return csafelog(func_exp_xs) + reduced_max_xs

cast(x) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
415
416
417
418
419
420
421
422
@classmethod
def cast(cls, x: Tensor) -> Tensor:
    if x.is_complex():
        return x
    if x.is_floating_point():
        return x.to(x.dtype.to_complex())
    default_float_dtype = torch.get_default_dtype()
    return x.to(default_float_dtype.to_complex())

mul(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
436
437
438
@classmethod
def mul(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.add, xs)

prod(x, dim, *, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
432
433
434
@classmethod
def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    return x.sum(dim=dim, keepdim=keepdim)

sum(x, dim, *, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
424
425
426
@classmethod
def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    return x.logsumexp(dim=dim, keepdim=keepdim)

EinsumFunc ¤

Bases: Protocol

Source code in cirkit/backend/torch/semiring.py
15
16
class EinsumFunc(Protocol):
    def __call__(self, *xs: Tensor) -> Tensor: ...

__call__(*xs) ¤

Source code in cirkit/backend/torch/semiring.py
16
def __call__(self, *xs: Tensor) -> Tensor: ...

LSESumSemiring ¤

Bases: SemiringImpl

The log space computation.

Source code in cirkit/backend/torch/semiring.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
@SemiringImpl.register("lse-sum")
class LSESumSemiring(SemiringImpl):
    """The log space computation."""

    @classmethod
    def cast(cls, x: Tensor) -> Tensor:
        if x.is_floating_point():
            return x
        if not x.is_complex():
            default_float_dtype = torch.get_default_dtype()
            return x.to(default_float_dtype)
        raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

    @classmethod
    def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        return x.logsumexp(dim=dim, keepdim=keepdim)

    @classmethod
    def add(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.logaddexp, xs)

    @classmethod
    def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        return x.sum(dim=dim, keepdim=keepdim)

    @classmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.add, xs)

    @classmethod
    def apply_reduce(
        cls,
        func: EinsumFunc,
        *xs: Tensor,
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
        #       generators, because generators can't save much if we want to reuse.
        max_xs = [
            torch.clamp(
                torch.amax(xi, dim=dim, keepdim=True),
                min=torch.finfo(xi.dtype).min,
                max=torch.finfo(xi.dtype).max,
            )
            for xi in xs
        ]
        exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

        # NOTE: exp_x is not tuple, but list still can be unpacked with *.
        func_exp_xs = func(*cast(tuple[Tensor, ...], exp_xs))

        reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
        if not keepdim:
            reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.
        return torch.log(func_exp_xs) + reduced_max_xs

add(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
370
371
372
@classmethod
def add(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.logaddexp, xs)

apply_reduce(func, *xs, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
@classmethod
def apply_reduce(
    cls,
    func: EinsumFunc,
    *xs: Tensor,
    dim: int,
    keepdim: bool,
) -> Tensor:
    # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
    #       generators, because generators can't save much if we want to reuse.
    max_xs = [
        torch.clamp(
            torch.amax(xi, dim=dim, keepdim=True),
            min=torch.finfo(xi.dtype).min,
            max=torch.finfo(xi.dtype).max,
        )
        for xi in xs
    ]
    exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

    # NOTE: exp_x is not tuple, but list still can be unpacked with *.
    func_exp_xs = func(*cast(tuple[Tensor, ...], exp_xs))

    reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
    if not keepdim:
        reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.
    return torch.log(func_exp_xs) + reduced_max_xs

cast(x) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
357
358
359
360
361
362
363
364
@classmethod
def cast(cls, x: Tensor) -> Tensor:
    if x.is_floating_point():
        return x
    if not x.is_complex():
        default_float_dtype = torch.get_default_dtype()
        return x.to(default_float_dtype)
    raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

mul(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
378
379
380
@classmethod
def mul(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.add, xs)

prod(x, dim, *, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
374
375
376
@classmethod
def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    return x.sum(dim=dim, keepdim=keepdim)

sum(x, dim, *, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
366
367
368
@classmethod
def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    return x.logsumexp(dim=dim, keepdim=keepdim)

SemiringImpl ¤

Bases: ABC

The abstract base class for semiring implementations.

Due to numerical precision, the actual units in computational graph may hold values in, e.g., log space, instead of linear space. And therefore, this provides a unified interface for the computations so that computation can be done in a space suitable to the implementation regardless of the global setting.

Source code in cirkit/backend/torch/semiring.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
class SemiringImpl(ABC):
    """The abstract base class for semiring implementations.

    Due to numerical precision, the actual units in computational graph may hold values in, e.g., \
    log space, instead of linear space. And therefore, this provides a unified interface for the \
    computations so that computation can be done in a space suitable to the implementation \
    regardless of the global setting.
    """

    # A registry from semiring string identifiers to semiring class implementations
    _registry: ClassVar[dict[str, Semiring]] = {}

    # A registry of morphisms between semiring class implementations
    _registry_morphisms: ClassVar[dict[tuple[Semiring, Semiring], Callable[[Tensor], Tensor]]] = {}

    @staticmethod
    def register(name: str) -> Callable[[Semiring], Semiring]:
        """Register a concrete semiring implementation by its name.

        Args:
            name: The name to register.

        Returns:
            Callable[[Semiring], Semiring]: The class decorator to register a subclass.
        """

        def _decorator(cls: Semiring) -> Semiring:
            """Register a concrete semiring implementation by its name.

            Args:
                cls: The semiring subclass to register.

            Returns:
                Semiring: The class passed in.
            """
            SemiringImpl._registry[name] = cls
            return cls

        return _decorator

    @classmethod
    def register_map_from(
        cls, other: Semiring
    ) -> Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:
        """Register a concrete semiring morphism implementation.

        Args:
            other: The source semiring.

        Returns:
            Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:
            The function decorator to register the morphism.
        """

        def _decorator(func: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]:
            """Register a concrete semiring morphism implementation.

            Args:
                func: The morphism between semirings to register.

            Returns:
                Callable[[Tensor], Tensor]: The morphism passed in.
            """
            SemiringImpl._registry_morphisms[(other, cls)] = func
            return func

        return _decorator

    @staticmethod
    def list() -> Iterable[str]:
        """List all semiring names registered.

        Returns:
            Iterable[str]: An iterable over all names available.
        """
        return iter(SemiringImpl._registry)

    @staticmethod
    def from_name(name: str) -> Semiring:
        """Get a semiring by its registered name.

        Args:
            name (str): The name to probe.

        Returns:
            Semiring: The retrieved concrete Semiring.
        """
        if name not in SemiringImpl._registry:
            raise IndexError(
                f"Unknown semiring '{name}'."
                f" Use @SemiringImpl.register(<name>) to register a new semiring"
            )
        return SemiringImpl._registry[name]

    @classmethod
    def map_from(cls, x: Tensor, semiring: Semiring) -> Tensor:
        """Map a tensor from the given semiring to `this` semiring.

        Args:
            x:
            semiring:

        Returns:

        """
        if cls == semiring:
            return x
        func: Callable[[Tensor], Tensor] | None = SemiringImpl._registry_morphisms.get(
            (semiring, cls), None
        )
        if func is None:
            raise NotImplementedError(
                f"Semiring map from '{semiring.__name__}' to '{cls.__name__}' is not implemented"
            )
        return func(x)

    def __new__(cls) -> "SemiringImpl":
        """Raise an error when this class is instantiated.

        Raises:
            TypeError: When this class is instantiated.

        Returns:
            SemiringImpl: This method never returns.
        """
        raise TypeError("This class cannot be instantiated")

    @classmethod
    def einsum(
        cls,
        equation: str | Sequence[Sequence[int]],
        *,
        inputs: tuple[Tensor, ...] | None = None,
        operands: tuple[Tensor, ...] | None = None,
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        """Perform an einsum operation where sums and products are specified by the semiring.

        Args:
            equation: The einsum expression.
            inputs:  The inputs of the einsum.
            operands: Additional operands to pass to the einsum, after the inputs in the
                einsum expression.
            dim: The dimension of the inputs that get summed over in the einsum expression.
                This is useful to make the einsum computationally stable in some semirings,
                e.g., the log-sum-exp semiring.
            keepdim: Whether to keep the dimension that get summed over in the einsum
                expression.

        Returns:
            Tensor: the result of the einsum operation over the semiring.
        """
        # TODO: We need to remove this super general yet extremely complicated and hard
        #  to maintain einsum definition, which depends on the semiring. A future version of the
        #  compiler in cirkit will be able to emit pytorch code for every layer at compile time
        if inputs is None:
            inputs = ()
        if operands is None:
            operands = ()
        match equation:
            case str():

                def _einsum_str_func(*xs: Tensor) -> Tensor:
                    opds = tuple(cls.cast(opd) for opd in operands)
                    return torch.einsum(equation, *xs, *opds)

                einsum_func = _einsum_str_func
            case Sequence():

                def _einsum_seq_func(*xs: Tensor) -> Tensor:
                    opds = tuple(cls.cast(opd) for opd in operands)
                    einsum_args = tuple(
                        itertools.chain.from_iterable(zip(xs + opds, equation[:-1]))
                    )
                    return torch.einsum(*einsum_args, equation[-1])

                einsum_func = _einsum_seq_func
            case _:
                raise ValueError(
                    "The einsum expression must be either a string or a sequence of int sequences"
                )

        return cls.apply_reduce(einsum_func, *inputs, dim=dim, keepdim=keepdim)

    # NOTE: Subclasses should not touch any of the above final static methods but should implement
    #       all the following abstract class methods, and subclasses should be @final.

    @classmethod
    @abstractmethod
    def cast(cls, x: Tensor) -> Tensor:
        """Cast a tensor to the data type required by `this` semiring.

        Args:
            x: The tensor.

        Returns:
            Tensor: The tensor converted to the required data type.
        """

    @classmethod
    @abstractmethod
    def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        """

        Args:
            x:
            dim:
            keepdim:

        Returns:

        """

    @classmethod
    @abstractmethod
    def add(cls, *xs: Tensor) -> Tensor:
        """

        Args:
            *xs:

        Returns:

        """

    @classmethod
    @abstractmethod
    def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        """Do the product within a tensor on given dim(s).

        Args:
            x: The input tensor.
            dim: The dimension to reduce along.
            keepdim: Whether the dim is kept as a size-1 dim. Defaults to False.

        Returns:
            Tensor: The product result.
        """

    @classmethod
    @abstractmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        """Multiply broadcastable tensors.

        Args:
            *xs (Tensor): The input tensors, should have broadcastable shapes.

        Returns:
            Tensor: The multiply result.
        """

    @classmethod
    @abstractmethod
    def apply_reduce(
        cls,
        func: EinsumFunc,
        *xs: Tensor,
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        """Apply a sum-like functions to the tensor(s).

        The sum units may perform not just plain sum, but also weighted sum or even einsum. In \
        fact, it can possibly be any function that is linear to the each input. All that kind of \
        func can be used here.

        It is expected that func always does computation in the linear space, as with numerical \
        tricks, only relatively significant numbers contribute the final answer, and underflow \
        will not affect much. However, the input/output values may still be in another space, and \
        needs to be projected here.

        Args:
            func (Callable[[Unpack[Ts]], Tensor]): The sum-like function to be applied.
            *xs (Unpack[Ts]): The input tensors. Type expected to be Tensor.
            dim (int): The dimension along which the values are \
                correlated and must be scaled together, i.e., the dim to sum along. This should \
                match the actual operation done by func. The same dim is shared among all inputs.
            keepdim (bool): Whether the dim is kept as a size-1 dim, should match the actual \
                operation done by func.

        Returns:
            Tensor: The sum result.
        """

__new__() ¤

Raise an error when this class is instantiated.

Raises:

Type Description
TypeError

When this class is instantiated.

Returns:

Name Type Description
SemiringImpl SemiringImpl

This method never returns.

Source code in cirkit/backend/torch/semiring.py
135
136
137
138
139
140
141
142
143
144
def __new__(cls) -> "SemiringImpl":
    """Raise an error when this class is instantiated.

    Raises:
        TypeError: When this class is instantiated.

    Returns:
        SemiringImpl: This method never returns.
    """
    raise TypeError("This class cannot be instantiated")

add(*xs) abstractmethod classmethod ¤

Parameters:

Name Type Description Default
*xs Tensor
()

Returns:

Source code in cirkit/backend/torch/semiring.py
233
234
235
236
237
238
239
240
241
242
243
@classmethod
@abstractmethod
def add(cls, *xs: Tensor) -> Tensor:
    """

    Args:
        *xs:

    Returns:

    """

apply_reduce(func, *xs, dim, keepdim) abstractmethod classmethod ¤

Apply a sum-like functions to the tensor(s).

The sum units may perform not just plain sum, but also weighted sum or even einsum. In fact, it can possibly be any function that is linear to the each input. All that kind of func can be used here.

It is expected that func always does computation in the linear space, as with numerical tricks, only relatively significant numbers contribute the final answer, and underflow will not affect much. However, the input/output values may still be in another space, and needs to be projected here.

Parameters:

Name Type Description Default
func Callable[[Unpack[Ts]], Tensor]

The sum-like function to be applied.

required
*xs Unpack[Ts]

The input tensors. Type expected to be Tensor.

()
dim int

The dimension along which the values are correlated and must be scaled together, i.e., the dim to sum along. This should match the actual operation done by func. The same dim is shared among all inputs.

required
keepdim bool

Whether the dim is kept as a size-1 dim, should match the actual operation done by func.

required

Returns:

Name Type Description
Tensor Tensor

The sum result.

Source code in cirkit/backend/torch/semiring.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
@classmethod
@abstractmethod
def apply_reduce(
    cls,
    func: EinsumFunc,
    *xs: Tensor,
    dim: int,
    keepdim: bool,
) -> Tensor:
    """Apply a sum-like functions to the tensor(s).

    The sum units may perform not just plain sum, but also weighted sum or even einsum. In \
    fact, it can possibly be any function that is linear to the each input. All that kind of \
    func can be used here.

    It is expected that func always does computation in the linear space, as with numerical \
    tricks, only relatively significant numbers contribute the final answer, and underflow \
    will not affect much. However, the input/output values may still be in another space, and \
    needs to be projected here.

    Args:
        func (Callable[[Unpack[Ts]], Tensor]): The sum-like function to be applied.
        *xs (Unpack[Ts]): The input tensors. Type expected to be Tensor.
        dim (int): The dimension along which the values are \
            correlated and must be scaled together, i.e., the dim to sum along. This should \
            match the actual operation done by func. The same dim is shared among all inputs.
        keepdim (bool): Whether the dim is kept as a size-1 dim, should match the actual \
            operation done by func.

    Returns:
        Tensor: The sum result.
    """

cast(x) abstractmethod classmethod ¤

Cast a tensor to the data type required by this semiring.

Parameters:

Name Type Description Default
x Tensor

The tensor.

required

Returns:

Name Type Description
Tensor Tensor

The tensor converted to the required data type.

Source code in cirkit/backend/torch/semiring.py
207
208
209
210
211
212
213
214
215
216
217
@classmethod
@abstractmethod
def cast(cls, x: Tensor) -> Tensor:
    """Cast a tensor to the data type required by `this` semiring.

    Args:
        x: The tensor.

    Returns:
        Tensor: The tensor converted to the required data type.
    """

einsum(equation, *, inputs=None, operands=None, dim, keepdim) classmethod ¤

Perform an einsum operation where sums and products are specified by the semiring.

Parameters:

Name Type Description Default
equation str | Sequence[Sequence[int]]

The einsum expression.

required
inputs tuple[Tensor, ...] | None

The inputs of the einsum.

None
operands tuple[Tensor, ...] | None

Additional operands to pass to the einsum, after the inputs in the einsum expression.

None
dim int

The dimension of the inputs that get summed over in the einsum expression. This is useful to make the einsum computationally stable in some semirings, e.g., the log-sum-exp semiring.

required
keepdim bool

Whether to keep the dimension that get summed over in the einsum expression.

required

Returns:

Name Type Description
Tensor Tensor

the result of the einsum operation over the semiring.

Source code in cirkit/backend/torch/semiring.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
@classmethod
def einsum(
    cls,
    equation: str | Sequence[Sequence[int]],
    *,
    inputs: tuple[Tensor, ...] | None = None,
    operands: tuple[Tensor, ...] | None = None,
    dim: int,
    keepdim: bool,
) -> Tensor:
    """Perform an einsum operation where sums and products are specified by the semiring.

    Args:
        equation: The einsum expression.
        inputs:  The inputs of the einsum.
        operands: Additional operands to pass to the einsum, after the inputs in the
            einsum expression.
        dim: The dimension of the inputs that get summed over in the einsum expression.
            This is useful to make the einsum computationally stable in some semirings,
            e.g., the log-sum-exp semiring.
        keepdim: Whether to keep the dimension that get summed over in the einsum
            expression.

    Returns:
        Tensor: the result of the einsum operation over the semiring.
    """
    # TODO: We need to remove this super general yet extremely complicated and hard
    #  to maintain einsum definition, which depends on the semiring. A future version of the
    #  compiler in cirkit will be able to emit pytorch code for every layer at compile time
    if inputs is None:
        inputs = ()
    if operands is None:
        operands = ()
    match equation:
        case str():

            def _einsum_str_func(*xs: Tensor) -> Tensor:
                opds = tuple(cls.cast(opd) for opd in operands)
                return torch.einsum(equation, *xs, *opds)

            einsum_func = _einsum_str_func
        case Sequence():

            def _einsum_seq_func(*xs: Tensor) -> Tensor:
                opds = tuple(cls.cast(opd) for opd in operands)
                einsum_args = tuple(
                    itertools.chain.from_iterable(zip(xs + opds, equation[:-1]))
                )
                return torch.einsum(*einsum_args, equation[-1])

            einsum_func = _einsum_seq_func
        case _:
            raise ValueError(
                "The einsum expression must be either a string or a sequence of int sequences"
            )

    return cls.apply_reduce(einsum_func, *inputs, dim=dim, keepdim=keepdim)

from_name(name) staticmethod ¤

Get a semiring by its registered name.

Parameters:

Name Type Description Default
name str

The name to probe.

required

Returns:

Name Type Description
Semiring Semiring

The retrieved concrete Semiring.

Source code in cirkit/backend/torch/semiring.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@staticmethod
def from_name(name: str) -> Semiring:
    """Get a semiring by its registered name.

    Args:
        name (str): The name to probe.

    Returns:
        Semiring: The retrieved concrete Semiring.
    """
    if name not in SemiringImpl._registry:
        raise IndexError(
            f"Unknown semiring '{name}'."
            f" Use @SemiringImpl.register(<name>) to register a new semiring"
        )
    return SemiringImpl._registry[name]

list() staticmethod ¤

List all semiring names registered.

Returns:

Type Description
Iterable[str]

Iterable[str]: An iterable over all names available.

Source code in cirkit/backend/torch/semiring.py
87
88
89
90
91
92
93
94
@staticmethod
def list() -> Iterable[str]:
    """List all semiring names registered.

    Returns:
        Iterable[str]: An iterable over all names available.
    """
    return iter(SemiringImpl._registry)

map_from(x, semiring) classmethod ¤

Map a tensor from the given semiring to this semiring.

Parameters:

Name Type Description Default
x Tensor
required
semiring Semiring
required

Returns:

Source code in cirkit/backend/torch/semiring.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def map_from(cls, x: Tensor, semiring: Semiring) -> Tensor:
    """Map a tensor from the given semiring to `this` semiring.

    Args:
        x:
        semiring:

    Returns:

    """
    if cls == semiring:
        return x
    func: Callable[[Tensor], Tensor] | None = SemiringImpl._registry_morphisms.get(
        (semiring, cls), None
    )
    if func is None:
        raise NotImplementedError(
            f"Semiring map from '{semiring.__name__}' to '{cls.__name__}' is not implemented"
        )
    return func(x)

mul(*xs) abstractmethod classmethod ¤

Multiply broadcastable tensors.

Parameters:

Name Type Description Default
*xs Tensor

The input tensors, should have broadcastable shapes.

()

Returns:

Name Type Description
Tensor Tensor

The multiply result.

Source code in cirkit/backend/torch/semiring.py
259
260
261
262
263
264
265
266
267
268
269
@classmethod
@abstractmethod
def mul(cls, *xs: Tensor) -> Tensor:
    """Multiply broadcastable tensors.

    Args:
        *xs (Tensor): The input tensors, should have broadcastable shapes.

    Returns:
        Tensor: The multiply result.
    """

prod(x, dim, *, keepdim=False) abstractmethod classmethod ¤

Do the product within a tensor on given dim(s).

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
dim int

The dimension to reduce along.

required
keepdim bool

Whether the dim is kept as a size-1 dim. Defaults to False.

False

Returns:

Name Type Description
Tensor Tensor

The product result.

Source code in cirkit/backend/torch/semiring.py
245
246
247
248
249
250
251
252
253
254
255
256
257
@classmethod
@abstractmethod
def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    """Do the product within a tensor on given dim(s).

    Args:
        x: The input tensor.
        dim: The dimension to reduce along.
        keepdim: Whether the dim is kept as a size-1 dim. Defaults to False.

    Returns:
        Tensor: The product result.
    """

register(name) staticmethod ¤

Register a concrete semiring implementation by its name.

Parameters:

Name Type Description Default
name str

The name to register.

required

Returns:

Type Description
Callable[[Semiring], Semiring]

Callable[[Semiring], Semiring]: The class decorator to register a subclass.

Source code in cirkit/backend/torch/semiring.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@staticmethod
def register(name: str) -> Callable[[Semiring], Semiring]:
    """Register a concrete semiring implementation by its name.

    Args:
        name: The name to register.

    Returns:
        Callable[[Semiring], Semiring]: The class decorator to register a subclass.
    """

    def _decorator(cls: Semiring) -> Semiring:
        """Register a concrete semiring implementation by its name.

        Args:
            cls: The semiring subclass to register.

        Returns:
            Semiring: The class passed in.
        """
        SemiringImpl._registry[name] = cls
        return cls

    return _decorator

register_map_from(other) classmethod ¤

Register a concrete semiring morphism implementation.

Parameters:

Name Type Description Default
other Semiring

The source semiring.

required

Returns:

Type Description
Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]

Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:

Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]

The function decorator to register the morphism.

Source code in cirkit/backend/torch/semiring.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@classmethod
def register_map_from(
    cls, other: Semiring
) -> Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:
    """Register a concrete semiring morphism implementation.

    Args:
        other: The source semiring.

    Returns:
        Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:
        The function decorator to register the morphism.
    """

    def _decorator(func: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]:
        """Register a concrete semiring morphism implementation.

        Args:
            func: The morphism between semirings to register.

        Returns:
            Callable[[Tensor], Tensor]: The morphism passed in.
        """
        SemiringImpl._registry_morphisms[(other, cls)] = func
        return func

    return _decorator

sum(x, dim, *, keepdim=False) abstractmethod classmethod ¤

Parameters:

Name Type Description Default
x Tensor
required
dim int
required
keepdim bool
False

Returns:

Source code in cirkit/backend/torch/semiring.py
219
220
221
222
223
224
225
226
227
228
229
230
231
@classmethod
@abstractmethod
def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    """

    Args:
        x:
        dim:
        keepdim:

    Returns:

    """

SumProductSemiring ¤

Bases: SemiringImpl

The linear space computation.

Source code in cirkit/backend/torch/semiring.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
@SemiringImpl.register("sum-product")
class SumProductSemiring(SemiringImpl):
    """The linear space computation."""

    @classmethod
    def cast(cls, x: Tensor) -> Tensor:
        """Cast a tensor to the data type required by the semiring.

        Args:
            x: The tensor.

        Returns:
            Tensor: The tensor converted to the required data type.
        """
        if x.is_floating_point():
            return x
        if not x.is_complex():
            default_float_dtype = torch.get_default_dtype()
            return x.to(default_float_dtype)
        raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

    @classmethod
    def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        return x.sum(dim=dim, keepdim=keepdim)

    @classmethod
    def add(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.add, xs)

    @classmethod
    def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
        return torch.prod(x, dim=dim, keepdim=keepdim)

    @classmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.mul, xs)

    @classmethod
    def apply_reduce(
        cls,
        func: EinsumFunc,
        *xs: Tensor,
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        return func(*xs)

add(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
330
331
332
@classmethod
def add(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.add, xs)

apply_reduce(func, *xs, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
342
343
344
345
346
347
348
349
350
@classmethod
def apply_reduce(
    cls,
    func: EinsumFunc,
    *xs: Tensor,
    dim: int,
    keepdim: bool,
) -> Tensor:
    return func(*xs)

cast(x) classmethod ¤

Cast a tensor to the data type required by the semiring.

Parameters:

Name Type Description Default
x Tensor

The tensor.

required

Returns:

Name Type Description
Tensor Tensor

The tensor converted to the required data type.

Source code in cirkit/backend/torch/semiring.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@classmethod
def cast(cls, x: Tensor) -> Tensor:
    """Cast a tensor to the data type required by the semiring.

    Args:
        x: The tensor.

    Returns:
        Tensor: The tensor converted to the required data type.
    """
    if x.is_floating_point():
        return x
    if not x.is_complex():
        default_float_dtype = torch.get_default_dtype()
        return x.to(default_float_dtype)
    raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

mul(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
338
339
340
@classmethod
def mul(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.mul, xs)

prod(x, dim, *, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
334
335
336
@classmethod
def prod(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    return torch.prod(x, dim=dim, keepdim=keepdim)

sum(x, dim, *, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
326
327
328
@classmethod
def sum(cls, x: Tensor, dim: int, *, keepdim: bool = False) -> Tensor:
    return x.sum(dim=dim, keepdim=keepdim)

_(x) ¤

Source code in cirkit/backend/torch/semiring.py
512
513
514
@ComplexLSESumSemiring.register_map_from(LSESumSemiring)
def _(x: Tensor) -> Tensor:
    return ComplexLSESumSemiring.cast(x)