Skip to content

nodes

nodes ¤

TorchBinaryParameterOp ¤

Bases: TorchParameterOp, ABC

Source code in cirkit/backend/torch/parameters/nodes.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
class TorchBinaryParameterOp(TorchParameterOp, ABC):
    def __init__(
        self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
    ) -> None:
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)

    @property
    def in_shape1(self) -> tuple[int, ...]:
        in_shape, _ = self.in_shapes
        return in_shape

    @property
    def in_shape2(self) -> tuple[int, ...]:
        _, in_shape = self.in_shapes
        return in_shape

    @property
    def config(self) -> dict[str, Any]:
        return {"in_shape1": self.in_shape1, "in_shape2": self.in_shape2}

    def __call__(self, x1: Tensor, x2: Tensor) -> Tensor:
        """Get the reparameterized parameters.

        Returns:
            Tensor: The parameters after reparameterization.
        """
        # IGNORE: Idiom for nn.Module.__call__.
        return super().__call__(x1, x2)  # type: ignore[no-any-return,misc]

    @abstractmethod
    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        ...

config property ¤

in_shape1 property ¤

in_shape2 property ¤

__call__(x1, x2) ¤

Get the reparameterized parameters.

Returns:

Name Type Description
Tensor Tensor

The parameters after reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
363
364
365
366
367
368
369
370
def __call__(self, x1: Tensor, x2: Tensor) -> Tensor:
    """Get the reparameterized parameters.

    Returns:
        Tensor: The parameters after reparameterization.
    """
    # IGNORE: Idiom for nn.Module.__call__.
    return super().__call__(x1, x2)  # type: ignore[no-any-return,misc]

__init__(in_shape1, in_shape2, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
344
345
346
347
def __init__(
    self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
) -> None:
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)

forward(x1, x2) abstractmethod ¤

Source code in cirkit/backend/torch/parameters/nodes.py
372
373
374
@abstractmethod
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    ...

TorchClampParameter ¤

Bases: TorchEntrywiseParameterOp

Exp reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
class TorchClampParameter(TorchEntrywiseParameterOp):
    """Exp reparameterization."""

    def __init__(
        self,
        in_shape: tuple[int, ...],
        vmin: float | None = None,
        vmax: float | None = None,
        *,
        num_folds: int = 1,
    ) -> None:
        assert vmin is not None or vmax is not None
        super().__init__(in_shape, num_folds=num_folds)
        self.vmin = vmin
        self.vmax = vmax

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        if self.vmin is not None:
            config["vmin"] = self.vmin
        if self.vmax is not None:
            config["vmax"] = self.vmax
        return config

    def forward(self, x: Tensor) -> Tensor:
        return torch.clamp(x, min=self.vmin, max=self.vmax)

config property ¤

vmax = vmax instance-attribute ¤

vmin = vmin instance-attribute ¤

__init__(in_shape, vmin=None, vmax=None, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
656
657
658
659
660
661
662
663
664
665
666
667
def __init__(
    self,
    in_shape: tuple[int, ...],
    vmin: float | None = None,
    vmax: float | None = None,
    *,
    num_folds: int = 1,
) -> None:
    assert vmin is not None or vmax is not None
    super().__init__(in_shape, num_folds=num_folds)
    self.vmin = vmin
    self.vmax = vmax

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
678
679
def forward(self, x: Tensor) -> Tensor:
    return torch.clamp(x, min=self.vmin, max=self.vmax)

TorchConjugateParameter ¤

Bases: TorchEntrywiseParameterOp

Conjugate parameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
682
683
684
685
686
687
688
689
class TorchConjugateParameter(TorchEntrywiseParameterOp):
    """Conjugate parameterization."""

    def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1) -> None:
        super().__init__(in_shape, num_folds=num_folds)

    def forward(self, x: Tensor) -> Tensor:
        return torch.conj(x)

__init__(in_shape, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
685
686
def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1) -> None:
    super().__init__(in_shape, num_folds=num_folds)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
688
689
def forward(self, x: Tensor) -> Tensor:
    return torch.conj(x)

TorchEntrywiseParameterOp ¤

Bases: TorchUnaryParameterOp, ABC

Source code in cirkit/backend/torch/parameters/nodes.py
377
378
379
380
class TorchEntrywiseParameterOp(TorchUnaryParameterOp, ABC):
    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape

shape property ¤

TorchEntrywiseReduceParameterOp ¤

Bases: TorchEntrywiseParameterOp, ABC

The base class for normalized reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
class TorchEntrywiseReduceParameterOp(TorchEntrywiseParameterOp, ABC):
    """The base class for normalized reparameterization."""

    # NOTE: This class only serves as the common base of all normalized reparams, but include
    #       nothing more. It's up to the implementations to define further details.
    def __init__(
        self,
        in_shape: tuple[int, ...],
        *,
        dim: int = -1,
        num_folds: int = 1,
    ) -> None:
        dim = dim if dim >= 0 else dim + len(in_shape)
        assert 0 <= dim < len(in_shape)
        super().__init__(in_shape, num_folds=num_folds)
        self.dim = dim

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["dim"] = self.dim
        return config

config property ¤

dim = dim instance-attribute ¤

__init__(in_shape, *, dim=-1, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
416
417
418
419
420
421
422
423
424
425
426
def __init__(
    self,
    in_shape: tuple[int, ...],
    *,
    dim: int = -1,
    num_folds: int = 1,
) -> None:
    dim = dim if dim >= 0 else dim + len(in_shape)
    assert 0 <= dim < len(in_shape)
    super().__init__(in_shape, num_folds=num_folds)
    self.dim = dim

TorchExpParameter ¤

Bases: TorchEntrywiseParameterOp

Exp reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
607
608
609
610
611
class TorchExpParameter(TorchEntrywiseParameterOp):
    """Exp reparameterization."""

    def forward(self, x: Tensor) -> Tensor:
        return torch.exp(x)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
610
611
def forward(self, x: Tensor) -> Tensor:
    return torch.exp(x)

TorchFlattenParameter ¤

Bases: TorchUnaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
class TorchFlattenParameter(TorchUnaryParameterOp):
    def __init__(
        self,
        in_shape: tuple[int, ...],
        num_folds: int = 1,
        start_dim: int = 0,
        end_dim: int = -1,
    ):
        super().__init__(in_shape, num_folds=num_folds)
        start_dim = start_dim if start_dim >= 0 else start_dim + len(in_shape)
        assert 0 <= start_dim < len(in_shape)
        end_dim = end_dim if end_dim >= 0 else end_dim + len(in_shape)
        assert 0 <= end_dim < len(in_shape)
        assert start_dim < end_dim
        self.start_dim = start_dim
        self.end_dim = end_dim

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["start_dim"] = self.start_dim
        config["end_dim"] = self.end_dim
        return config

    @cached_property
    def shape(self) -> tuple[int, ...]:
        flattened_dim = np.prod(
            [self.in_shapes[0][i] for i in range(self.start_dim, self.end_dim + 1)]
        )
        return (
            *self.in_shapes[0][: self.start_dim],
            flattened_dim,
            *self.in_shapes[0][self.end_dim + 1 :],
        )

    def forward(self, x: Tensor) -> Tensor:
        return torch.flatten(x, start_dim=self.start_dim + 1, end_dim=self.end_dim + 1)

config property ¤

end_dim = end_dim instance-attribute ¤

shape cached property ¤

start_dim = start_dim instance-attribute ¤

__init__(in_shape, num_folds=1, start_dim=0, end_dim=-1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
def __init__(
    self,
    in_shape: tuple[int, ...],
    num_folds: int = 1,
    start_dim: int = 0,
    end_dim: int = -1,
):
    super().__init__(in_shape, num_folds=num_folds)
    start_dim = start_dim if start_dim >= 0 else start_dim + len(in_shape)
    assert 0 <= start_dim < len(in_shape)
    end_dim = end_dim if end_dim >= 0 else end_dim + len(in_shape)
    assert 0 <= end_dim < len(in_shape)
    assert start_dim < end_dim
    self.start_dim = start_dim
    self.end_dim = end_dim

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
782
783
def forward(self, x: Tensor) -> Tensor:
    return torch.flatten(x, start_dim=self.start_dim + 1, end_dim=self.end_dim + 1)

TorchGaussianProductLogPartition ¤

Bases: TorchParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
class TorchGaussianProductLogPartition(TorchParameterOp):
    def __init__(
        self,
        in_mean1_shape: tuple[int, ...],
        in_stddev1_shape: tuple[int, ...],
        in_mean2_shape: tuple[int, ...],
        in_stddev2_shape: tuple[int, ...],
        *,
        num_folds: int = 1,
    ) -> None:
        assert in_mean1_shape == in_stddev1_shape
        assert in_mean2_shape == in_stddev2_shape
        assert in_mean1_shape[1] == in_mean2_shape[1]
        assert in_stddev1_shape[1] == in_stddev2_shape[1]
        super().__init__(
            in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, num_folds=num_folds
        )
        self._log_two_pi = np.log(2.0 * np.pi)

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            self.in_shapes[0][0] * self.in_shapes[2][0],
            self.in_shapes[0][1],
        )

    @property
    def config(self) -> dict[str, Any]:
        return {
            "in_mean1_shape": self.in_shapes[0],
            "in_stddev1_shape": self.in_shapes[1],
            "in_mean2_shape": self.in_shapes[2],
            "in_stddev2_shape": self.in_shapes[3],
        }

    def forward(
        self,
        mean1: Tensor,
        stddev1: Tensor,
        mean2: Tensor,
        stddev2: Tensor,
    ) -> Tensor:
        var1 = torch.square(stddev1)  # (F, K1, C)
        var2 = torch.square(stddev2)  # (F, K2, C)
        var12 = var1.unsqueeze(dim=2) + var2.unsqueeze(dim=1)  # (F, K1, K2, C)
        inv_var12 = torch.reciprocal(var12)
        sq_mahalanobis = torch.square(mean1.unsqueeze(dim=2) - mean2.unsqueeze(dim=1)) * inv_var12
        log_partition = -0.5 * (self._log_two_pi + torch.log(var12) + sq_mahalanobis)
        return log_partition.view(-1, *self.shape)  # (F, K1 * K2, C)

_log_two_pi = np.log(2.0 * np.pi) instance-attribute ¤

config property ¤

shape property ¤

__init__(in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
def __init__(
    self,
    in_mean1_shape: tuple[int, ...],
    in_stddev1_shape: tuple[int, ...],
    in_mean2_shape: tuple[int, ...],
    in_stddev2_shape: tuple[int, ...],
    *,
    num_folds: int = 1,
) -> None:
    assert in_mean1_shape == in_stddev1_shape
    assert in_mean2_shape == in_stddev2_shape
    assert in_mean1_shape[1] == in_mean2_shape[1]
    assert in_stddev1_shape[1] == in_stddev2_shape[1]
    super().__init__(
        in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, num_folds=num_folds
    )
    self._log_two_pi = np.log(2.0 * np.pi)

forward(mean1, stddev1, mean2, stddev2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
916
917
918
919
920
921
922
923
924
925
926
927
928
929
def forward(
    self,
    mean1: Tensor,
    stddev1: Tensor,
    mean2: Tensor,
    stddev2: Tensor,
) -> Tensor:
    var1 = torch.square(stddev1)  # (F, K1, C)
    var2 = torch.square(stddev2)  # (F, K2, C)
    var12 = var1.unsqueeze(dim=2) + var2.unsqueeze(dim=1)  # (F, K1, K2, C)
    inv_var12 = torch.reciprocal(var12)
    sq_mahalanobis = torch.square(mean1.unsqueeze(dim=2) - mean2.unsqueeze(dim=1)) * inv_var12
    log_partition = -0.5 * (self._log_two_pi + torch.log(var12) + sq_mahalanobis)
    return log_partition.view(-1, *self.shape)  # (F, K1 * K2, C)

TorchGaussianProductMean ¤

Bases: TorchParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
class TorchGaussianProductMean(TorchParameterOp):
    def __init__(
        self,
        in_mean1_shape: tuple[int, ...],
        in_stddev1_shape: tuple[int, ...],
        in_mean2_shape: tuple[int, ...],
        in_stddev2_shape: tuple[int, ...],
        *,
        num_folds: int = 1,
    ) -> None:
        assert in_mean1_shape == in_stddev1_shape
        assert in_mean2_shape == in_stddev2_shape
        assert in_mean1_shape[1] == in_mean2_shape[1]
        assert in_stddev1_shape[1] == in_stddev2_shape[1]
        super().__init__(
            in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, num_folds=num_folds
        )

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            self.in_shapes[0][0] * self.in_shapes[2][0],
            self.in_shapes[0][1],
        )

    @property
    def config(self) -> dict[str, Any]:
        return {
            "in_mean1_shape": self.in_shapes[0],
            "in_stddev1_shape": self.in_shapes[1],
            "in_mean2_shape": self.in_shapes[2],
            "in_stddev2_shape": self.in_shapes[3],
        }

    def forward(self, mean1: Tensor, stddev1: Tensor, mean2: Tensor, stddev2: Tensor) -> Tensor:
        var1 = torch.square(stddev1)  # (F, K1, C)
        var2 = torch.square(stddev2)  # (F, K2, C)
        inv_var12 = torch.reciprocal(
            var1.unsqueeze(dim=2) + var2.unsqueeze(dim=1)
        )  # (F, K1, K2, C)
        wm1 = mean1.unsqueeze(dim=2) * var2.unsqueeze(dim=1)  # (F, K1, K2, C)
        wm2 = mean2.unsqueeze(dim=1) * var1.unsqueeze(dim=2)  # (F, K1, K2, C)
        mean = (wm1 + wm2) * inv_var12  # (F, K1, K2, C)
        return mean.view(-1, *self.shape)  # (F, K1 * K2, C)

config property ¤

shape property ¤

__init__(in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
def __init__(
    self,
    in_mean1_shape: tuple[int, ...],
    in_stddev1_shape: tuple[int, ...],
    in_mean2_shape: tuple[int, ...],
    in_stddev2_shape: tuple[int, ...],
    *,
    num_folds: int = 1,
) -> None:
    assert in_mean1_shape == in_stddev1_shape
    assert in_mean2_shape == in_stddev2_shape
    assert in_mean1_shape[1] == in_mean2_shape[1]
    assert in_stddev1_shape[1] == in_stddev2_shape[1]
    super().__init__(
        in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, num_folds=num_folds
    )

forward(mean1, stddev1, mean2, stddev2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
838
839
840
841
842
843
844
845
846
847
def forward(self, mean1: Tensor, stddev1: Tensor, mean2: Tensor, stddev2: Tensor) -> Tensor:
    var1 = torch.square(stddev1)  # (F, K1, C)
    var2 = torch.square(stddev2)  # (F, K2, C)
    inv_var12 = torch.reciprocal(
        var1.unsqueeze(dim=2) + var2.unsqueeze(dim=1)
    )  # (F, K1, K2, C)
    wm1 = mean1.unsqueeze(dim=2) * var2.unsqueeze(dim=1)  # (F, K1, K2, C)
    wm2 = mean2.unsqueeze(dim=1) * var1.unsqueeze(dim=2)  # (F, K1, K2, C)
    mean = (wm1 + wm2) * inv_var12  # (F, K1, K2, C)
    return mean.view(-1, *self.shape)  # (F, K1 * K2, C)

TorchGaussianProductStddev ¤

Bases: TorchBinaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
class TorchGaussianProductStddev(TorchBinaryParameterOp):
    def __init__(
        self,
        in_stddev1_shape: tuple[int, ...],
        in_stddev2_shape: tuple[int, ...],
        *,
        num_folds: int = 1,
    ) -> None:
        assert in_stddev1_shape[1] == in_stddev2_shape[1]
        super().__init__(in_stddev1_shape, in_stddev2_shape, num_folds=num_folds)

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            self.in_shapes[0][0] * self.in_shapes[1][0],
            self.in_shapes[0][1],
        )

    @property
    def config(self) -> dict[str, Any]:
        return {"in_stddev1_shape": self.in_shapes[0], "in_stddev2_shape": self.in_shapes[1]}

    def forward(self, stddev1: Tensor, stddev2: Tensor) -> Tensor:
        var1 = torch.square(stddev1)  # (F, K1, C)
        var2 = torch.square(stddev2)  # (F, K2, C)
        inv_var1 = torch.reciprocal(var1).unsqueeze(dim=2)  # (F, K1, 1, C)
        inv_var2 = torch.reciprocal(var2).unsqueeze(dim=1)  # (F, 1, K2, C)
        var = torch.reciprocal(inv_var1 + inv_var2)  # (F, K1, K2, C)
        return torch.sqrt(var).view(-1, *self.shape)  # (F, K1 * K2, C)

config property ¤

shape property ¤

__init__(in_stddev1_shape, in_stddev2_shape, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
851
852
853
854
855
856
857
858
859
def __init__(
    self,
    in_stddev1_shape: tuple[int, ...],
    in_stddev2_shape: tuple[int, ...],
    *,
    num_folds: int = 1,
) -> None:
    assert in_stddev1_shape[1] == in_stddev2_shape[1]
    super().__init__(in_stddev1_shape, in_stddev2_shape, num_folds=num_folds)

forward(stddev1, stddev2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
872
873
874
875
876
877
878
def forward(self, stddev1: Tensor, stddev2: Tensor) -> Tensor:
    var1 = torch.square(stddev1)  # (F, K1, C)
    var2 = torch.square(stddev2)  # (F, K2, C)
    inv_var1 = torch.reciprocal(var1).unsqueeze(dim=2)  # (F, K1, 1, C)
    inv_var2 = torch.reciprocal(var2).unsqueeze(dim=1)  # (F, 1, K2, C)
    var = torch.reciprocal(inv_var1 + inv_var2)  # (F, K1, K2, C)
    return torch.sqrt(var).view(-1, *self.shape)  # (F, K1 * K2, C)

TorchHadamardParameter ¤

Bases: TorchBinaryParameterOp

Hadamard product reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
class TorchHadamardParameter(TorchBinaryParameterOp):
    """Hadamard product reparameterization."""

    def __init__(
        self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
    ) -> None:
        assert in_shape1 == in_shape2
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)

    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape1

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        return x1 * x2

shape property ¤

__init__(in_shape1, in_shape2, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
493
494
495
496
497
def __init__(
    self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
) -> None:
    assert in_shape1 == in_shape2
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)

forward(x1, x2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
503
504
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    return x1 * x2

TorchIndexParameter ¤

Bases: TorchUnaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
class TorchIndexParameter(TorchUnaryParameterOp):
    def __init__(
        self,
        in_shape: tuple[int, ...],
        indices: list[int],
        dim: int = -1,
        *,
        num_folds: int = 1,
    ) -> None:
        super().__init__(in_shape, num_folds=num_folds)
        dim = dim if dim >= 0 else dim + len(in_shape)
        assert 0 <= dim < len(in_shape)
        assert all(0 <= i < in_shape[dim] for i in indices)
        super().__init__(in_shape, num_folds=num_folds)
        self.dim = dim
        self.register_buffer("_indices", torch.tensor(indices))

    @property
    def indices(self) -> list[int]:
        return self._indices.cpu().tolist()

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["indices"] = self.indices
        config["dim"] = self.dim
        return config

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            *self.in_shape[: self.dim],
            len(self._indices),
            *self.in_shape[self.dim + 1 :],
        )

    def forward(self, x: Tensor) -> Tensor:
        return x[:, self._indices]

config property ¤

dim = dim instance-attribute ¤

indices property ¤

shape property ¤

__init__(in_shape, indices, dim=-1, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def __init__(
    self,
    in_shape: tuple[int, ...],
    indices: list[int],
    dim: int = -1,
    *,
    num_folds: int = 1,
) -> None:
    super().__init__(in_shape, num_folds=num_folds)
    dim = dim if dim >= 0 else dim + len(in_shape)
    assert 0 <= dim < len(in_shape)
    assert all(0 <= i < in_shape[dim] for i in indices)
    super().__init__(in_shape, num_folds=num_folds)
    self.dim = dim
    self.register_buffer("_indices", torch.tensor(indices))

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
471
472
def forward(self, x: Tensor) -> Tensor:
    return x[:, self._indices]

TorchKroneckerParameter ¤

Bases: TorchBinaryParameterOp

Kronecker product reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
class TorchKroneckerParameter(TorchBinaryParameterOp):
    """Kronecker product reparameterization."""

    def __init__(
        self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
    ) -> None:
        assert len(in_shape1) == len(in_shape2)
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)
        self._batched_kron = torch.vmap(torch.kron)

    @cached_property
    def shape(self) -> tuple[int, ...]:
        return tuple(d1 * d2 for d1, d2 in zip(self.in_shape1, self.in_shape2))

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        return self._batched_kron(x1, x2)

_batched_kron = torch.vmap(torch.kron) instance-attribute ¤

shape cached property ¤

__init__(in_shape1, in_shape2, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
510
511
512
513
514
515
def __init__(
    self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
) -> None:
    assert len(in_shape1) == len(in_shape2)
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)
    self._batched_kron = torch.vmap(torch.kron)

forward(x1, x2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
521
522
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    return self._batched_kron(x1, x2)

TorchLogParameter ¤

Bases: TorchEntrywiseParameterOp

Log reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
614
615
616
617
618
class TorchLogParameter(TorchEntrywiseParameterOp):
    """Log reparameterization."""

    def forward(self, x: Tensor) -> Tensor:
        return torch.log(x)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
617
618
def forward(self, x: Tensor) -> Tensor:
    return torch.log(x)

TorchLogSoftmaxParameter ¤

Bases: TorchEntrywiseReduceParameterOp

Log-Softmax reparameterization.

Range: (-inf, 0). Constraints: logsumexp is 0.

Source code in cirkit/backend/torch/parameters/nodes.py
718
719
720
721
722
723
724
725
726
class TorchLogSoftmaxParameter(TorchEntrywiseReduceParameterOp):
    """Log-Softmax reparameterization.

    Range: (-inf, 0).
    Constraints: logsumexp is 0.
    """

    def forward(self, x: Tensor) -> Tensor:
        return torch.log_softmax(x, dim=self.dim + 1)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
725
726
def forward(self, x: Tensor) -> Tensor:
    return torch.log_softmax(x, dim=self.dim + 1)

TorchMatMulParameter ¤

Bases: TorchBinaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
class TorchMatMulParameter(TorchBinaryParameterOp):
    def __init__(
        self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
    ) -> None:
        assert len(in_shape1) == len(in_shape2) == 2
        assert in_shape1[1] == in_shape2[0]
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)

    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape1[0], self.in_shape2[1]

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        # x1: (F, d1, d2)
        # x2: (F, d2, d3)
        return torch.matmul(x1, x2)  # (F, d1, d3)

shape property ¤

__init__(in_shape1, in_shape2, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
730
731
732
733
734
735
def __init__(
    self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
) -> None:
    assert len(in_shape1) == len(in_shape2) == 2
    assert in_shape1[1] == in_shape2[0]
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)

forward(x1, x2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
741
742
743
744
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    # x1: (F, d1, d2)
    # x2: (F, d2, d3)
    return torch.matmul(x1, x2)  # (F, d1, d3)

TorchMixingWeightParameter ¤

Bases: TorchUnaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
class TorchMixingWeightParameter(TorchUnaryParameterOp):
    def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1):
        super().__init__(in_shape, num_folds=num_folds)
        if len(in_shape) != 2:
            raise ValueError(f"Expected shape (num_units, arity), but found {in_shape}")

    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape[0], self.in_shape[0] * self.in_shape[1]

    def forward(self, x: Tensor) -> Tensor:
        # x: (F, num_units, arity)
        # diag_weights: (arity, num_units, num_units)
        diag_weights = torch.vmap(torch.vmap(torch.diag, in_dims=1))(x)
        # (F, num_units, arity, num_units) -> (F, num_units, arity * num_units)
        return diag_weights.permute(0, 2, 1, 3).flatten(start_dim=2)

shape property ¤

__init__(in_shape, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
787
788
789
790
def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1):
    super().__init__(in_shape, num_folds=num_folds)
    if len(in_shape) != 2:
        raise ValueError(f"Expected shape (num_units, arity), but found {in_shape}")

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
796
797
798
799
800
801
def forward(self, x: Tensor) -> Tensor:
    # x: (F, num_units, arity)
    # diag_weights: (arity, num_units, num_units)
    diag_weights = torch.vmap(torch.vmap(torch.diag, in_dims=1))(x)
    # (F, num_units, arity, num_units) -> (F, num_units, arity * num_units)
    return diag_weights.permute(0, 2, 1, 3).flatten(start_dim=2)

TorchOuterProductParameter ¤

Bases: TorchBinaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class TorchOuterProductParameter(TorchBinaryParameterOp):
    def __init__(
        self,
        in_shape1: tuple[int, ...],
        in_shape2: tuple[int, ...],
        dim: int = -1,
        *,
        num_folds: int = 1,
    ) -> None:
        assert len(in_shape1) == len(in_shape2)
        dim = dim if dim >= 0 else dim + len(in_shape1)
        assert 0 <= dim < len(in_shape1)
        assert in_shape1[:dim] == in_shape2[:dim]
        assert in_shape1[dim + 1 :] == in_shape2[dim + 1 :]
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)
        self.dim = dim

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            *self.in_shape1[: self.dim],
            self.in_shape1[self.dim] * self.in_shape2[self.dim],
            *self.in_shape1[self.dim + 1 :],
        )

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["dim"] = self.dim
        return config

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        # x1: (F, d1, d2, ..., dk1, ... dn)
        # x2: (F, d1, d2, ..., dk2, ... dn)
        x1 = x1.unsqueeze(self.dim + 2)  # (F, d1, d2, ..., dk1, 1, ..., dn)
        x2 = x2.unsqueeze(self.dim + 1)  # (F, d1, d2, ..., 1, dk1, ...., dn)
        x = x1 * x2  # (F, d1, d2, ..., dk1, dk2, ..., dn)
        x = x.view(self.num_folds, *self.shape)  # (F, d1, d2, ..., dk1 * dk2, ..., dn)
        return x

config property ¤

dim = dim instance-attribute ¤

shape property ¤

__init__(in_shape1, in_shape2, dim=-1, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def __init__(
    self,
    in_shape1: tuple[int, ...],
    in_shape2: tuple[int, ...],
    dim: int = -1,
    *,
    num_folds: int = 1,
) -> None:
    assert len(in_shape1) == len(in_shape2)
    dim = dim if dim >= 0 else dim + len(in_shape1)
    assert 0 <= dim < len(in_shape1)
    assert in_shape1[:dim] == in_shape2[:dim]
    assert in_shape1[dim + 1 :] == in_shape2[dim + 1 :]
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)
    self.dim = dim

forward(x1, x2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
556
557
558
559
560
561
562
563
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    # x1: (F, d1, d2, ..., dk1, ... dn)
    # x2: (F, d1, d2, ..., dk2, ... dn)
    x1 = x1.unsqueeze(self.dim + 2)  # (F, d1, d2, ..., dk1, 1, ..., dn)
    x2 = x2.unsqueeze(self.dim + 1)  # (F, d1, d2, ..., 1, dk1, ...., dn)
    x = x1 * x2  # (F, d1, d2, ..., dk1, dk2, ..., dn)
    x = x.view(self.num_folds, *self.shape)  # (F, d1, d2, ..., dk1 * dk2, ..., dn)
    return x

TorchOuterSumParameter ¤

Bases: TorchBinaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
class TorchOuterSumParameter(TorchBinaryParameterOp):
    def __init__(
        self,
        in_shape1: tuple[int, ...],
        in_shape2: tuple[int, ...],
        *,
        num_folds: int = 1,
        dim: int = -1,
    ) -> None:
        assert len(in_shape1) == len(in_shape2)
        dim = dim if dim >= 0 else dim + len(in_shape1)
        assert 0 <= dim < len(in_shape1)
        assert in_shape1[:dim] == in_shape2[:dim]
        assert in_shape1[dim + 1 :] == in_shape2[dim + 1 :]
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)
        self.dim = dim

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            *self.in_shape1[: self.dim],
            self.in_shape1[self.dim] * self.in_shape2[self.dim],
            *self.in_shape1[self.dim + 1 :],
        )

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["dim"] = self.dim
        return config

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        # x1: (F, d1, d2, ..., dk1, ... dn)
        # x2: (F, d1, d2, ..., dk2, ... dn)
        x1 = x1.unsqueeze(self.dim + 2)  # (F, d1, d2, ..., dk1, 1, ..., dn)
        x2 = x2.unsqueeze(self.dim + 1)  # (F, d1, d2, ..., 1, dk1, ...., dn)
        x = x1 + x2  # (F, d1, d2, ..., dk1, dk2, ..., dn)
        x = x.view(self.num_folds, *self.shape)  # (F, d1, d2, ..., dk1 * dk2, ..., dn)
        return x

config property ¤

dim = dim instance-attribute ¤

shape property ¤

__init__(in_shape1, in_shape2, *, num_folds=1, dim=-1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
def __init__(
    self,
    in_shape1: tuple[int, ...],
    in_shape2: tuple[int, ...],
    *,
    num_folds: int = 1,
    dim: int = -1,
) -> None:
    assert len(in_shape1) == len(in_shape2)
    dim = dim if dim >= 0 else dim + len(in_shape1)
    assert 0 <= dim < len(in_shape1)
    assert in_shape1[:dim] == in_shape2[:dim]
    assert in_shape1[dim + 1 :] == in_shape2[dim + 1 :]
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)
    self.dim = dim

forward(x1, x2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
597
598
599
600
601
602
603
604
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    # x1: (F, d1, d2, ..., dk1, ... dn)
    # x2: (F, d1, d2, ..., dk2, ... dn)
    x1 = x1.unsqueeze(self.dim + 2)  # (F, d1, d2, ..., dk1, 1, ..., dn)
    x2 = x2.unsqueeze(self.dim + 1)  # (F, d1, d2, ..., 1, dk1, ...., dn)
    x = x1 + x2  # (F, d1, d2, ..., dk1, dk2, ..., dn)
    x = x.view(self.num_folds, *self.shape)  # (F, d1, d2, ..., dk1 * dk2, ..., dn)
    return x

TorchParameterInput ¤

Bases: TorchParameterNode, ABC

The torch parameter input node. A parameter input is a parameter node in the computational graph that comptues parameter that does not have inputs. See TorchParameter for more details.

Source code in cirkit/backend/torch/parameters/nodes.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class TorchParameterInput(TorchParameterNode, ABC):
    """The torch parameter input node. A parameter input is a parameter node in the
    computational graph that comptues parameter that does __not__ have inputs. See
    [TorchParameter][cirkit.backend.torch.parameters.parameter.TorchParameter] for more details.
    """

    def __call__(self) -> Tensor:
        # IGNORE: Idiom for nn.Module.__call__.
        return super().__call__()  # type: ignore[no-any-return,misc]

    def extra_repr(self) -> str:
        return f"output-shape: {(self.num_folds, *self.shape)}"

    @abstractmethod
    def forward(self) -> Tensor:
        r"""Evaluate a torch parameter input node.

        Returns:
            Tensor: A tensor of shape $(F,K_1,\ldots,K_n)$, where $F$ is the number of folds, and
            $(K_1,\ldots,K_n)$ is the shape of the tensors within each fold.
        """

__call__() ¤

Source code in cirkit/backend/torch/parameters/nodes.py
69
70
71
def __call__(self) -> Tensor:
    # IGNORE: Idiom for nn.Module.__call__.
    return super().__call__()  # type: ignore[no-any-return,misc]

extra_repr() ¤

Source code in cirkit/backend/torch/parameters/nodes.py
73
74
def extra_repr(self) -> str:
    return f"output-shape: {(self.num_folds, *self.shape)}"

forward() abstractmethod ¤

Evaluate a torch parameter input node.

Returns:

Name Type Description
Tensor Tensor

A tensor of shape \((F,K_1,\ldots,K_n)\), where \(F\) is the number of folds, and

Tensor

\((K_1,\ldots,K_n)\) is the shape of the tensors within each fold.

Source code in cirkit/backend/torch/parameters/nodes.py
76
77
78
79
80
81
82
83
@abstractmethod
def forward(self) -> Tensor:
    r"""Evaluate a torch parameter input node.

    Returns:
        Tensor: A tensor of shape $(F,K_1,\ldots,K_n)$, where $F$ is the number of folds, and
        $(K_1,\ldots,K_n)$ is the shape of the tensors within each fold.
    """

TorchParameterNode ¤

Bases: AbstractTorchModule, ABC

The abstract parameter node class. A parameter node is a node in the computational graph that computes parameters. See TorchParameter for more details.

Source code in cirkit/backend/torch/parameters/nodes.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class TorchParameterNode(AbstractTorchModule, ABC):
    """The abstract parameter node class. A parameter node is a node in the computational
    graph that computes parameters.
    See [TorchParameter][cirkit.backend.torch.parameters.parameter.TorchParameter]
    for more details."""

    def __init__(self, *, num_folds: int = 1):
        """Initialize a torch parameter node.

        Args:
            num_folds: The number of folds computed by the node.
        """
        super().__init__(num_folds=num_folds)

    @property
    @abstractmethod
    def shape(self) -> tuple[int, ...]:
        r"""The shape of the tensor folds that the node outputs.
        If the shape is $(K_1,\ldots,K_n)$ and the number of folds is $F$, then the node outputs a
        tensor having overall shape $(F,K_1,\ldots,K_n)$.

        Returns:
            The shape of the thensor folds that the node outputs.
        """

    @property
    def config(self) -> dict[str, Any]:
        """Retrieves the configuration of the parameter node, i.e., a dictionary mapping
        hyperparameters of the parameter node to their values. The hyperparameter names must
        match the argument names in the ```__init__``` method.

        Returns:
            Dict[str, Any]: A dictionary from hyperparameter names to their value.
        """
        return {}

    @property
    def fold_settings(self) -> tuple[Any, ...]:
        return (*self.config.items(),)

    @final
    @property
    def sub_modules(self) -> dict[str, "AbstractTorchModule"]:
        return {}

    @torch.no_grad()
    def reset_parameters(self):
        ...

config property ¤

Retrieves the configuration of the parameter node, i.e., a dictionary mapping hyperparameters of the parameter node to their values. The hyperparameter names must match the argument names in the __init__ method.

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: A dictionary from hyperparameter names to their value.

fold_settings property ¤

shape abstractmethod property ¤

The shape of the tensor folds that the node outputs. If the shape is \((K_1,\ldots,K_n)\) and the number of folds is \(F\), then the node outputs a tensor having overall shape \((F,K_1,\ldots,K_n)\).

Returns:

Type Description
tuple[int, ...]

The shape of the thensor folds that the node outputs.

sub_modules property ¤

__init__(*, num_folds=1) ¤

Initialize a torch parameter node.

Parameters:

Name Type Description Default
num_folds int

The number of folds computed by the node.

1
Source code in cirkit/backend/torch/parameters/nodes.py
19
20
21
22
23
24
25
def __init__(self, *, num_folds: int = 1):
    """Initialize a torch parameter node.

    Args:
        num_folds: The number of folds computed by the node.
    """
    super().__init__(num_folds=num_folds)

reset_parameters() ¤

Source code in cirkit/backend/torch/parameters/nodes.py
58
59
60
@torch.no_grad()
def reset_parameters(self):
    ...

TorchParameterOp ¤

Bases: TorchParameterNode, ABC

Source code in cirkit/backend/torch/parameters/nodes.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
class TorchParameterOp(TorchParameterNode, ABC):
    def __init__(self, *in_shapes: tuple[int, ...], num_folds: int = 1):
        super().__init__(num_folds=num_folds)
        self._in_shapes = in_shapes

    @property
    def in_shapes(self) -> tuple[tuple[int, ...], ...]:
        return self._in_shapes

    @property
    def config(self) -> dict[str, Any]:
        return {"in_shapes": self.in_shapes}

    def __call__(self, *xs: Tensor) -> Tensor:
        """Get the reparameterized parameters.

        Returns:
            Tensor: The parameters after reparameterization.
        """
        # IGNORE: Idiom for nn.Module.__call__.
        return super().__call__(*xs)  # type: ignore[no-any-return,misc]

    def extra_repr(self) -> str:
        return (
            f"input-shapes: {[(self.num_folds, *in_shape) for in_shape in self._in_shapes]}"
            + "\n"
            + f"output-shape: {(self.num_folds, *self.shape)}"
        )

    @abstractmethod
    def forward(self, *xs: Tensor) -> Tensor:
        ...

_in_shapes = in_shapes instance-attribute ¤

config property ¤

in_shapes property ¤

__call__(*xs) ¤

Get the reparameterized parameters.

Returns:

Name Type Description
Tensor Tensor

The parameters after reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
295
296
297
298
299
300
301
302
def __call__(self, *xs: Tensor) -> Tensor:
    """Get the reparameterized parameters.

    Returns:
        Tensor: The parameters after reparameterization.
    """
    # IGNORE: Idiom for nn.Module.__call__.
    return super().__call__(*xs)  # type: ignore[no-any-return,misc]

__init__(*in_shapes, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
283
284
285
def __init__(self, *in_shapes: tuple[int, ...], num_folds: int = 1):
    super().__init__(num_folds=num_folds)
    self._in_shapes = in_shapes

extra_repr() ¤

Source code in cirkit/backend/torch/parameters/nodes.py
304
305
306
307
308
309
def extra_repr(self) -> str:
    return (
        f"input-shapes: {[(self.num_folds, *in_shape) for in_shape in self._in_shapes]}"
        + "\n"
        + f"output-shape: {(self.num_folds, *self.shape)}"
    )

forward(*xs) abstractmethod ¤

Source code in cirkit/backend/torch/parameters/nodes.py
311
312
313
@abstractmethod
def forward(self, *xs: Tensor) -> Tensor:
    ...

TorchPointerParameter ¤

Bases: TorchParameterInput

Source code in cirkit/backend/torch/parameters/nodes.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
class TorchPointerParameter(TorchParameterInput):
    def __init__(
        self, parameter: TorchTensorParameter, *, fold_idx: int | list[int] | None = None
    ) -> None:
        if fold_idx is None:
            num_folds = parameter.num_folds
        elif isinstance(fold_idx, int):
            assert 0 <= fold_idx < parameter.num_folds
            if fold_idx == 0 and parameter.num_folds == 1:
                fold_idx = None
                num_folds = parameter.num_folds
            else:
                fold_idx = [fold_idx]
                num_folds = 1
        else:
            assert isinstance(fold_idx, list)
            assert all(0 <= i < parameter.num_folds for i in fold_idx)
            if fold_idx == list(range(parameter.num_folds)):
                fold_idx = None
                num_folds = parameter.num_folds
            else:
                num_folds = len(fold_idx)
        assert not isinstance(parameter, TorchPointerParameter)
        super().__init__(num_folds=num_folds)
        self._parameter = parameter
        self.register_buffer("_fold_idx", None if fold_idx is None else torch.tensor(fold_idx))

    @property
    def shape(self) -> tuple[int, ...]:
        """The shape of the output parameter."""
        return self._parameter.shape

    @property
    def config(self) -> dict[str, Any]:
        return {"parameter": self._parameter}

    @property
    def fold_idx(self) -> list[int] | None:
        if self._fold_idx is None:
            return None
        return self._fold_idx.cpu().tolist()

    def deref(self) -> TorchTensorParameter:
        return self._parameter

    def forward(self) -> Tensor:
        x = self._parameter()
        if self._fold_idx is None:
            return x
        return x[self._fold_idx]

_parameter = parameter instance-attribute ¤

config property ¤

fold_idx property ¤

shape property ¤

The shape of the output parameter.

__init__(parameter, *, fold_idx=None) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def __init__(
    self, parameter: TorchTensorParameter, *, fold_idx: int | list[int] | None = None
) -> None:
    if fold_idx is None:
        num_folds = parameter.num_folds
    elif isinstance(fold_idx, int):
        assert 0 <= fold_idx < parameter.num_folds
        if fold_idx == 0 and parameter.num_folds == 1:
            fold_idx = None
            num_folds = parameter.num_folds
        else:
            fold_idx = [fold_idx]
            num_folds = 1
    else:
        assert isinstance(fold_idx, list)
        assert all(0 <= i < parameter.num_folds for i in fold_idx)
        if fold_idx == list(range(parameter.num_folds)):
            fold_idx = None
            num_folds = parameter.num_folds
        else:
            num_folds = len(fold_idx)
    assert not isinstance(parameter, TorchPointerParameter)
    super().__init__(num_folds=num_folds)
    self._parameter = parameter
    self.register_buffer("_fold_idx", None if fold_idx is None else torch.tensor(fold_idx))

deref() ¤

Source code in cirkit/backend/torch/parameters/nodes.py
272
273
def deref(self) -> TorchTensorParameter:
    return self._parameter

forward() ¤

Source code in cirkit/backend/torch/parameters/nodes.py
275
276
277
278
279
def forward(self) -> Tensor:
    x = self._parameter()
    if self._fold_idx is None:
        return x
    return x[self._fold_idx]

TorchPolynomialDifferential ¤

Bases: TorchUnaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
class TorchPolynomialDifferential(TorchUnaryParameterOp):
    def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1, order: int = 1) -> None:
        if order <= 0:
            raise ValueError("The order of differentiation must be positive.")
        super().__init__(in_shape, num_folds=num_folds)
        self.order = order

    @property
    def shape(self) -> tuple[int, ...]:
        # if dp1>order, i.e., deg>=order, then diff, else const 0.
        return (
            self.in_shapes[0][0],
            self.in_shapes[0][1] - self.order if self.in_shapes[0][1] > self.order else 1,
        )

    @classmethod
    def _diff_once(cls, x: Tensor) -> Tensor:
        degp1 = x.shape[-1]  # x shape (F, K, dp1).
        arange = torch.arange(1, degp1).to(x)  # shape (deg,).
        return x[..., 1:] * arange  # a_n x^n -> n a_n x^(n-1), with a_0 disappeared.

    def forward(self, coeff: Tensor) -> Tensor:
        if coeff.shape[-1] <= self.order:
            return torch.zeros_like(coeff[..., :1])  # shape (F, K, 1).

        for _ in range(self.order):
            coeff = self._diff_once(coeff)
        return coeff  # shape (F, K, dp1-ord).

order = order instance-attribute ¤

shape property ¤

__init__(in_shape, *, num_folds=1, order=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
967
968
969
970
971
def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1, order: int = 1) -> None:
    if order <= 0:
        raise ValueError("The order of differentiation must be positive.")
    super().__init__(in_shape, num_folds=num_folds)
    self.order = order

_diff_once(x) classmethod ¤

Source code in cirkit/backend/torch/parameters/nodes.py
981
982
983
984
985
@classmethod
def _diff_once(cls, x: Tensor) -> Tensor:
    degp1 = x.shape[-1]  # x shape (F, K, dp1).
    arange = torch.arange(1, degp1).to(x)  # shape (deg,).
    return x[..., 1:] * arange  # a_n x^n -> n a_n x^(n-1), with a_0 disappeared.

forward(coeff) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
987
988
989
990
991
992
993
def forward(self, coeff: Tensor) -> Tensor:
    if coeff.shape[-1] <= self.order:
        return torch.zeros_like(coeff[..., :1])  # shape (F, K, 1).

    for _ in range(self.order):
        coeff = self._diff_once(coeff)
    return coeff  # shape (F, K, dp1-ord).

TorchPolynomialProduct ¤

Bases: TorchBinaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
class TorchPolynomialProduct(TorchBinaryParameterOp):
    # Use default __init__

    @property
    def shape(self) -> tuple[int, ...]:
        return (
            self.in_shapes[0][0] * self.in_shapes[1][0],  # dim K
            self.in_shapes[0][1] + self.in_shapes[1][1] - 1,  # dim dp1
        )

    def forward(self, coeff1: Tensor, coeff2: Tensor) -> Tensor:
        # TODO: torch typing issue.
        fft: Callable[..., Tensor]  # type: ignore[misc]
        ifft: Callable[..., Tensor]  # type: ignore[misc]
        if coeff1.is_complex() or coeff2.is_complex():
            fft = torch.fft.fft
            ifft = torch.fft.ifft
        else:
            fft = torch.fft.rfft
            ifft = torch.fft.irfft

        degp1 = coeff1.shape[-1] + coeff2.shape[-1] - 1  # deg1p1 + deg2p1 - 1 = (deg1 + deg2) + 1.

        spec1 = fft(coeff1, n=degp1, dim=-1)  # shape (F, K1, dp1).
        spec2 = fft(coeff2, n=degp1, dim=-1)  # shape (F, K2, dp1).

        # shape (F, K1, 1, dp1), (F, 1, K2, dp1) -> (F, K1, K2, dp1) -> (F, K1*K2, dp1).
        spec = torch.flatten(
            spec1.unsqueeze(dim=2) * spec2.unsqueeze(dim=1), start_dim=1, end_dim=2
        )

        return ifft(spec, n=degp1, dim=-1)  # shape (F, K1*K2, dp1).

shape property ¤

forward(coeff1, coeff2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
def forward(self, coeff1: Tensor, coeff2: Tensor) -> Tensor:
    # TODO: torch typing issue.
    fft: Callable[..., Tensor]  # type: ignore[misc]
    ifft: Callable[..., Tensor]  # type: ignore[misc]
    if coeff1.is_complex() or coeff2.is_complex():
        fft = torch.fft.fft
        ifft = torch.fft.ifft
    else:
        fft = torch.fft.rfft
        ifft = torch.fft.irfft

    degp1 = coeff1.shape[-1] + coeff2.shape[-1] - 1  # deg1p1 + deg2p1 - 1 = (deg1 + deg2) + 1.

    spec1 = fft(coeff1, n=degp1, dim=-1)  # shape (F, K1, dp1).
    spec2 = fft(coeff2, n=degp1, dim=-1)  # shape (F, K2, dp1).

    # shape (F, K1, 1, dp1), (F, 1, K2, dp1) -> (F, K1, K2, dp1) -> (F, K1*K2, dp1).
    spec = torch.flatten(
        spec1.unsqueeze(dim=2) * spec2.unsqueeze(dim=1), start_dim=1, end_dim=2
    )

    return ifft(spec, n=degp1, dim=-1)  # shape (F, K1*K2, dp1).

TorchReduceLSEParameter ¤

Bases: TorchReduceParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
702
703
704
class TorchReduceLSEParameter(TorchReduceParameterOp):
    def forward(self, x: Tensor) -> Tensor:
        return torch.logsumexp(x, dim=self.dim + 1)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
703
704
def forward(self, x: Tensor) -> Tensor:
    return torch.logsumexp(x, dim=self.dim + 1)

TorchReduceParameterOp ¤

Bases: TorchUnaryParameterOp, ABC

The base class for normalized reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
class TorchReduceParameterOp(TorchUnaryParameterOp, ABC):
    """The base class for normalized reparameterization."""

    # NOTE: This class only serves as the common base of all normalized reparams, but include
    #       nothing more. It's up to the implementations to define further details.
    def __init__(
        self,
        in_shape: tuple[int, ...],
        dim: int = -1,
        *,
        num_folds: int = 1,
    ) -> None:
        dim = dim if dim >= 0 else dim + len(in_shape)
        assert 0 <= dim < len(in_shape)
        super().__init__(in_shape, num_folds=num_folds)
        self.dim = dim

    @property
    def shape(self) -> tuple[int, ...]:
        return *self.in_shape[: self.dim], *self.in_shape[self.dim + 1 :]

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["dim"] = self.dim
        return config

config property ¤

dim = dim instance-attribute ¤

shape property ¤

__init__(in_shape, dim=-1, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
388
389
390
391
392
393
394
395
396
397
398
def __init__(
    self,
    in_shape: tuple[int, ...],
    dim: int = -1,
    *,
    num_folds: int = 1,
) -> None:
    dim = dim if dim >= 0 else dim + len(in_shape)
    assert 0 <= dim < len(in_shape)
    super().__init__(in_shape, num_folds=num_folds)
    self.dim = dim

TorchReduceProductParameter ¤

Bases: TorchReduceParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
697
698
699
class TorchReduceProductParameter(TorchReduceParameterOp):
    def forward(self, x: Tensor) -> Tensor:
        return torch.prod(x, dim=self.dim + 1)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
698
699
def forward(self, x: Tensor) -> Tensor:
    return torch.prod(x, dim=self.dim + 1)

TorchReduceSumParameter ¤

Bases: TorchReduceParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
692
693
694
class TorchReduceSumParameter(TorchReduceParameterOp):
    def forward(self, x: Tensor) -> Tensor:
        return torch.sum(x, dim=self.dim + 1)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
693
694
def forward(self, x: Tensor) -> Tensor:
    return torch.sum(x, dim=self.dim + 1)

TorchScaledSigmoidParameter ¤

Bases: TorchEntrywiseParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
class TorchScaledSigmoidParameter(TorchEntrywiseParameterOp):
    def __init__(
        self, in_shape: tuple[int, ...], vmin: float, vmax: float, *, num_folds: int = 1
    ) -> None:
        super().__init__(in_shape, num_folds=num_folds)
        assert 0 <= vmin < vmax, "Must provide 0 <= vmin < vmax."
        self.vmin = vmin
        self.vmax = vmax

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["vmin"] = self.vmin
        config["vmax"] = self.vmax
        return config

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * (self.vmax - self.vmin) + self.vmin

config property ¤

vmax = vmax instance-attribute ¤

vmin = vmin instance-attribute ¤

__init__(in_shape, vmin, vmax, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
634
635
636
637
638
639
640
def __init__(
    self, in_shape: tuple[int, ...], vmin: float, vmax: float, *, num_folds: int = 1
) -> None:
    super().__init__(in_shape, num_folds=num_folds)
    assert 0 <= vmin < vmax, "Must provide 0 <= vmin < vmax."
    self.vmin = vmin
    self.vmax = vmax

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
649
650
def forward(self, x: Tensor) -> Tensor:
    return torch.sigmoid(x) * (self.vmax - self.vmin) + self.vmin

TorchSigmoidParameter ¤

Bases: TorchEntrywiseParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
628
629
630
class TorchSigmoidParameter(TorchEntrywiseParameterOp):
    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
629
630
def forward(self, x: Tensor) -> Tensor:
    return torch.sigmoid(x)

TorchSoftmaxParameter ¤

Bases: TorchEntrywiseReduceParameterOp

Softmax reparameterization.

Range: (0, 1), 0 available if input is masked, 1 available when only one element valid. Constraints: sum to 1.

Source code in cirkit/backend/torch/parameters/nodes.py
707
708
709
710
711
712
713
714
715
class TorchSoftmaxParameter(TorchEntrywiseReduceParameterOp):
    """Softmax reparameterization.

    Range: (0, 1), 0 available if input is masked, 1 available when only one element valid.
    Constraints: sum to 1.
    """

    def forward(self, x: Tensor) -> Tensor:
        return torch.softmax(x, dim=self.dim + 1)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
714
715
def forward(self, x: Tensor) -> Tensor:
    return torch.softmax(x, dim=self.dim + 1)

TorchSquareParameter ¤

Bases: TorchEntrywiseParameterOp

Square reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
621
622
623
624
625
class TorchSquareParameter(TorchEntrywiseParameterOp):
    """Square reparameterization."""

    def forward(self, x: Tensor) -> Tensor:
        return torch.square(x)

forward(x) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
624
625
def forward(self, x: Tensor) -> Tensor:
    return torch.square(x)

TorchSumParameter ¤

Bases: TorchBinaryParameterOp

Source code in cirkit/backend/torch/parameters/nodes.py
475
476
477
478
479
480
481
482
483
484
485
486
487
class TorchSumParameter(TorchBinaryParameterOp):
    def __init__(
        self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
    ) -> None:
        assert in_shape1 == in_shape2
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)

    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape1

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        return x1 + x2

shape property ¤

__init__(in_shape1, in_shape2, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
476
477
478
479
480
def __init__(
    self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1
) -> None:
    assert in_shape1 == in_shape2
    super().__init__(in_shape1, in_shape2, num_folds=num_folds)

forward(x1, x2) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
486
487
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
    return x1 + x2

TorchTensorParameter ¤

Bases: TorchParameterInput

A torch tensor parameter is a TorchParameterInput that stores a torch.nn.parameter.Parameter object.

Source code in cirkit/backend/torch/parameters/nodes.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class TorchTensorParameter(TorchParameterInput):
    """A torch tensor parameter is a
    [TorchParameterInput][cirkit.backend.torch.parameters.nodes.TorchParameterInput]
    that stores a [torch.nn.parameter.Parameter][torch.nn.parameter.Parameter] object.
    """

    def __init__(
        self,
        *shape: int,
        requires_grad: bool = True,
        dtype: torch.dtype | None = None,
        initializer_: Callable[[Tensor], Tensor] | None = None,
        num_folds: int = 1,
    ):
        r"""Initializes a torch tensor parameter. Given a shape $(K_1,\ldots,K_n)$ and a number of
        folds $F$, it eventually materializes a torch parameter of shape $(F,K_1,\ldots,K_n)$.

        Args:
            *shape: The shape of the tensor parameter folds $(K_1,\ldots,K_n)$.
            requires_grad: Whether the parameter requires the computation of gradients.
            dtype: The data type of the parameter.
                If it is None, then it defaults to the current default torch data type, i.e.,
                it is given by [torch.get_default_dtype][torch.get_default_dtype].
            initializer_: The in-place initializer used to initialize the tensor parameter.
                It is a callable with only a tensor as input. If it is None, then it defaults to
                sampling from a standard normal distribution, i.e.,
                [torch.nn.init.normal_][torch.nn.init.normal_].
            num_folds: The number of folds $F$.
        """
        if dtype is None:
            dtype = torch.get_default_dtype()
        super().__init__(num_folds=num_folds)
        self._shape = shape
        self._ptensor: nn.Parameter | None = None
        self._requires_grad = requires_grad
        self._dtype = dtype
        self._initializer_ = nn.init.normal_ if initializer_ is None else initializer_

    @property
    def shape(self) -> tuple[int, ...]:
        return self._shape

    @property
    def dtype(self) -> torch.dtype:
        """Retrieve the data type of the parameter.

        Returns:
            torch.dtype: The parameter data type.
        """
        return self._dtype

    @property
    def device(self) -> torch.device:
        """Retrieve the device of the parameter.

        Returns:
            torch.device: The parameter device.

        Raises:
            ValueError: If the parameter has not been initialized.
                See the [reset_parameters][cirkit.backend.torch.parameters.nodes.TorchTensorParameter.reset_parameters]
                method.
        """
        if self._ptensor is None:
            raise ValueError(
                "The tensor parameter has not been initialized. " "Use reset_parameters() first"
            )
        return self._ptensor.device

    @property
    def requires_grad(self) -> bool:
        """Retrieve whether the torch parameter requires gradients.

        Returns:
            bool: True if it requires gradients, False otherwise.
        """
        return self._requires_grad

    @requires_grad.setter
    def requires_grad(self, value: bool):
        """Set whether the torch parameter requires gradients.

        Args:
            value: The value to set.
        """
        self._requires_grad = value
        if self._ptensor is not None:
            self._ptensor.requires_grad = value

    @property
    def initializer(self) -> Callable[[Tensor], Tensor]:
        """Retrieve the initializer of the torch tensor parameter.

        Returns:
            Callable[[Tensor], Tensor]: The in-place tensor initializer.
        """
        return self._initializer_

    @property
    def config(self) -> dict[str, Any]:
        return {
            "shape": self._shape,
            "requires_grad": self._requires_grad,
            "dtype": self._dtype,
            "initializer_": self._initializer_,
        }

    @property
    def fold_settings(self) -> tuple[Any, ...]:
        return self._shape, self._requires_grad, self._dtype

    @torch.no_grad()
    def reset_parameters(self) -> None:
        """Allocate and initialize the torch tensor parameter. If the tensor has already been
        allocated, then this function simply call the initializer to reset the parameter values.
        """
        if self._ptensor is None:
            shape = (self.num_folds, *self._shape)
            self._ptensor = nn.Parameter(
                torch.empty(*shape, dtype=self._dtype), requires_grad=self._requires_grad
            )
            self._initializer_(self._ptensor.data)
            return
        self._initializer_(self._ptensor.data)

    def forward(self) -> Tensor:
        r"""Evaluate a torch parameter input node.

        Returns:
            Tensor: A tensor of shape $(F,K_1,\ldots,K_n)$, where $F$ is the number of folds, and
            $(K_1,\ldots,K_n)$ is the shape of the tensors within each fold.

        Raises:
            ValueError: If the parameter has not been initialized.
                See the [reset_parameters][cirkit.backend.torch.parameters.nodes.TorchTensorParameter.reset_parameters]
                method.
        """
        if self._ptensor is None:
            raise ValueError(
                "The tensor parameter has not been initialized. " "Use reset_parameters() first"
            )
        return self._ptensor

_dtype = dtype instance-attribute ¤

_initializer_ = nn.init.normal_ if initializer_ is None else initializer_ instance-attribute ¤

_ptensor = None instance-attribute ¤

_requires_grad = requires_grad instance-attribute ¤

_shape = shape instance-attribute ¤

config property ¤

device property ¤

Retrieve the device of the parameter.

Returns:

Type Description
device

torch.device: The parameter device.

Raises:

Type Description
ValueError

If the parameter has not been initialized. See the reset_parameters method.

dtype property ¤

Retrieve the data type of the parameter.

Returns:

Type Description
dtype

torch.dtype: The parameter data type.

fold_settings property ¤

initializer property ¤

Retrieve the initializer of the torch tensor parameter.

Returns:

Type Description
Callable[[Tensor], Tensor]

Callable[[Tensor], Tensor]: The in-place tensor initializer.

requires_grad property writable ¤

Retrieve whether the torch parameter requires gradients.

Returns:

Name Type Description
bool bool

True if it requires gradients, False otherwise.

shape property ¤

__init__(*shape, requires_grad=True, dtype=None, initializer_=None, num_folds=1) ¤

Initializes a torch tensor parameter. Given a shape \((K_1,\ldots,K_n)\) and a number of folds \(F\), it eventually materializes a torch parameter of shape \((F,K_1,\ldots,K_n)\).

Parameters:

Name Type Description Default
*shape int

The shape of the tensor parameter folds \((K_1,\ldots,K_n)\).

()
requires_grad bool

Whether the parameter requires the computation of gradients.

True
dtype dtype | None

The data type of the parameter. If it is None, then it defaults to the current default torch data type, i.e., it is given by torch.get_default_dtype.

None
initializer_ Callable[[Tensor], Tensor] | None

The in-place initializer used to initialize the tensor parameter. It is a callable with only a tensor as input. If it is None, then it defaults to sampling from a standard normal distribution, i.e., torch.nn.init.normal_.

None
num_folds int

The number of folds \(F\).

1
Source code in cirkit/backend/torch/parameters/nodes.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    *shape: int,
    requires_grad: bool = True,
    dtype: torch.dtype | None = None,
    initializer_: Callable[[Tensor], Tensor] | None = None,
    num_folds: int = 1,
):
    r"""Initializes a torch tensor parameter. Given a shape $(K_1,\ldots,K_n)$ and a number of
    folds $F$, it eventually materializes a torch parameter of shape $(F,K_1,\ldots,K_n)$.

    Args:
        *shape: The shape of the tensor parameter folds $(K_1,\ldots,K_n)$.
        requires_grad: Whether the parameter requires the computation of gradients.
        dtype: The data type of the parameter.
            If it is None, then it defaults to the current default torch data type, i.e.,
            it is given by [torch.get_default_dtype][torch.get_default_dtype].
        initializer_: The in-place initializer used to initialize the tensor parameter.
            It is a callable with only a tensor as input. If it is None, then it defaults to
            sampling from a standard normal distribution, i.e.,
            [torch.nn.init.normal_][torch.nn.init.normal_].
        num_folds: The number of folds $F$.
    """
    if dtype is None:
        dtype = torch.get_default_dtype()
    super().__init__(num_folds=num_folds)
    self._shape = shape
    self._ptensor: nn.Parameter | None = None
    self._requires_grad = requires_grad
    self._dtype = dtype
    self._initializer_ = nn.init.normal_ if initializer_ is None else initializer_

forward() ¤

Evaluate a torch parameter input node.

Returns:

Name Type Description
Tensor Tensor

A tensor of shape \((F,K_1,\ldots,K_n)\), where \(F\) is the number of folds, and

Tensor

\((K_1,\ldots,K_n)\) is the shape of the tensors within each fold.

Raises:

Type Description
ValueError

If the parameter has not been initialized. See the reset_parameters method.

Source code in cirkit/backend/torch/parameters/nodes.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def forward(self) -> Tensor:
    r"""Evaluate a torch parameter input node.

    Returns:
        Tensor: A tensor of shape $(F,K_1,\ldots,K_n)$, where $F$ is the number of folds, and
        $(K_1,\ldots,K_n)$ is the shape of the tensors within each fold.

    Raises:
        ValueError: If the parameter has not been initialized.
            See the [reset_parameters][cirkit.backend.torch.parameters.nodes.TorchTensorParameter.reset_parameters]
            method.
    """
    if self._ptensor is None:
        raise ValueError(
            "The tensor parameter has not been initialized. " "Use reset_parameters() first"
        )
    return self._ptensor

reset_parameters() ¤

Allocate and initialize the torch tensor parameter. If the tensor has already been allocated, then this function simply call the initializer to reset the parameter values.

Source code in cirkit/backend/torch/parameters/nodes.py
197
198
199
200
201
202
203
204
205
206
207
208
209
@torch.no_grad()
def reset_parameters(self) -> None:
    """Allocate and initialize the torch tensor parameter. If the tensor has already been
    allocated, then this function simply call the initializer to reset the parameter values.
    """
    if self._ptensor is None:
        shape = (self.num_folds, *self._shape)
        self._ptensor = nn.Parameter(
            torch.empty(*shape, dtype=self._dtype), requires_grad=self._requires_grad
        )
        self._initializer_(self._ptensor.data)
        return
    self._initializer_(self._ptensor.data)

TorchUnaryParameterOp ¤

Bases: TorchParameterOp, ABC

Source code in cirkit/backend/torch/parameters/nodes.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class TorchUnaryParameterOp(TorchParameterOp, ABC):
    def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1) -> None:
        super().__init__(in_shape, num_folds=num_folds)

    @property
    def in_shape(self) -> tuple[int, ...]:
        (in_shape,) = self.in_shapes
        return in_shape

    @property
    def config(self) -> dict[str, Any]:
        return {"in_shape": self.in_shape}

    def __call__(self, x: Tensor) -> Tensor:
        """Get the reparameterized parameters.

        Returns:
            Tensor: The parameters after reparameterization.
        """
        # IGNORE: Idiom for nn.Module.__call__.
        return super().__call__(x)  # type: ignore[no-any-return,misc]

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:
        ...

config property ¤

in_shape property ¤

__call__(x) ¤

Get the reparameterized parameters.

Returns:

Name Type Description
Tensor Tensor

The parameters after reparameterization.

Source code in cirkit/backend/torch/parameters/nodes.py
329
330
331
332
333
334
335
336
def __call__(self, x: Tensor) -> Tensor:
    """Get the reparameterized parameters.

    Returns:
        Tensor: The parameters after reparameterization.
    """
    # IGNORE: Idiom for nn.Module.__call__.
    return super().__call__(x)  # type: ignore[no-any-return,misc]

__init__(in_shape, *, num_folds=1) ¤

Source code in cirkit/backend/torch/parameters/nodes.py
317
318
def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1) -> None:
    super().__init__(in_shape, num_folds=num_folds)

forward(x) abstractmethod ¤

Source code in cirkit/backend/torch/parameters/nodes.py
338
339
340
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
    ...