Skip to content

optimized

optimized ¤

TorchCPTLayer ¤

Bases: TorchInnerLayer

The Candecomp transposed (CP-T) layer, which is the fusion of a sum layer and a Hadamard layer.

Source code in cirkit/backend/torch/layers/optimized.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class TorchCPTLayer(TorchInnerLayer):
    """The Candecomp transposed (CP-T) layer, which is the fusion of a sum layer and a Hadamard
    layer.
    """

    def __init__(
        self,
        num_input_units: int,
        num_output_units: int,
        arity: int = 2,
        *,
        weight: TorchParameter,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        """Initialize a CP-T layer.

        Args:
            num_input_units: The number of input units.
            num_output_units: The number of output units.
            arity: The arity of the layer, must be 2. Defaults to 2.
            weight: The weight parameter, which must have shape $(F, K_o, K_i)$,
                where $F$ is the number of folds, $K_o$ is the number output units,
                and $K_i$ is the number of input units.

        Raises:
            ValueError: If the number of input and output units are incompatible with the
                shape of the weight parameter.
        """
        super().__init__(
            num_input_units,
            num_output_units,
            arity=arity,
            semiring=semiring,
            num_folds=num_folds,
        )
        if not self._valid_weight_shape(weight):
            raise ValueError(
                f"Expected number of folds {self.num_folds} "
                f"and shape {self._weight_shape} for 'weight', found"
                f"{weight.num_folds} and {weight.shape}, respectively"
            )
        self.weight = weight

    def _valid_weight_shape(self, w: TorchParameter) -> bool:
        if w.num_folds != self.num_folds:
            return False
        return w.shape == self._weight_shape

    @property
    def _weight_shape(self) -> tuple[int, ...]:
        return self.num_output_units, self.num_input_units

    @property
    def config(self) -> Mapping[str, Any]:
        return {
            "num_input_units": self.num_input_units,
            "num_output_units": self.num_output_units,
            "arity": self.arity,
        }

    @property
    def params(self) -> Mapping[str, TorchParameter]:
        return {"weight": self.weight}

    def forward(self, x: Tensor) -> Tensor:
        # x: (F, B, Ki)
        x = self.semiring.prod(x, dim=1, keepdim=False)
        # weight: (F, Ko, Ki)
        weight = self.weight()
        return self.semiring.einsum(
            "fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
        )

    def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
        weight = self.weight()
        negative = torch.any(weight < 0.0)
        if negative:
            raise ValueError("Sampling only works with positive weights")
        normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
        if not normalized:
            raise ValueError("Sampling only works with a normalized parametrization")

        # x: (F, H, C, K, num_samples, D)
        x = torch.sum(x, dim=1, keepdim=True)  # (F, H=1, C, K, num_samples, D)

        c = x.shape[2]
        d = x.shape[-1]
        num_samples = x.shape[-2]

        # mixing_distribution: (F, O, K)
        mixing_distribution = torch.distributions.Categorical(probs=weight)

        mixing_samples = mixing_distribution.sample((num_samples,))
        mixing_samples = E.rearrange(mixing_samples, "n f o -> f o n")
        mixing_indices = E.repeat(mixing_samples, "f o n -> f a c o n d", a=1, c=c, d=d)

        x = torch.gather(x, dim=-3, index=mixing_indices)
        x = x[:, 0]
        return x, mixing_samples

_weight_shape property ¤

config property ¤

params property ¤

weight = weight instance-attribute ¤

__init__(num_input_units, num_output_units, arity=2, *, weight, semiring=None, num_folds=1) ¤

Initialize a CP-T layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units.

required
num_output_units int

The number of output units.

required
arity int

The arity of the layer, must be 2. Defaults to 2.

2
weight TorchParameter

The weight parameter, which must have shape \((F, K_o, K_i)\), where \(F\) is the number of folds, \(K_o\) is the number output units, and \(K_i\) is the number of input units.

required

Raises:

Type Description
ValueError

If the number of input and output units are incompatible with the shape of the weight parameter.

Source code in cirkit/backend/torch/layers/optimized.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def __init__(
    self,
    num_input_units: int,
    num_output_units: int,
    arity: int = 2,
    *,
    weight: TorchParameter,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    """Initialize a CP-T layer.

    Args:
        num_input_units: The number of input units.
        num_output_units: The number of output units.
        arity: The arity of the layer, must be 2. Defaults to 2.
        weight: The weight parameter, which must have shape $(F, K_o, K_i)$,
            where $F$ is the number of folds, $K_o$ is the number output units,
            and $K_i$ is the number of input units.

    Raises:
        ValueError: If the number of input and output units are incompatible with the
            shape of the weight parameter.
    """
    super().__init__(
        num_input_units,
        num_output_units,
        arity=arity,
        semiring=semiring,
        num_folds=num_folds,
    )
    if not self._valid_weight_shape(weight):
        raise ValueError(
            f"Expected number of folds {self.num_folds} "
            f"and shape {self._weight_shape} for 'weight', found"
            f"{weight.num_folds} and {weight.shape}, respectively"
        )
    self.weight = weight

_valid_weight_shape(w) ¤

Source code in cirkit/backend/torch/layers/optimized.py
135
136
137
138
def _valid_weight_shape(self, w: TorchParameter) -> bool:
    if w.num_folds != self.num_folds:
        return False
    return w.shape == self._weight_shape

forward(x) ¤

Source code in cirkit/backend/torch/layers/optimized.py
156
157
158
159
160
161
162
163
def forward(self, x: Tensor) -> Tensor:
    # x: (F, B, Ki)
    x = self.semiring.prod(x, dim=1, keepdim=False)
    # weight: (F, Ko, Ki)
    weight = self.weight()
    return self.semiring.einsum(
        "fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
    )

sample(x) ¤

Source code in cirkit/backend/torch/layers/optimized.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
    weight = self.weight()
    negative = torch.any(weight < 0.0)
    if negative:
        raise ValueError("Sampling only works with positive weights")
    normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
    if not normalized:
        raise ValueError("Sampling only works with a normalized parametrization")

    # x: (F, H, C, K, num_samples, D)
    x = torch.sum(x, dim=1, keepdim=True)  # (F, H=1, C, K, num_samples, D)

    c = x.shape[2]
    d = x.shape[-1]
    num_samples = x.shape[-2]

    # mixing_distribution: (F, O, K)
    mixing_distribution = torch.distributions.Categorical(probs=weight)

    mixing_samples = mixing_distribution.sample((num_samples,))
    mixing_samples = E.rearrange(mixing_samples, "n f o -> f o n")
    mixing_indices = E.repeat(mixing_samples, "f o n -> f a c o n d", a=1, c=c, d=d)

    x = torch.gather(x, dim=-3, index=mixing_indices)
    x = x[:, 0]
    return x, mixing_samples

TorchTensorDotLayer ¤

Bases: TorchInnerLayer

The tensor dot layer performs the following operations. Let \(\mathbf{x}\) be an input tensor of shape \((B, K_i)\), where \(B\) is the batch size, and \(K_i\) is the number of input uits. The tensor dot layer firstly reshapes as the tensor \(\mathcal{Z}\) having shape \((B, K_j, K_q)\), where \(K_i = K_jK_q\). Then, it computes the tensor \(\mathcal{S}\) of shape \((B, K_q, K_k)\) as follows:

\[ \mathcal{S}_{bqk} = \sum_{j=1}^{K_j} w_{kj} z_{bjq} \]

in element-wise notation, where \(\mathbf{W}\) is a tensor of shape \((K_k, K_j)\), where we have that \(K_o = K_qK_k\) is the number of output units. Finally, it returns the output tensor of shape \((B, K_o)\) obtained by flattening the last two dimensions of the tensor \(\mathcal{S}\). Note that the above operations are parallelized w.r.t. the additional fold dimension.

Source code in cirkit/backend/torch/layers/optimized.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class TorchTensorDotLayer(TorchInnerLayer):
    r"""The tensor dot layer performs the following operations.
    Let $\mathbf{x}$ be an input tensor of shape $(B, K_i)$, where $B$ is the batch size,
    and $K_i$ is the number of input uits. The tensor dot layer firstly reshapes as the tensor
    $\mathcal{Z}$ having shape $(B, K_j, K_q)$, where $K_i = K_jK_q$. Then, it computes the
    tensor $\mathcal{S}$ of shape $(B, K_q, K_k)$ as follows:

    $$
    \mathcal{S}_{bqk} = \sum_{j=1}^{K_j} w_{kj} z_{bjq}
    $$

    in element-wise notation, where $\mathbf{W}$ is a tensor of shape $(K_k, K_j)$,
    where we have that $K_o = K_qK_k$ is the number of output units.
    Finally, it returns the output tensor of shape $(B, K_o)$ obtained by flattening the
    last two dimensions of the tensor $\mathcal{S}$. Note that the above operations are
    parallelized w.r.t. the additional fold dimension.
    """

    def __init__(
        self,
        num_input_units: int,
        num_output_units: int,
        *,
        weight: TorchParameter,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        """Initialize a tensor dot layer.

        Args:
            num_input_units: The number of input units $K_i$, such that
                $K_i = K_j K_q$ for some $K_j,K_q$.
            num_output_units: The number of output units $K_o$, such that
                $K_o = K_q K_k$ for some $K_k$.
            weight: The weight parameter, which must have shape $(F, K_k, K_j)$,
                where $F$ is the number of folds, and $K_k,K_j$ are defined
                as in the definition of the number of input and output units above.

        ValueError: If the number of input and output units are incompatible with the
                shape of the weight parameter.
        """
        super().__init__(
            num_input_units,
            num_output_units,
            arity=1,
            semiring=semiring,
            num_folds=num_folds,
        )
        if not self._valid_weight_shape(weight):
            raise ValueError(
                f"Expected number of folds {self.num_folds} "
                f"and shape (K_k, K_j) for 'weight', where "
                f"{self.num_input_units} = K_jK_q and "
                f"{self.num_output_units} = K_qK_k, "
                f"but found {weight.num_folds} and {weight.shape}, respectively"
            )
        self.weight = weight
        self._num_contract_units = weight.shape[1]
        self._num_batch_units = num_input_units // self._num_contract_units

    def _valid_weight_shape(self, w: TorchParameter) -> bool:
        if w.num_folds != self.num_folds:
            return False
        if len(w.shape) != 2:
            return False
        if self.num_input_units % w.shape[1]:
            return False
        if self.num_output_units != w.shape[0] * (self.num_input_units // w.shape[1]):
            return False
        return True

    @property
    def config(self) -> Mapping[str, Any]:
        return {"num_input_units": self.num_input_units, "num_output_units": self.num_output_units}

    @property
    def params(self) -> Mapping[str, TorchParameter]:
        return {"weight": self.weight}

    def forward(self, x: Tensor) -> Tensor:
        # x: (F, H=1, B, Ki) -> (F, B, Ki)
        x = x.squeeze(dim=1)
        # x: (F, B, Ki) -> (F, B, Kj, Kq) -> (F, B, Kq, Kj)
        x = x.view(x.shape[0], x.shape[1], self._num_contract_units, self._num_batch_units)
        x = x.permute(0, 1, 3, 2)
        # weight: (F, Kk, Kj)
        weight = self.weight()
        # y: (F, B, Kq, Kj)
        y = self.semiring.einsum(
            "fbqj,fkj->fbqk", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
        )
        # return y: (F, B, Kq * Kk) = (F, B, Ko)
        return y.view(y.shape[0], y.shape[1], self.num_output_units)

_num_batch_units = num_input_units // self._num_contract_units instance-attribute ¤

_num_contract_units = weight.shape[1] instance-attribute ¤

config property ¤

params property ¤

weight = weight instance-attribute ¤

__init__(num_input_units, num_output_units, *, weight, semiring=None, num_folds=1) ¤

Initialize a tensor dot layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units \(K_i\), such that \(K_i = K_j K_q\) for some \(K_j,K_q\).

required
num_output_units int

The number of output units \(K_o\), such that \(K_o = K_q K_k\) for some \(K_k\).

required
weight TorchParameter

The weight parameter, which must have shape \((F, K_k, K_j)\), where \(F\) is the number of folds, and \(K_k,K_j\) are defined as in the definition of the number of input and output units above.

required
If the number of input and output units are incompatible with the

shape of the weight parameter.

Source code in cirkit/backend/torch/layers/optimized.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def __init__(
    self,
    num_input_units: int,
    num_output_units: int,
    *,
    weight: TorchParameter,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    """Initialize a tensor dot layer.

    Args:
        num_input_units: The number of input units $K_i$, such that
            $K_i = K_j K_q$ for some $K_j,K_q$.
        num_output_units: The number of output units $K_o$, such that
            $K_o = K_q K_k$ for some $K_k$.
        weight: The weight parameter, which must have shape $(F, K_k, K_j)$,
            where $F$ is the number of folds, and $K_k,K_j$ are defined
            as in the definition of the number of input and output units above.

    ValueError: If the number of input and output units are incompatible with the
            shape of the weight parameter.
    """
    super().__init__(
        num_input_units,
        num_output_units,
        arity=1,
        semiring=semiring,
        num_folds=num_folds,
    )
    if not self._valid_weight_shape(weight):
        raise ValueError(
            f"Expected number of folds {self.num_folds} "
            f"and shape (K_k, K_j) for 'weight', where "
            f"{self.num_input_units} = K_jK_q and "
            f"{self.num_output_units} = K_qK_k, "
            f"but found {weight.num_folds} and {weight.shape}, respectively"
        )
    self.weight = weight
    self._num_contract_units = weight.shape[1]
    self._num_batch_units = num_input_units // self._num_contract_units

_valid_weight_shape(w) ¤

Source code in cirkit/backend/torch/layers/optimized.py
253
254
255
256
257
258
259
260
261
262
def _valid_weight_shape(self, w: TorchParameter) -> bool:
    if w.num_folds != self.num_folds:
        return False
    if len(w.shape) != 2:
        return False
    if self.num_input_units % w.shape[1]:
        return False
    if self.num_output_units != w.shape[0] * (self.num_input_units // w.shape[1]):
        return False
    return True

forward(x) ¤

Source code in cirkit/backend/torch/layers/optimized.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def forward(self, x: Tensor) -> Tensor:
    # x: (F, H=1, B, Ki) -> (F, B, Ki)
    x = x.squeeze(dim=1)
    # x: (F, B, Ki) -> (F, B, Kj, Kq) -> (F, B, Kq, Kj)
    x = x.view(x.shape[0], x.shape[1], self._num_contract_units, self._num_batch_units)
    x = x.permute(0, 1, 3, 2)
    # weight: (F, Kk, Kj)
    weight = self.weight()
    # y: (F, B, Kq, Kj)
    y = self.semiring.einsum(
        "fbqj,fkj->fbqk", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
    )
    # return y: (F, B, Kq * Kk) = (F, B, Ko)
    return y.view(y.shape[0], y.shape[1], self.num_output_units)

TorchTuckerLayer ¤

Bases: TorchInnerLayer

The Tucker layer optimized implementation, leveraging a torch.einsum operation.

Source code in cirkit/backend/torch/layers/optimized.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class TorchTuckerLayer(TorchInnerLayer):
    """The Tucker layer optimized implementation, leveraging a ```torch.einsum``` operation."""

    def __init__(
        self,
        num_input_units: int,
        num_output_units: int,
        arity: int = 2,
        *,
        weight: TorchParameter,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        """Initialize a Tucker layer.

        Args:
            num_input_units: The number of input units.
            num_output_units: The number of output units.
            arity: The arity of the layer, must be 2. Defaults to 2.
            weight: The weight parameter, which must have shape $(F, K_o, K_i^2)$,
                where $F$ is the number of folds, $K_o$ is the number output units,
                and $K_i$ is the number of input units.

        Raises:
            NotImplementedError: If the arity is not equal to 2. Future versions of cirkit
                will support Tucker layers having arity greter than 2.
            ValueError: If the number of input and output units are incompatible with the
                shape of the weight parameter.
        """
        # TODO: Generalize Tucker layer to have any arity greater or equal 2
        if arity != 2:
            raise NotImplementedError("The Tucker layer is only implemented with arity=2")
        super().__init__(
            num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
        )
        if not self._valid_weight_shape(weight):
            raise ValueError(
                f"Expected number of folds {self.num_folds} "
                f"and shape {self._weight_shape} for 'weight', found"
                f"{weight.num_folds} and {weight.shape}, respectively"
            )
        self.weight = weight

    def _valid_weight_shape(self, w: TorchParameter) -> bool:
        if w.num_folds != self.num_folds:
            return False
        return w.shape == self._weight_shape

    @property
    def _weight_shape(self) -> tuple[int, ...]:
        return self.num_output_units, self.num_input_units * self.num_input_units

    @property
    def config(self) -> Mapping[str, Any]:
        return {
            "num_input_units": self.num_input_units,
            "num_output_units": self.num_output_units,
            "arity": self.arity,
        }

    @property
    def params(self) -> Mapping[str, TorchParameter]:
        return {"weight": self.weight}

    def forward(self, x: Tensor) -> Tensor:
        # weight: (F, Ko, Ki * Ki) -> (F, Ko, Ki, Ki)
        weight = self.weight().view(
            -1, self.num_output_units, self.num_input_units, self.num_input_units
        )
        return self.semiring.einsum(
            "fbi,fbj,foij->fbo",
            operands=(weight,),
            inputs=(x[:, 0], x[:, 1]),
            dim=-1,
            keepdim=True,
        )

_weight_shape property ¤

config property ¤

params property ¤

weight = weight instance-attribute ¤

__init__(num_input_units, num_output_units, arity=2, *, weight, semiring=None, num_folds=1) ¤

Initialize a Tucker layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units.

required
num_output_units int

The number of output units.

required
arity int

The arity of the layer, must be 2. Defaults to 2.

2
weight TorchParameter

The weight parameter, which must have shape \((F, K_o, K_i^2)\), where \(F\) is the number of folds, \(K_o\) is the number output units, and \(K_i\) is the number of input units.

required

Raises:

Type Description
NotImplementedError

If the arity is not equal to 2. Future versions of cirkit will support Tucker layers having arity greter than 2.

ValueError

If the number of input and output units are incompatible with the shape of the weight parameter.

Source code in cirkit/backend/torch/layers/optimized.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    num_input_units: int,
    num_output_units: int,
    arity: int = 2,
    *,
    weight: TorchParameter,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    """Initialize a Tucker layer.

    Args:
        num_input_units: The number of input units.
        num_output_units: The number of output units.
        arity: The arity of the layer, must be 2. Defaults to 2.
        weight: The weight parameter, which must have shape $(F, K_o, K_i^2)$,
            where $F$ is the number of folds, $K_o$ is the number output units,
            and $K_i$ is the number of input units.

    Raises:
        NotImplementedError: If the arity is not equal to 2. Future versions of cirkit
            will support Tucker layers having arity greter than 2.
        ValueError: If the number of input and output units are incompatible with the
            shape of the weight parameter.
    """
    # TODO: Generalize Tucker layer to have any arity greater or equal 2
    if arity != 2:
        raise NotImplementedError("The Tucker layer is only implemented with arity=2")
    super().__init__(
        num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
    )
    if not self._valid_weight_shape(weight):
        raise ValueError(
            f"Expected number of folds {self.num_folds} "
            f"and shape {self._weight_shape} for 'weight', found"
            f"{weight.num_folds} and {weight.shape}, respectively"
        )
    self.weight = weight

_valid_weight_shape(w) ¤

Source code in cirkit/backend/torch/layers/optimized.py
56
57
58
59
def _valid_weight_shape(self, w: TorchParameter) -> bool:
    if w.num_folds != self.num_folds:
        return False
    return w.shape == self._weight_shape

forward(x) ¤

Source code in cirkit/backend/torch/layers/optimized.py
77
78
79
80
81
82
83
84
85
86
87
88
def forward(self, x: Tensor) -> Tensor:
    # weight: (F, Ko, Ki * Ki) -> (F, Ko, Ki, Ki)
    weight = self.weight().view(
        -1, self.num_output_units, self.num_input_units, self.num_input_units
    )
    return self.semiring.einsum(
        "fbi,fbj,foij->fbo",
        operands=(weight,),
        inputs=(x[:, 0], x[:, 1]),
        dim=-1,
        keepdim=True,
    )