Skip to content

base

base ¤

TorchLayer ¤

Bases: AbstractTorchModule, ABC

The abstract base class for all layers implemented in torch.

Source code in cirkit/backend/torch/layers/base.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class TorchLayer(AbstractTorchModule, ABC):
    """The abstract base class for all layers implemented in torch."""

    def __init__(
        self,
        num_input_units: int,
        num_output_units: int,
        arity: int = 1,
        *,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ) -> None:
        """Initialize a layer.

        Args:
            num_input_units: The number of input units.
            num_output_units: The number of output units.
            arity: The arity of the layer.
            semiring: The evaluation semiring.
                Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
            num_folds: The number of folds.

        Raises:
            ValueError: If the number of input units is negative.
            ValueError: If the number of output units is not positive.
            ValueError: If the arity is not positive.
        """
        if num_input_units < 0:
            raise ValueError("The number of input units must be non-negative")
        if num_output_units <= 0:
            raise ValueError("The number of output units must be positive")
        if arity <= 0:
            raise ValueError("The arity must be positive")
        super().__init__(num_folds=num_folds)
        self.num_input_units = num_input_units
        self.num_output_units = num_output_units
        self.arity = arity
        self.semiring = semiring if semiring is not None else SumProductSemiring

    @property
    @abstractmethod
    def config(self) -> Mapping[str, Any]:
        """Retrieves the configuration of the layer, i.e., a dictionary mapping hyperparameters
        of the layer to their values. The hyperparameter names must match the argument names in
        the ```__init__``` method.

        Returns:
            Mapping[str, Any]: A dictionary from hyperparameter names to their value.
        """

    @property
    def params(self) -> Mapping[str, TorchParameter]:
        """Retrieve the torch parameters of the layer, i.e., a dictionary mapping the names of
        the torch parameters to the actual torch parameter instance. The parameter names must
        match the argument names in the```__init__``` method.

        Returns:
            Mapping[str, TorchParameter]: A dictionary from parameter names to the corresponding
                torch parameter instance.
        """
        return {}

    @property
    def sub_modules(self) -> Mapping[str, "TorchLayer"]:
        """Retrieve a dictionary mapping string identifiers to torch sub-module layers.,
        that must be passed to the ```__init__``` method of the top-level layer

        Returns:
            A dictionary of torch modules.
        """
        return {}

    @cached_property
    def num_parameters(self) -> int:
        """Retrieve the number of scalar parameters. Note that if a parameter is complex-valued,
        this will double count them.

        Returns:
            The number of scalar parameters.
        """
        return sum(2 * p.numel() if torch.is_complex(p) else p.numel() for p in self.parameters())

    @cached_property
    def num_buffers(self) -> int:
        """Retrieve the number of scalar buffers. Note that if a buffer is complex-valued,
        this will double count them.

        Returns:
            The number of scalar buffers.
        """
        return sum(2 * b.numel() if torch.is_complex(b) else b.numel() for b in self.buffers())

    def extra_repr(self) -> str:
        return (
            "  ".join(
                [
                    f"folds: {self.num_folds}",
                    f"arity: {self.arity}",
                    f"input-units: {self.num_input_units}",
                    f"output-units: {self.num_output_units}",
                ]
            )
            + "\n"
            + f"input-shape: {(self.num_folds, self.arity, -1, self.num_input_units)}"
            + "\n"
            + f"output-shape: {(self.num_folds, -1, self.num_output_units)}"
        )

arity = arity instance-attribute ¤

config abstractmethod property ¤

Retrieves the configuration of the layer, i.e., a dictionary mapping hyperparameters of the layer to their values. The hyperparameter names must match the argument names in the __init__ method.

Returns:

Type Description
Mapping[str, Any]

Mapping[str, Any]: A dictionary from hyperparameter names to their value.

num_buffers cached property ¤

Retrieve the number of scalar buffers. Note that if a buffer is complex-valued, this will double count them.

Returns:

Type Description
int

The number of scalar buffers.

num_input_units = num_input_units instance-attribute ¤

num_output_units = num_output_units instance-attribute ¤

num_parameters cached property ¤

Retrieve the number of scalar parameters. Note that if a parameter is complex-valued, this will double count them.

Returns:

Type Description
int

The number of scalar parameters.

params property ¤

Retrieve the torch parameters of the layer, i.e., a dictionary mapping the names of the torch parameters to the actual torch parameter instance. The parameter names must match the argument names in the__init__ method.

Returns:

Type Description
Mapping[str, TorchParameter]

Mapping[str, TorchParameter]: A dictionary from parameter names to the corresponding torch parameter instance.

semiring = semiring if semiring is not None else SumProductSemiring instance-attribute ¤

sub_modules property ¤

Retrieve a dictionary mapping string identifiers to torch sub-module layers., that must be passed to the __init__ method of the top-level layer

Returns:

Type Description
Mapping[str, TorchLayer]

A dictionary of torch modules.

__init__(num_input_units, num_output_units, arity=1, *, semiring=None, num_folds=1) ¤

Initialize a layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units.

required
num_output_units int

The number of output units.

required
arity int

The arity of the layer.

1
semiring Semiring | None

The evaluation semiring. Defaults to SumProductSemiring.

None
num_folds int

The number of folds.

1

Raises:

Type Description
ValueError

If the number of input units is negative.

ValueError

If the number of output units is not positive.

ValueError

If the arity is not positive.

Source code in cirkit/backend/torch/layers/base.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    num_input_units: int,
    num_output_units: int,
    arity: int = 1,
    *,
    semiring: Semiring | None = None,
    num_folds: int = 1,
) -> None:
    """Initialize a layer.

    Args:
        num_input_units: The number of input units.
        num_output_units: The number of output units.
        arity: The arity of the layer.
        semiring: The evaluation semiring.
            Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
        num_folds: The number of folds.

    Raises:
        ValueError: If the number of input units is negative.
        ValueError: If the number of output units is not positive.
        ValueError: If the arity is not positive.
    """
    if num_input_units < 0:
        raise ValueError("The number of input units must be non-negative")
    if num_output_units <= 0:
        raise ValueError("The number of output units must be positive")
    if arity <= 0:
        raise ValueError("The arity must be positive")
    super().__init__(num_folds=num_folds)
    self.num_input_units = num_input_units
    self.num_output_units = num_output_units
    self.arity = arity
    self.semiring = semiring if semiring is not None else SumProductSemiring

extra_repr() ¤

Source code in cirkit/backend/torch/layers/base.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def extra_repr(self) -> str:
    return (
        "  ".join(
            [
                f"folds: {self.num_folds}",
                f"arity: {self.arity}",
                f"input-units: {self.num_input_units}",
                f"output-units: {self.num_output_units}",
            ]
        )
        + "\n"
        + f"input-shape: {(self.num_folds, self.arity, -1, self.num_input_units)}"
        + "\n"
        + f"output-shape: {(self.num_folds, -1, self.num_output_units)}"
    )