Skip to content

semiring

semiring ¤

Semiring = type[SemiringImpl] module-attribute ¤

SemiringT = TypeVar('SemiringT', bound=type['SemiringImpl']) module-attribute ¤

Ts = TypeVarTuple('Ts') module-attribute ¤

ComplexLSESumSemiring ¤

Bases: SemiringImpl

The complex log space computation.

Source code in cirkit/backend/torch/semiring.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
@SemiringImpl.register("complex-lse-sum")
@final  # type: ignore[misc]
class ComplexLSESumSemiring(SemiringImpl):
    """The complex log space computation."""

    @classmethod
    def cast(cls, x: Tensor) -> Tensor:
        if x.is_complex():
            return x
        if x.is_floating_point():
            return x.to(x.dtype.to_complex())
        default_float_dtype = torch.get_default_dtype()
        return x.to(default_float_dtype.to_complex())

    @classmethod
    def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        return x.logsumexp(dim=dim, keepdim=keepdim)

    @classmethod
    def add(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.logaddexp, xs)

    @classmethod
    def prod(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        return x.sum(dim=dim, keepdim=keepdim)

    @classmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.add, xs)

    @classmethod
    def apply_reduce(
        cls,
        func: Callable[[Unpack[Ts]], Tensor],
        *xs: Unpack[Ts],
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
        #       generators, because generators can't save much if we want to reuse.
        # CAST: Expected tuple of Tensor but got Ts.
        xs = [cast(Tensor, xi) for xi in xs]
        max_xs = [
            torch.clamp(
                torch.amax(xi.real, dim=dim, keepdim=True),
                min=torch.finfo(xi.real.dtype).min,
                max=torch.finfo(xi.real.dtype).max,
            )
            for xi in xs
        ]
        exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

        # NOTE: exp_x is not tuple, but list still can be unpacked with *.
        # CAST: Expected Ts but got tuple (actually list) of Tensor.
        func_exp_xs = func(*cast(tuple[Unpack[Ts]], exp_xs))

        # TODO: verify the behavior of reduce under torch.compile
        reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
        if not keepdim:
            reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.

        # Compute log(x) and its gradients safely where x is a complex tensor.
        #
        # The problem is that if x = 0 + 0j, then the complex gradient of log(x) yields NaNs.
        # Note that for real non-monotonic circuits this problem cannot be avoided by simply
        # clipping the parameters of e.g., dense layers. In fact, even if we clipped the parameters
        # to be sufficiently far from zero here, cancellations would still arise from negations, which
        # in turn might result in under-flows. This has been observed in float32 for squared
        # non-monotonic PCs with real parameters.
        #
        # To solve this issue, here we use a 'safe' version of the complex logarithm whose gradients
        # are replaced with zero if NaN and to the largest/lowest representable values if +inf/-inf.
        #
        return csafelog(func_exp_xs) + reduced_max_xs

add(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
413
414
415
@classmethod
def add(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.logaddexp, xs)

apply_reduce(func, *xs, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
@classmethod
def apply_reduce(
    cls,
    func: Callable[[Unpack[Ts]], Tensor],
    *xs: Unpack[Ts],
    dim: int,
    keepdim: bool,
) -> Tensor:
    # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
    #       generators, because generators can't save much if we want to reuse.
    # CAST: Expected tuple of Tensor but got Ts.
    xs = [cast(Tensor, xi) for xi in xs]
    max_xs = [
        torch.clamp(
            torch.amax(xi.real, dim=dim, keepdim=True),
            min=torch.finfo(xi.real.dtype).min,
            max=torch.finfo(xi.real.dtype).max,
        )
        for xi in xs
    ]
    exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

    # NOTE: exp_x is not tuple, but list still can be unpacked with *.
    # CAST: Expected Ts but got tuple (actually list) of Tensor.
    func_exp_xs = func(*cast(tuple[Unpack[Ts]], exp_xs))

    # TODO: verify the behavior of reduce under torch.compile
    reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
    if not keepdim:
        reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.

    # Compute log(x) and its gradients safely where x is a complex tensor.
    #
    # The problem is that if x = 0 + 0j, then the complex gradient of log(x) yields NaNs.
    # Note that for real non-monotonic circuits this problem cannot be avoided by simply
    # clipping the parameters of e.g., dense layers. In fact, even if we clipped the parameters
    # to be sufficiently far from zero here, cancellations would still arise from negations, which
    # in turn might result in under-flows. This has been observed in float32 for squared
    # non-monotonic PCs with real parameters.
    #
    # To solve this issue, here we use a 'safe' version of the complex logarithm whose gradients
    # are replaced with zero if NaN and to the largest/lowest representable values if +inf/-inf.
    #
    return csafelog(func_exp_xs) + reduced_max_xs

cast(x) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
400
401
402
403
404
405
406
407
@classmethod
def cast(cls, x: Tensor) -> Tensor:
    if x.is_complex():
        return x
    if x.is_floating_point():
        return x.to(x.dtype.to_complex())
    default_float_dtype = torch.get_default_dtype()
    return x.to(default_float_dtype.to_complex())

mul(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
421
422
423
@classmethod
def mul(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.add, xs)

prod(x, /, *, dim=None, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
417
418
419
@classmethod
def prod(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    return x.sum(dim=dim, keepdim=keepdim)

sum(x, /, *, dim=None, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
409
410
411
@classmethod
def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    return x.logsumexp(dim=dim, keepdim=keepdim)

LSESumSemiring ¤

Bases: SemiringImpl

The log space computation.

Source code in cirkit/backend/torch/semiring.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
@SemiringImpl.register("lse-sum")
@final  # type: ignore[misc]
class LSESumSemiring(SemiringImpl):
    """The log space computation."""

    @classmethod
    def cast(cls, x: Tensor) -> Tensor:
        if x.is_floating_point():
            return x
        if not x.is_complex():
            default_float_dtype = torch.get_default_dtype()
            return x.to(default_float_dtype)
        raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

    @classmethod
    def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        return x.logsumexp(dim=dim, keepdim=keepdim)

    @classmethod
    def add(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.logaddexp, xs)

    @classmethod
    def prod(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        dim = tuple(dim) if isinstance(dim, Sequence) else dim  # dim must be concrete type for sum.
        return x.sum(dim=dim, keepdim=keepdim)

    @classmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.add, xs)

    @classmethod
    def apply_reduce(
        cls,
        func: Callable[[Unpack[Ts]], Tensor],
        *xs: Unpack[Ts],
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
        #       generators, because generators can't save much if we want to reuse.
        # CAST: Expected tuple of Tensor but got Ts.
        xs = [cast(Tensor, xi) for xi in xs]
        max_xs = [
            torch.clamp(
                torch.amax(xi, dim=dim, keepdim=True),
                min=torch.finfo(xi.dtype).min,
                max=torch.finfo(xi.dtype).max,
            )
            for xi in xs
        ]
        exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

        # NOTE: exp_x is not tuple, but list still can be unpacked with *.
        # CAST: Expected Ts but got tuple (actually list) of Tensor.
        func_exp_xs = func(*cast(tuple[Unpack[Ts]], exp_xs))

        # TODO: verify the behavior of reduce under torch.compile
        reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
        if not keepdim:
            reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.
        return torch.log(func_exp_xs) + reduced_max_xs

add(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
349
350
351
@classmethod
def add(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.logaddexp, xs)

apply_reduce(func, *xs, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
@classmethod
def apply_reduce(
    cls,
    func: Callable[[Unpack[Ts]], Tensor],
    *xs: Unpack[Ts],
    dim: int,
    keepdim: bool,
) -> Tensor:
    # NOTE: Due to usage of intermediate results, they need to be instantiated in lists but not
    #       generators, because generators can't save much if we want to reuse.
    # CAST: Expected tuple of Tensor but got Ts.
    xs = [cast(Tensor, xi) for xi in xs]
    max_xs = [
        torch.clamp(
            torch.amax(xi, dim=dim, keepdim=True),
            min=torch.finfo(xi.dtype).min,
            max=torch.finfo(xi.dtype).max,
        )
        for xi in xs
    ]
    exp_xs = [torch.exp(xi - max_xi) for xi, max_xi in zip(xs, max_xs)]

    # NOTE: exp_x is not tuple, but list still can be unpacked with *.
    # CAST: Expected Ts but got tuple (actually list) of Tensor.
    func_exp_xs = func(*cast(tuple[Unpack[Ts]], exp_xs))

    # TODO: verify the behavior of reduce under torch.compile
    reduced_max_xs = functools.reduce(torch.add, max_xs)  # Do n-1 add instead of n.
    if not keepdim:
        reduced_max_xs = reduced_max_xs.squeeze(dim)  # To match shape of func_exp_x.
    return torch.log(func_exp_xs) + reduced_max_xs

cast(x) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
336
337
338
339
340
341
342
343
@classmethod
def cast(cls, x: Tensor) -> Tensor:
    if x.is_floating_point():
        return x
    if not x.is_complex():
        default_float_dtype = torch.get_default_dtype()
        return x.to(default_float_dtype)
    raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

mul(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
358
359
360
@classmethod
def mul(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.add, xs)

prod(x, /, *, dim=None, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
353
354
355
356
@classmethod
def prod(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    dim = tuple(dim) if isinstance(dim, Sequence) else dim  # dim must be concrete type for sum.
    return x.sum(dim=dim, keepdim=keepdim)

sum(x, /, *, dim=None, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
345
346
347
@classmethod
def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    return x.logsumexp(dim=dim, keepdim=keepdim)

SemiringImpl ¤

Bases: ABC

The abstract base class for semiring implementations.

Due to numerical precision, the actual units in computational graph may hold values in, e.g., log space, instead of linear space. And therefore, this provides a unified interface for the computations so that computation can be done in a space suitable to the implementation regardless of the global setting.

Source code in cirkit/backend/torch/semiring.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
class SemiringImpl(ABC):
    """The abstract base class for semiring implementations.

    Due to numerical precision, the actual units in computational graph may hold values in, e.g., \
    log space, instead of linear space. And therefore, this provides a unified interface for the \
    computations so that computation can be done in a space suitable to the implementation \
    regardless of the global setting.
    """

    # A registry from semiring string identifiers to semiring class implementations
    _registry: ClassVar[dict[str, type["SemiringImpl"]]] = {}

    # A registry of morphisms between semiring class implementations
    _registry_morphisms: ClassVar[
        dict[tuple[type["SemiringImpl"], type["SemiringImpl"]], Callable[[Tensor], Tensor]]
    ] = {}

    @final
    @staticmethod
    def register(name: str) -> Callable[[SemiringT], SemiringT]:
        """Register a concrete semiring implementation by its name.

        Args:
            name (str): The name to register.

        Returns:
            Callable[[Semiring], Semiring]: The class decorator to register a subclass.
        """

        def _decorator(cls: Semiring) -> Semiring:
            """Register a concrete semiring implementation by its name.

            Args:
                cls (Semiring): The semiring subclass to register.

            Returns:
                Semiring: The class passed in.
            """
            # CAST: getattr gives Any.
            assert cast(
                bool, getattr(cls, "__final__", False)
            ), "Subclasses of SemiringImpl should be final."
            SemiringImpl._registry[name] = cls
            return cls

        return _decorator

    @final
    @classmethod
    def register_map_from(
        cls, other: SemiringT
    ) -> Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:
        """Register a concrete semiring morphism implementation.

        Args:
            other: The source semiring.

        Returns:
            Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]: The function decorator to register the morphism.
        """

        def _decorator(func: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]:
            """Register a concrete semiring morphism implementation.

            Args:
                func (Callable[[Tensor], Tensor]): The morphism between semirings to register.

            Returns:
                Callable[[Tensor], Tensor]: The morphism passed in.
            """
            SemiringImpl._registry_morphisms[(other, cls)] = func
            return func

        return _decorator

    @final
    @staticmethod
    def list() -> Iterable[str]:
        """List all semiring names registered.

        Returns:
            Iterable[str]: An iterable over all names available.
        """
        return iter(SemiringImpl._registry)

    @final
    @staticmethod
    def from_name(name: str) -> SemiringT:
        """Get a semiring by its registered name.

        Args:
            name (str): The name to probe.

        Returns:
            Semiring: The retrieved concrete Semiring.
        """
        if name not in SemiringImpl._registry:
            raise IndexError(
                f"Unknown semiring '{name}'. Use @SemiringImpl.register(<name>) to register a new semiring"
            )
        return SemiringImpl._registry[name]

    @final
    @classmethod
    def map_from(cls, x: Tensor, semiring: SemiringT) -> Tensor:
        """Map a tensor from the given semiring to `this` semiring.

        Args:
            x:
            semiring:

        Returns:

        """
        if cls == semiring:
            return x
        func: Callable[[Tensor], Tensor] | None = SemiringImpl._registry_morphisms.get(
            (semiring, cls), None
        )
        if func is None:
            raise NotImplementedError(
                f"Semiring map from '{semiring.__name__}' to '{cls.__name__}' is not implemented"
            )
        return func(x)

    @final
    def __new__(cls) -> "SemiringImpl":
        """Raise an error when this class is instantiated.

        Raises:
            TypeError: When this class is instantiated.

        Returns:
            SemiringImpl: This method never returns.
        """
        raise TypeError("This class cannot be instantiated")

    @classmethod
    def einsum(
        cls,
        equation: str,
        *,
        inputs: tuple[Tensor, ...],
        operands: tuple[Tensor, ...],
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        operands = tuple(cls.cast(opd) for opd in operands)

        def _einsum_func(*xs: Tensor) -> Tensor:
            return torch.einsum(equation, *xs, *operands)

        return cls.apply_reduce(_einsum_func, *inputs, dim=dim, keepdim=keepdim)

    # NOTE: Subclasses should not touch any of the above final static methods but should implement
    #       all the following abstract class methods, and subclasses should be @final.

    @classmethod
    @abstractmethod
    def cast(cls, x: Tensor) -> Tensor:
        """Cast a tensor to the data type required by `this` semiring.

        Args:
            x: The tensor.

        Returns:
            Tensor: The tensor converted to the required data type.
        """

    @classmethod
    @abstractmethod
    def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        """

        Args:
            x:
            dim:
            keepdim:

        Returns:

        """

    @classmethod
    @abstractmethod
    def add(cls, *xs: Tensor) -> Tensor:
        """

        Args:
            *xs:

        Returns:

        """

    @classmethod
    @abstractmethod
    def prod(
        cls, x: Tensor, /, *, dim: int | Sequence[int] | None = None, keepdim: bool = False
    ) -> Tensor:
        """Do the product within a tensor on given dim(s).

        Args:
            x (Tensor): The input tensor.
            dim (Optional[Union[int, Sequence[int]]], optional): The dimension(s) to reduce along, \
                None for all dims. Defaults to None.
            keepdim (bool, optional): Whether the dim is kept as a size-1 dim. Defaults to False.

        Returns:
            Tensor: The product result.
        """

    @classmethod
    @abstractmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        """Multiply broadcastable tensors.

        Args:
            *xs (Tensor): The input tensors, should have broadcastable shapes.

        Returns:
            Tensor: The multiply result.
        """

    # TODO: it's difficult to bound a variadic to Tensor with TypeVars, we can only use unbounded
    #       Unpack[TypeVarTuple].

    @classmethod
    @abstractmethod
    def apply_reduce(
        cls,
        func: Callable[[Unpack[Ts]], Tensor],
        *xs: Unpack[Ts],
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        """Apply a sum-like functions to the tensor(s).

        The sum units may perform not just plain sum, but also weighted sum or even einsum. In \
        fact, it can possibly be any function that is linear to the each input. All that kind of \
        func can be used here.

        It is expected that func always does computation in the linear space, as with numerical \
        tricks, only relatively significant numbers contribute the final answer, and underflow \
        will not affect much. However, the input/output values may still be in another space, and \
        needs to be projected here.

        Args:
            func (Callable[[Unpack[Ts]], Tensor]): The sum-like function to be applied.
            *xs (Unpack[Ts]): The input tensors. Type expected to be Tensor.
            dim (int): The dimension along which the values are \
                correlated and must be scaled together, i.e., the dim to sum along. This should \
                match the actual operation done by func. The same dim is shared among all inputs.
            keepdim (bool): Whether the dim is kept as a size-1 dim, should match the actual \
                operation done by func.

        Returns:
            Tensor: The sum result.
        """

_registry = {} class-attribute ¤

_registry_morphisms = {} class-attribute ¤

__new__() ¤

Raise an error when this class is instantiated.

Raises:

Type Description
TypeError

When this class is instantiated.

Returns:

Name Type Description
SemiringImpl SemiringImpl

This method never returns.

Source code in cirkit/backend/torch/semiring.py
141
142
143
144
145
146
147
148
149
150
151
@final
def __new__(cls) -> "SemiringImpl":
    """Raise an error when this class is instantiated.

    Raises:
        TypeError: When this class is instantiated.

    Returns:
        SemiringImpl: This method never returns.
    """
    raise TypeError("This class cannot be instantiated")

add(*xs) abstractmethod classmethod ¤

Parameters:

Name Type Description Default
*xs Tensor
()

Returns:

Source code in cirkit/backend/torch/semiring.py
199
200
201
202
203
204
205
206
207
208
209
@classmethod
@abstractmethod
def add(cls, *xs: Tensor) -> Tensor:
    """

    Args:
        *xs:

    Returns:

    """

apply_reduce(func, *xs, dim, keepdim) abstractmethod classmethod ¤

Apply a sum-like functions to the tensor(s).

The sum units may perform not just plain sum, but also weighted sum or even einsum. In fact, it can possibly be any function that is linear to the each input. All that kind of func can be used here.

It is expected that func always does computation in the linear space, as with numerical tricks, only relatively significant numbers contribute the final answer, and underflow will not affect much. However, the input/output values may still be in another space, and needs to be projected here.

Parameters:

Name Type Description Default
func Callable[[Unpack[Ts]], Tensor]

The sum-like function to be applied.

required
*xs Unpack[Ts]

The input tensors. Type expected to be Tensor.

()
dim int

The dimension along which the values are correlated and must be scaled together, i.e., the dim to sum along. This should match the actual operation done by func. The same dim is shared among all inputs.

required
keepdim bool

Whether the dim is kept as a size-1 dim, should match the actual operation done by func.

required

Returns:

Name Type Description
Tensor Tensor

The sum result.

Source code in cirkit/backend/torch/semiring.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
@classmethod
@abstractmethod
def apply_reduce(
    cls,
    func: Callable[[Unpack[Ts]], Tensor],
    *xs: Unpack[Ts],
    dim: int,
    keepdim: bool,
) -> Tensor:
    """Apply a sum-like functions to the tensor(s).

    The sum units may perform not just plain sum, but also weighted sum or even einsum. In \
    fact, it can possibly be any function that is linear to the each input. All that kind of \
    func can be used here.

    It is expected that func always does computation in the linear space, as with numerical \
    tricks, only relatively significant numbers contribute the final answer, and underflow \
    will not affect much. However, the input/output values may still be in another space, and \
    needs to be projected here.

    Args:
        func (Callable[[Unpack[Ts]], Tensor]): The sum-like function to be applied.
        *xs (Unpack[Ts]): The input tensors. Type expected to be Tensor.
        dim (int): The dimension along which the values are \
            correlated and must be scaled together, i.e., the dim to sum along. This should \
            match the actual operation done by func. The same dim is shared among all inputs.
        keepdim (bool): Whether the dim is kept as a size-1 dim, should match the actual \
            operation done by func.

    Returns:
        Tensor: The sum result.
    """

cast(x) abstractmethod classmethod ¤

Cast a tensor to the data type required by this semiring.

Parameters:

Name Type Description Default
x Tensor

The tensor.

required

Returns:

Name Type Description
Tensor Tensor

The tensor converted to the required data type.

Source code in cirkit/backend/torch/semiring.py
173
174
175
176
177
178
179
180
181
182
183
@classmethod
@abstractmethod
def cast(cls, x: Tensor) -> Tensor:
    """Cast a tensor to the data type required by `this` semiring.

    Args:
        x: The tensor.

    Returns:
        Tensor: The tensor converted to the required data type.
    """

einsum(equation, *, inputs, operands, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@classmethod
def einsum(
    cls,
    equation: str,
    *,
    inputs: tuple[Tensor, ...],
    operands: tuple[Tensor, ...],
    dim: int,
    keepdim: bool,
) -> Tensor:
    operands = tuple(cls.cast(opd) for opd in operands)

    def _einsum_func(*xs: Tensor) -> Tensor:
        return torch.einsum(equation, *xs, *operands)

    return cls.apply_reduce(_einsum_func, *inputs, dim=dim, keepdim=keepdim)

from_name(name) staticmethod ¤

Get a semiring by its registered name.

Parameters:

Name Type Description Default
name str

The name to probe.

required

Returns:

Name Type Description
Semiring SemiringT

The retrieved concrete Semiring.

Source code in cirkit/backend/torch/semiring.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@final
@staticmethod
def from_name(name: str) -> SemiringT:
    """Get a semiring by its registered name.

    Args:
        name (str): The name to probe.

    Returns:
        Semiring: The retrieved concrete Semiring.
    """
    if name not in SemiringImpl._registry:
        raise IndexError(
            f"Unknown semiring '{name}'. Use @SemiringImpl.register(<name>) to register a new semiring"
        )
    return SemiringImpl._registry[name]

list() staticmethod ¤

List all semiring names registered.

Returns:

Type Description
Iterable[str]

Iterable[str]: An iterable over all names available.

Source code in cirkit/backend/torch/semiring.py
91
92
93
94
95
96
97
98
99
@final
@staticmethod
def list() -> Iterable[str]:
    """List all semiring names registered.

    Returns:
        Iterable[str]: An iterable over all names available.
    """
    return iter(SemiringImpl._registry)

map_from(x, semiring) classmethod ¤

Map a tensor from the given semiring to this semiring.

Parameters:

Name Type Description Default
x Tensor
required
semiring SemiringT
required

Returns:

Source code in cirkit/backend/torch/semiring.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@final
@classmethod
def map_from(cls, x: Tensor, semiring: SemiringT) -> Tensor:
    """Map a tensor from the given semiring to `this` semiring.

    Args:
        x:
        semiring:

    Returns:

    """
    if cls == semiring:
        return x
    func: Callable[[Tensor], Tensor] | None = SemiringImpl._registry_morphisms.get(
        (semiring, cls), None
    )
    if func is None:
        raise NotImplementedError(
            f"Semiring map from '{semiring.__name__}' to '{cls.__name__}' is not implemented"
        )
    return func(x)

mul(*xs) abstractmethod classmethod ¤

Multiply broadcastable tensors.

Parameters:

Name Type Description Default
*xs Tensor

The input tensors, should have broadcastable shapes.

()

Returns:

Name Type Description
Tensor Tensor

The multiply result.

Source code in cirkit/backend/torch/semiring.py
228
229
230
231
232
233
234
235
236
237
238
@classmethod
@abstractmethod
def mul(cls, *xs: Tensor) -> Tensor:
    """Multiply broadcastable tensors.

    Args:
        *xs (Tensor): The input tensors, should have broadcastable shapes.

    Returns:
        Tensor: The multiply result.
    """

prod(x, /, *, dim=None, keepdim=False) abstractmethod classmethod ¤

Do the product within a tensor on given dim(s).

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
dim Optional[Union[int, Sequence[int]]]

The dimension(s) to reduce along, None for all dims. Defaults to None.

None
keepdim bool

Whether the dim is kept as a size-1 dim. Defaults to False.

False

Returns:

Name Type Description
Tensor Tensor

The product result.

Source code in cirkit/backend/torch/semiring.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@classmethod
@abstractmethod
def prod(
    cls, x: Tensor, /, *, dim: int | Sequence[int] | None = None, keepdim: bool = False
) -> Tensor:
    """Do the product within a tensor on given dim(s).

    Args:
        x (Tensor): The input tensor.
        dim (Optional[Union[int, Sequence[int]]], optional): The dimension(s) to reduce along, \
            None for all dims. Defaults to None.
        keepdim (bool, optional): Whether the dim is kept as a size-1 dim. Defaults to False.

    Returns:
        Tensor: The product result.
    """

register(name) staticmethod ¤

Register a concrete semiring implementation by its name.

Parameters:

Name Type Description Default
name str

The name to register.

required

Returns:

Type Description
Callable[[SemiringT], SemiringT]

Callable[[Semiring], Semiring]: The class decorator to register a subclass.

Source code in cirkit/backend/torch/semiring.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@final
@staticmethod
def register(name: str) -> Callable[[SemiringT], SemiringT]:
    """Register a concrete semiring implementation by its name.

    Args:
        name (str): The name to register.

    Returns:
        Callable[[Semiring], Semiring]: The class decorator to register a subclass.
    """

    def _decorator(cls: Semiring) -> Semiring:
        """Register a concrete semiring implementation by its name.

        Args:
            cls (Semiring): The semiring subclass to register.

        Returns:
            Semiring: The class passed in.
        """
        # CAST: getattr gives Any.
        assert cast(
            bool, getattr(cls, "__final__", False)
        ), "Subclasses of SemiringImpl should be final."
        SemiringImpl._registry[name] = cls
        return cls

    return _decorator

register_map_from(other) classmethod ¤

Register a concrete semiring morphism implementation.

Parameters:

Name Type Description Default
other SemiringT

The source semiring.

required

Returns:

Type Description
Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]

Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]: The function decorator to register the morphism.

Source code in cirkit/backend/torch/semiring.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@final
@classmethod
def register_map_from(
    cls, other: SemiringT
) -> Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]:
    """Register a concrete semiring morphism implementation.

    Args:
        other: The source semiring.

    Returns:
        Callable[[Callable[[Tensor], Tensor]], Callable[[Tensor], Tensor]]: The function decorator to register the morphism.
    """

    def _decorator(func: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]:
        """Register a concrete semiring morphism implementation.

        Args:
            func (Callable[[Tensor], Tensor]): The morphism between semirings to register.

        Returns:
            Callable[[Tensor], Tensor]: The morphism passed in.
        """
        SemiringImpl._registry_morphisms[(other, cls)] = func
        return func

    return _decorator

sum(x, /, *, dim=None, keepdim=False) abstractmethod classmethod ¤

Parameters:

Name Type Description Default
x Tensor
required
dim int | None
None
keepdim bool
False

Returns:

Source code in cirkit/backend/torch/semiring.py
185
186
187
188
189
190
191
192
193
194
195
196
197
@classmethod
@abstractmethod
def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    """

    Args:
        x:
        dim:
        keepdim:

    Returns:

    """

SumProductSemiring ¤

Bases: SemiringImpl

The linear space computation.

Source code in cirkit/backend/torch/semiring.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
@SemiringImpl.register("sum-product")
@final  # type: ignore[misc]
class SumProductSemiring(SemiringImpl):
    """The linear space computation."""

    @classmethod
    def cast(cls, x: Tensor) -> Tensor:
        """Cast a tensor to the data type required by the semiring.

        Args:
            x (Tensor): The tensor.

        Returns:
            Tensor: The tensor converted to the required data type.
        """
        if x.is_floating_point():
            return x
        if not x.is_complex():
            default_float_dtype = torch.get_default_dtype()
            return x.to(default_float_dtype)
        raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

    @classmethod
    def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        return x.sum(dim=dim, keepdim=keepdim)

    @classmethod
    def add(cls, *xs: Tensor) -> Tensor:
        raise functools.reduce(torch.add, xs)

    @classmethod
    def prod(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
        # prod only accepts one dim and cannot be None.
        dim = dim if dim is not None else range(x.ndim)
        return torch.prod(x, dim=dim, keepdim=keepdim)

    @classmethod
    def mul(cls, *xs: Tensor) -> Tensor:
        return functools.reduce(torch.mul, xs)

    @classmethod
    def apply_reduce(
        cls,
        func: Callable[[Unpack[Ts]], Tensor],
        *xs: Unpack[Ts],
        dim: int,
        keepdim: bool,
    ) -> Tensor:
        return func(*xs)

add(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
306
307
308
@classmethod
def add(cls, *xs: Tensor) -> Tensor:
    raise functools.reduce(torch.add, xs)

apply_reduce(func, *xs, dim, keepdim) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
320
321
322
323
324
325
326
327
328
@classmethod
def apply_reduce(
    cls,
    func: Callable[[Unpack[Ts]], Tensor],
    *xs: Unpack[Ts],
    dim: int,
    keepdim: bool,
) -> Tensor:
    return func(*xs)

cast(x) classmethod ¤

Cast a tensor to the data type required by the semiring.

Parameters:

Name Type Description Default
x Tensor

The tensor.

required

Returns:

Name Type Description
Tensor Tensor

The tensor converted to the required data type.

Source code in cirkit/backend/torch/semiring.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
@classmethod
def cast(cls, x: Tensor) -> Tensor:
    """Cast a tensor to the data type required by the semiring.

    Args:
        x (Tensor): The tensor.

    Returns:
        Tensor: The tensor converted to the required data type.
    """
    if x.is_floating_point():
        return x
    if not x.is_complex():
        default_float_dtype = torch.get_default_dtype()
        return x.to(default_float_dtype)
    raise ValueError(f"Cannot cast a tensor of type '{x.dtype}' to the '{cls.__name__}'")

mul(*xs) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
316
317
318
@classmethod
def mul(cls, *xs: Tensor) -> Tensor:
    return functools.reduce(torch.mul, xs)

prod(x, /, *, dim=None, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
310
311
312
313
314
@classmethod
def prod(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    # prod only accepts one dim and cannot be None.
    dim = dim if dim is not None else range(x.ndim)
    return torch.prod(x, dim=dim, keepdim=keepdim)

sum(x, /, *, dim=None, keepdim=False) classmethod ¤

Source code in cirkit/backend/torch/semiring.py
302
303
304
@classmethod
def sum(cls, x: Tensor, /, *, dim: int | None = None, keepdim: bool = False) -> Tensor:
    return x.sum(dim=dim, keepdim=keepdim)

_(x) ¤

Source code in cirkit/backend/torch/semiring.py
504
505
506
@ComplexLSESumSemiring.register_map_from(LSESumSemiring)
def _(x: Tensor) -> Tensor:
    return ComplexLSESumSemiring.cast(x)