Skip to content

inner

inner ¤

TorchHadamardLayer ¤

Bases: TorchInnerLayer

The Hadamard product layer, which computes an element-wise (or Hadamard) product of the input vectors it receives as inputs. See the symbolic HadamardLayer for more details.

Source code in cirkit/backend/torch/layers/inner.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class TorchHadamardLayer(TorchInnerLayer):
    """The Hadamard product layer, which computes an element-wise (or Hadamard) product of
    the input vectors it receives as inputs.
    See the symbolic [HadamardLayer][cirkit.symbolic.layers.HadamardLayer] for more details.
    """

    def __init__(
        self,
        num_input_units: int,
        arity: int = 2,
        *,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        """Initialize a Hadamard product layer.

        Args:
            num_input_units: The number of input units, which is equal to the number of
                output units.
            arity: The arity of the layer.
            semiring: The evaluation semiring.
                Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
            num_folds: The number of channels.

        Raises:
            ValueError: If the arity is not at least 2.
            ValueError: If the number of input units is not the same as the number of output units.
        """
        if arity < 2:
            raise ValueError("The arity should be at least 2")
        super().__init__(
            num_input_units, num_input_units, arity=arity, semiring=semiring, num_folds=num_folds
        )

    @property
    def config(self) -> Mapping[str, Any]:
        return {
            "num_input_units": self.num_input_units,
            "arity": self.arity,
        }

    def forward(self, x: Tensor) -> Tensor:
        return self.semiring.prod(x, dim=1, keepdim=False)  # shape (F, H, B, K) -> (F, B, K).

    def sample(self, x: Tensor) -> tuple[Tensor, None]:
        # Concatenate samples over disjoint variables through a sum
        # x: (F, H, C, K, num_samples, D)
        x = torch.sum(x, dim=1)  # (F, C, K, num_samples, D)
        return x, None

config property ¤

__init__(num_input_units, arity=2, *, semiring=None, num_folds=1) ¤

Initialize a Hadamard product layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units, which is equal to the number of output units.

required
arity int

The arity of the layer.

2
semiring Semiring | None

The evaluation semiring. Defaults to SumProductSemiring.

None
num_folds int

The number of channels.

1

Raises:

Type Description
ValueError

If the arity is not at least 2.

ValueError

If the number of input units is not the same as the number of output units.

Source code in cirkit/backend/torch/layers/inner.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
    self,
    num_input_units: int,
    arity: int = 2,
    *,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    """Initialize a Hadamard product layer.

    Args:
        num_input_units: The number of input units, which is equal to the number of
            output units.
        arity: The arity of the layer.
        semiring: The evaluation semiring.
            Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
        num_folds: The number of channels.

    Raises:
        ValueError: If the arity is not at least 2.
        ValueError: If the number of input units is not the same as the number of output units.
    """
    if arity < 2:
        raise ValueError("The arity should be at least 2")
    super().__init__(
        num_input_units, num_input_units, arity=arity, semiring=semiring, num_folds=num_folds
    )

forward(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
123
124
def forward(self, x: Tensor) -> Tensor:
    return self.semiring.prod(x, dim=1, keepdim=False)  # shape (F, H, B, K) -> (F, B, K).

sample(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
126
127
128
129
130
def sample(self, x: Tensor) -> tuple[Tensor, None]:
    # Concatenate samples over disjoint variables through a sum
    # x: (F, H, C, K, num_samples, D)
    x = torch.sum(x, dim=1)  # (F, C, K, num_samples, D)
    return x, None

TorchInnerLayer ¤

Bases: TorchLayer, ABC

The abstract base class for inner layers, i.e., either sum or product layers.

Source code in cirkit/backend/torch/layers/inner.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class TorchInnerLayer(TorchLayer, ABC):
    """The abstract base class for inner layers, i.e., either sum or product layers."""

    def __init__(
        self,
        num_input_units: int,
        num_output_units: int,
        arity: int = 2,
        *,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        """Initialize an inner layer.

        Args:
            num_input_units: The number of input units.
            num_output_units: The number of output units.
            arity: The arity of the layer.
            semiring: The evaluation semiring.
                Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
            num_folds: The number of channels.
        """
        super().__init__(
            num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
        )

    @property
    def fold_settings(self) -> tuple[Any, ...]:
        pshapes = [(n, p.shape) for n, p in self.params.items()]
        return *self.config.items(), *pshapes

    def __call__(self, x: Tensor) -> Tensor:
        # IGNORE: Idiom for nn.Module.__call__.
        return super().__call__(x)  # type: ignore[no-any-return,misc]

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:
        """Invoke the forward function.

        Args:
            x: The tensor input to this layer, having shape $(F, H, B, K_i)$, where $F$
                is the number of folds, $H$ is the arity, $B$ is the batch size, and
                $K_i$ is the number of input units.

        Returns:
            Tensor: The tensor output of this layer, having shape $(F, B, K_o)$, where $K_o$
                is the number of output units.
        """

    def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
        """Perform a forward sampling step.

        Args:
            x: A tensor representing the input variable assignments, having shape
                $(F, H, C, K, N, D)$, where $F$ is the number of folds, $H$ is the arity,
                $C$ is the number of channels, $K$ is the numbe rof input units, $N$ is the number
                of samples, $D$ is the number of variables.

        Returns:
            Tensor: A new tensor representing the new variable assignements the layers gives
                as output.

        Raises:
            TypeError: If sampling is not supported by the layer.
        """
        raise TypeError(f"Sampling not implemented for {type(self)}")

fold_settings property ¤

__call__(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
45
46
47
def __call__(self, x: Tensor) -> Tensor:
    # IGNORE: Idiom for nn.Module.__call__.
    return super().__call__(x)  # type: ignore[no-any-return,misc]

__init__(num_input_units, num_output_units, arity=2, *, semiring=None, num_folds=1) ¤

Initialize an inner layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units.

required
num_output_units int

The number of output units.

required
arity int

The arity of the layer.

2
semiring Semiring | None

The evaluation semiring. Defaults to SumProductSemiring.

None
num_folds int

The number of channels.

1
Source code in cirkit/backend/torch/layers/inner.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    num_input_units: int,
    num_output_units: int,
    arity: int = 2,
    *,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    """Initialize an inner layer.

    Args:
        num_input_units: The number of input units.
        num_output_units: The number of output units.
        arity: The arity of the layer.
        semiring: The evaluation semiring.
            Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
        num_folds: The number of channels.
    """
    super().__init__(
        num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
    )

forward(x) abstractmethod ¤

Invoke the forward function.

Parameters:

Name Type Description Default
x Tensor

The tensor input to this layer, having shape \((F, H, B, K_i)\), where \(F\) is the number of folds, \(H\) is the arity, \(B\) is the batch size, and \(K_i\) is the number of input units.

required

Returns:

Name Type Description
Tensor Tensor

The tensor output of this layer, having shape \((F, B, K_o)\), where \(K_o\) is the number of output units.

Source code in cirkit/backend/torch/layers/inner.py
49
50
51
52
53
54
55
56
57
58
59
60
61
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
    """Invoke the forward function.

    Args:
        x: The tensor input to this layer, having shape $(F, H, B, K_i)$, where $F$
            is the number of folds, $H$ is the arity, $B$ is the batch size, and
            $K_i$ is the number of input units.

    Returns:
        Tensor: The tensor output of this layer, having shape $(F, B, K_o)$, where $K_o$
            is the number of output units.
    """

sample(x) ¤

Perform a forward sampling step.

Parameters:

Name Type Description Default
x Tensor

A tensor representing the input variable assignments, having shape \((F, H, C, K, N, D)\), where \(F\) is the number of folds, \(H\) is the arity, \(C\) is the number of channels, \(K\) is the numbe rof input units, \(N\) is the number of samples, \(D\) is the number of variables.

required

Returns:

Name Type Description
Tensor tuple[Tensor, Tensor | None]

A new tensor representing the new variable assignements the layers gives as output.

Raises:

Type Description
TypeError

If sampling is not supported by the layer.

Source code in cirkit/backend/torch/layers/inner.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
    """Perform a forward sampling step.

    Args:
        x: A tensor representing the input variable assignments, having shape
            $(F, H, C, K, N, D)$, where $F$ is the number of folds, $H$ is the arity,
            $C$ is the number of channels, $K$ is the numbe rof input units, $N$ is the number
            of samples, $D$ is the number of variables.

    Returns:
        Tensor: A new tensor representing the new variable assignements the layers gives
            as output.

    Raises:
        TypeError: If sampling is not supported by the layer.
    """
    raise TypeError(f"Sampling not implemented for {type(self)}")

TorchKroneckerLayer ¤

Bases: TorchInnerLayer

The Kronecker product layer, which computes the Kronecker product of the input vectors it receives as input. See the symbolic KroneckerLayer for more details.

Source code in cirkit/backend/torch/layers/inner.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class TorchKroneckerLayer(TorchInnerLayer):
    """The Kronecker product layer, which computes the Kronecker product of the input vectors
    it receives as input.
    See the symbolic [KroneckerLayer][cirkit.symbolic.layers.KroneckerLayer] for more details.
    """

    def __init__(
        self,
        num_input_units: int,
        arity: int = 2,
        *,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        """Initialize a Kronecker product layer.

        Args:
            num_input_units: The number of input units. The number of output units is the power of
                the number of input units to the arity.
            arity: The arity of the layer. Defaults to 2 (which is the only supported arity).
            semiring: The evaluation semiring.
                Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
            num_folds: The number of channels.

        Raises:
            NotImplementedError: If the arity is not 2.
            ValueError: If the number of input units is not the same as the number of output units.
        """
        # TODO: generalize kronecker layer as to support a greater arity
        if arity != 2:
            raise NotImplementedError("Kronecker only implemented for binary product units.")
        super().__init__(
            num_input_units,
            num_input_units**arity,
            arity=arity,
            semiring=semiring,
            num_folds=num_folds,
        )

    @property
    def config(self) -> Mapping[str, Any]:
        return {
            "num_input_units": self.num_input_units,
            "arity": self.arity,
        }

    def forward(self, x: Tensor) -> Tensor:
        x0 = x[:, 0].unsqueeze(dim=-1)  # shape (F, B, Ki, 1).
        x1 = x[:, 1].unsqueeze(dim=-2)  # shape (F, B, 1, Ki).
        # shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
        return self.semiring.mul(x0, x1).flatten(start_dim=-2)

    def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
        # x: (F, H, C, K, num_samples, D)
        x0 = x[:, 0].unsqueeze(dim=3)  # (F, C, Ki, 1, num_samples, D)
        x1 = x[:, 1].unsqueeze(dim=2)  # (F, C, 1, Ki, num_samples, D)
        # shape (F, C, Ki, Ki, num_samples, D) -> (F, C, Ko=Ki**2, num_samples, D)
        x = x0 + x1
        return torch.flatten(x, start_dim=2, end_dim=3), None

config property ¤

__init__(num_input_units, arity=2, *, semiring=None, num_folds=1) ¤

Initialize a Kronecker product layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units. The number of output units is the power of the number of input units to the arity.

required
arity int

The arity of the layer. Defaults to 2 (which is the only supported arity).

2
semiring Semiring | None

The evaluation semiring. Defaults to SumProductSemiring.

None
num_folds int

The number of channels.

1

Raises:

Type Description
NotImplementedError

If the arity is not 2.

ValueError

If the number of input units is not the same as the number of output units.

Source code in cirkit/backend/torch/layers/inner.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def __init__(
    self,
    num_input_units: int,
    arity: int = 2,
    *,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    """Initialize a Kronecker product layer.

    Args:
        num_input_units: The number of input units. The number of output units is the power of
            the number of input units to the arity.
        arity: The arity of the layer. Defaults to 2 (which is the only supported arity).
        semiring: The evaluation semiring.
            Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
        num_folds: The number of channels.

    Raises:
        NotImplementedError: If the arity is not 2.
        ValueError: If the number of input units is not the same as the number of output units.
    """
    # TODO: generalize kronecker layer as to support a greater arity
    if arity != 2:
        raise NotImplementedError("Kronecker only implemented for binary product units.")
    super().__init__(
        num_input_units,
        num_input_units**arity,
        arity=arity,
        semiring=semiring,
        num_folds=num_folds,
    )

forward(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
179
180
181
182
183
def forward(self, x: Tensor) -> Tensor:
    x0 = x[:, 0].unsqueeze(dim=-1)  # shape (F, B, Ki, 1).
    x1 = x[:, 1].unsqueeze(dim=-2)  # shape (F, B, 1, Ki).
    # shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
    return self.semiring.mul(x0, x1).flatten(start_dim=-2)

sample(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
185
186
187
188
189
190
191
def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
    # x: (F, H, C, K, num_samples, D)
    x0 = x[:, 0].unsqueeze(dim=3)  # (F, C, Ki, 1, num_samples, D)
    x1 = x[:, 1].unsqueeze(dim=2)  # (F, C, 1, Ki, num_samples, D)
    # shape (F, C, Ki, Ki, num_samples, D) -> (F, C, Ko=Ki**2, num_samples, D)
    x = x0 + x1
    return torch.flatten(x, start_dim=2, end_dim=3), None

TorchSumLayer ¤

Bases: TorchInnerLayer

The sum layer torch implementation. See the symbolic SumLayer for more details.

Source code in cirkit/backend/torch/layers/inner.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
class TorchSumLayer(TorchInnerLayer):
    """The sum layer torch implementation.
    See the symbolic [SumLayer][cirkit.symbolic.layers.SumLayer] for more details."""

    def __init__(
        self,
        num_input_units: int,
        num_output_units: int,
        arity: int = 1,
        *,
        weight: TorchParameter,
        semiring: Semiring | None = None,
        num_folds: int = 1,
    ):
        r"""Initialize a sum layer.

        Args:
            num_input_units: The number of input units.
            num_output_units: The number of output units.
            arity: The arity of the layer.
            weight: The weight parameter, which must have shape $(F, K_o, K_i\cdot H)$,
                where $F$ is the number of folds, $K_o$ is the number of output units,
                   $K_i$ is the number of input units, and $H$ is the arity.
            semiring: The evaluation semiring.
                Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
            num_folds: The number of channels.

        Raises:
            ValueError: If the arity is not a positive integer.
            ValueError: If the arity, the number of input and output units are incompatible with the
                shape of the weight parameter.
        """
        if arity < 1:
            raise ValueError("The arity must be a positive integer")
        super().__init__(
            num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
        )
        if not self._valid_weight_shape(weight):
            raise ValueError(
                f"Expected number of folds {self.num_folds} "
                f"and shape {self._weight_shape} for 'weight', found"
                f"{weight.num_folds} and {weight.shape}, respectively"
            )
        self.weight = weight

    def _valid_weight_shape(self, w: TorchParameter) -> bool:
        if w.num_folds != self.num_folds:
            return False
        return w.shape == self._weight_shape

    @property
    def _weight_shape(self) -> tuple[int, ...]:
        return self.num_output_units, self.num_input_units * self.arity

    @property
    def config(self) -> Mapping[str, Any]:
        return {
            "num_input_units": self.num_input_units,
            "num_output_units": self.num_output_units,
            "arity": self.arity,
        }

    @property
    def params(self) -> Mapping[str, TorchParameter]:
        return {"weight": self.weight}

    def forward(self, x: Tensor) -> Tensor:
        # x: (F, H, B, Ki) -> (F, B, H * Ki)
        # weight: (F, Ko, H * Ki)
        x = x.permute(0, 2, 1, 3).flatten(start_dim=2)
        weight = self.weight()
        return self.semiring.einsum(
            "fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
        )  # shape (F, B, K_o).

    def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
        weight = self.weight()
        negative = torch.any(weight < 0.0)
        normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
        if negative or not normalized:
            raise TypeError("Sampling in sum layers only works with positive weights summing to 1")

        # x: (F, H, C, Ki, num_samples, D) -> (F, C, H * Ki, num_samples, D)
        x = x.permute(0, 2, 1, 3, 4, 5).flatten(2, 3)
        c = x.shape[1]
        num_samples = x.shape[3]
        d = x.shape[4]

        # mixing_distribution: (F, Ko, H * Ki)
        mixing_distribution = torch.distributions.Categorical(probs=weight)

        # mixing_samples: (num_samples, F, Ko) -> (F, Ko, num_samples)
        mixing_samples = mixing_distribution.sample((num_samples,))
        mixing_samples = E.rearrange(mixing_samples, "n f k -> f k n")

        # mixing_indices: (F, C, Ko, num_samples, D)
        mixing_indices = E.repeat(mixing_samples, "f k n -> f c k n d", c=c, d=d)

        # x: (F, C, Ko, num_samples, D)
        x = torch.gather(x, dim=2, index=mixing_indices)
        return x, mixing_samples

_weight_shape property ¤

config property ¤

params property ¤

weight = weight instance-attribute ¤

__init__(num_input_units, num_output_units, arity=1, *, weight, semiring=None, num_folds=1) ¤

Initialize a sum layer.

Parameters:

Name Type Description Default
num_input_units int

The number of input units.

required
num_output_units int

The number of output units.

required
arity int

The arity of the layer.

1
weight TorchParameter

The weight parameter, which must have shape \((F, K_o, K_i\cdot H)\), where \(F\) is the number of folds, \(K_o\) is the number of output units, \(K_i\) is the number of input units, and \(H\) is the arity.

required
semiring Semiring | None

The evaluation semiring. Defaults to SumProductSemiring.

None
num_folds int

The number of channels.

1

Raises:

Type Description
ValueError

If the arity is not a positive integer.

ValueError

If the arity, the number of input and output units are incompatible with the shape of the weight parameter.

Source code in cirkit/backend/torch/layers/inner.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def __init__(
    self,
    num_input_units: int,
    num_output_units: int,
    arity: int = 1,
    *,
    weight: TorchParameter,
    semiring: Semiring | None = None,
    num_folds: int = 1,
):
    r"""Initialize a sum layer.

    Args:
        num_input_units: The number of input units.
        num_output_units: The number of output units.
        arity: The arity of the layer.
        weight: The weight parameter, which must have shape $(F, K_o, K_i\cdot H)$,
            where $F$ is the number of folds, $K_o$ is the number of output units,
               $K_i$ is the number of input units, and $H$ is the arity.
        semiring: The evaluation semiring.
            Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].
        num_folds: The number of channels.

    Raises:
        ValueError: If the arity is not a positive integer.
        ValueError: If the arity, the number of input and output units are incompatible with the
            shape of the weight parameter.
    """
    if arity < 1:
        raise ValueError("The arity must be a positive integer")
    super().__init__(
        num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
    )
    if not self._valid_weight_shape(weight):
        raise ValueError(
            f"Expected number of folds {self.num_folds} "
            f"and shape {self._weight_shape} for 'weight', found"
            f"{weight.num_folds} and {weight.shape}, respectively"
        )
    self.weight = weight

_valid_weight_shape(w) ¤

Source code in cirkit/backend/torch/layers/inner.py
239
240
241
242
def _valid_weight_shape(self, w: TorchParameter) -> bool:
    if w.num_folds != self.num_folds:
        return False
    return w.shape == self._weight_shape

forward(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
260
261
262
263
264
265
266
267
def forward(self, x: Tensor) -> Tensor:
    # x: (F, H, B, Ki) -> (F, B, H * Ki)
    # weight: (F, Ko, H * Ki)
    x = x.permute(0, 2, 1, 3).flatten(start_dim=2)
    weight = self.weight()
    return self.semiring.einsum(
        "fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
    )  # shape (F, B, K_o).

sample(x) ¤

Source code in cirkit/backend/torch/layers/inner.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
    weight = self.weight()
    negative = torch.any(weight < 0.0)
    normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
    if negative or not normalized:
        raise TypeError("Sampling in sum layers only works with positive weights summing to 1")

    # x: (F, H, C, Ki, num_samples, D) -> (F, C, H * Ki, num_samples, D)
    x = x.permute(0, 2, 1, 3, 4, 5).flatten(2, 3)
    c = x.shape[1]
    num_samples = x.shape[3]
    d = x.shape[4]

    # mixing_distribution: (F, Ko, H * Ki)
    mixing_distribution = torch.distributions.Categorical(probs=weight)

    # mixing_samples: (num_samples, F, Ko) -> (F, Ko, num_samples)
    mixing_samples = mixing_distribution.sample((num_samples,))
    mixing_samples = E.rearrange(mixing_samples, "n f k -> f k n")

    # mixing_indices: (F, C, Ko, num_samples, D)
    mixing_indices = E.repeat(mixing_samples, "f k n -> f c k n d", c=c, d=d)

    # x: (F, C, Ko, num_samples, D)
    x = torch.gather(x, dim=2, index=mixing_indices)
    return x, mixing_samples