Skip to content

queries

queries ¤

IntegrateQuery ¤

Bases: Query

The integration query object allows marginalising out variables.

Computes output in two forward passes

a) The normal circuit forward pass for input x b) The integration forward pass where all variables are marginalised

A mask over random variables is computed based on the scopes passed as input. This determines whether the integrated or normal circuit result is returned for each variable.

Source code in cirkit/backend/torch/queries.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class IntegrateQuery(Query):
    """The integration query object allows marginalising out variables.

    Computes output in two forward passes:
        a) The normal circuit forward pass for input x
        b) The integration forward pass where all variables are marginalised

    A mask over random variables is computed based on the scopes passed as
    input. This determines whether the integrated or normal circuit result
    is returned for each variable.
    """

    def __init__(self, circuit: TorchCircuit) -> None:
        """Initialize an integration query object.

        Args:
            circuit: The circuit to integrate over.

        Raises:
            ValueError: If the circuit to integrate is not smooth or not decomposable.
        """
        if not circuit.properties.smooth or not circuit.properties.decomposable:
            raise ValueError(
                f"The circuit to integrate must be smooth and decomposable, "
                f"but found {circuit.properties}"
            )
        super().__init__()
        self._circuit = circuit

    def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope]) -> Tensor:
        """Solve an integration query, given an input batch and the variables to integrate.

        Args:
            x: An input batch of shape $(B, C, D)$, where $B$ is the batch size, $C$ is the number
                of channels per variable, and $D$ is the number of variables.
            integrate_vars: The variables to integrate. It must be a subset of the variables on
                which the circuit given in the constructor is defined on.
                The format can be one of the following three:
                    1. Tensor of shape (B, D) where B is the batch size and D is the number of
                        variables in the scope of the circuit. Its dtype should be torch.bool
                        and have True in the positions of random variables that should be
                        marginalised out and False elsewhere.
                    2. Scope, in this case the same integration mask is applied for all entries
                        of the batch
                    3. List of Scopes, where the length of the list must be either 1 or B. If
                        the list has length 1, behaves as above.
        Returns:
            The result of the integration query, given as a tensor of shape $(B, O, K)$,
                where $B$ is the batch size, $O$ is the number of output vectors of the circuit, and
                $K$ is the number of units in each output vector.
        """
        if isinstance(integrate_vars, Tensor):
            # Check type of tensor is boolean
            if integrate_vars.dtype != torch.bool:
                raise ValueError(
                    f"Expected dtype of tensor to be torch.bool, got {integrate_vars.dtype}"
                )
            # If single dimensional tensor, assume batch size = 1
            if len(integrate_vars.shape) == 1:
                integrate_vars = torch.unsqueeze(integrate_vars, 0)
            # If the scope is correct, proceed, otherwise error
            num_vars = max(self._circuit.scope) + 1
            if integrate_vars.shape[1] == num_vars:
                integrate_vars_mask = integrate_vars
            else:
                raise ValueError(
                    f"Circuit scope has {num_vars} variables but integrate_vars "
                    f"was defined over {integrate_vars.shape[1]} != {num_vars} variables"
                )
        else:
            # Convert list of scopes to a boolean mask of dimension (B, N) where
            # N is the number of variables in the circuit's scope.
            integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars)
            integrate_vars_mask = integrate_vars_mask.to(x.device)

        # Check batch sizes of input x and mask are compatible
        if integrate_vars_mask.shape[0] not in (1, x.shape[0]):
            raise ValueError(
                "The number of scopes to integrate over must "
                "either match the batch size of x, or be 1 if you "
                "want to broadcast. Found #inputs = "
                f"{x.shape[0]} != {integrate_vars_mask.shape[0]} = len(integrate_vars)"
            )

        output = self._circuit.evaluate(
            x,
            module_fn=functools.partial(
                IntegrateQuery._layer_fn, integrate_vars_mask=integrate_vars_mask
            ),
        )  # (O, B, K)
        return output.transpose(0, 1)  # (B, O, K)

    @staticmethod
    def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> Tensor:
        # Evaluate a layer: if it is not an input layer, then evaluate it in the usual
        # feed-forward way. Otherwise, use the variables to integrate to solve the marginal
        # queries on the input layers.
        output = layer(x)  # (F, B, Ko)
        if not isinstance(layer, TorchInputLayer):
            return output
        if layer.num_variables > 1:
            raise NotImplementedError("Integration of multivariate input layers is not supported")
        # integrate_vars_mask is a boolean tensor of dim (B, N)
        # where N is the number of variables in the scope of the whole circuit.
        #
        # layer.scope_idx contains a subset of the variable_idxs of the scope
        # but may be a reshaped tensor; the shape and order of the variables may be different.
        #
        # as such, we need to use the idxs in layer.scope_idx to lookup the values from
        # the integrate_vars_mask - this will return the correct shape and values.
        #
        # if integrate_vars_mask was a vector, we could do integrate_vars_mask[layer.scope_idx]
        # the vmap below applies the above across the B dimension

        # integration_mask has dimension (B, F, Ko)
        integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask)
        # permute to match integration_output: integration_mask has dimension (F, B, Ko)
        integration_mask = integration_mask.permute([1, 0, 2])

        if not torch.any(integration_mask).item():
            return output

        integration_output = layer.integrate()
        # Use the integration mask to select which output should be the result of
        # an integration operation, and which should not be
        # This is done in parallel for all folds, and regardless of whether the
        # circuit is folded or unfolded
        return torch.where(integration_mask, integration_output, output)

    @staticmethod
    def scopes_to_mask(circuit: TorchCircuit, batch_integrate_vars: Scope | list[Scope]):
        """Accepts a batch of scopes and returns a boolean mask as a tensor with
        True in positions of specified scope indices and False otherwise.
        """
        # If we passed a single scope, assume B = 1
        if isinstance(batch_integrate_vars, Scope):
            batch_integrate_vars = [batch_integrate_vars]

        batch_size = len(tuple(batch_integrate_vars))
        # There are cases where the circuit.scope may change,
        # e.g. we may marginalise out X_1 and the length of the scope may be smaller
        # but the actual scope will not have been shifted.
        num_rvs = max(circuit.scope) + 1
        num_idxs = sum(len(s) for s in batch_integrate_vars)

        # TODO: Maybe consider using a sparse tensor
        mask = torch.zeros((batch_size, num_rvs), dtype=torch.bool)

        # Catch case of only empty scopes where the following command will fail
        if num_idxs == 0:
            return mask

        batch_idxs, rv_idxs = zip(
            *((i, idx) for i, idxs in enumerate(batch_integrate_vars) for idx in idxs if idxs)
        )

        # Check that we have not asked to marginalise variables that are not defined
        invalid_idxs = Scope(rv_idxs) - circuit.scope
        if invalid_idxs:
            raise ValueError(
                "The variables to marginalize must be a subset of "
                "the circuit scope. Invalid variables "
                f"not in scope: {list(invalid_idxs)} "
            )

        mask[batch_idxs, rv_idxs] = True

        return mask

_circuit = circuit instance-attribute ¤

__call__(x, *, integrate_vars) ¤

Solve an integration query, given an input batch and the variables to integrate.

Parameters:

Name Type Description Default
x Tensor

An input batch of shape \((B, C, D)\), where \(B\) is the batch size, \(C\) is the number of channels per variable, and \(D\) is the number of variables.

required
integrate_vars Tensor | Scope | Iterable[Scope]

The variables to integrate. It must be a subset of the variables on which the circuit given in the constructor is defined on. The format can be one of the following three: 1. Tensor of shape (B, D) where B is the batch size and D is the number of variables in the scope of the circuit. Its dtype should be torch.bool and have True in the positions of random variables that should be marginalised out and False elsewhere. 2. Scope, in this case the same integration mask is applied for all entries of the batch 3. List of Scopes, where the length of the list must be either 1 or B. If the list has length 1, behaves as above.

required

Returns: The result of the integration query, given as a tensor of shape \((B, O, K)\), where \(B\) is the batch size, \(O\) is the number of output vectors of the circuit, and \(K\) is the number of units in each output vector.

Source code in cirkit/backend/torch/queries.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope]) -> Tensor:
    """Solve an integration query, given an input batch and the variables to integrate.

    Args:
        x: An input batch of shape $(B, C, D)$, where $B$ is the batch size, $C$ is the number
            of channels per variable, and $D$ is the number of variables.
        integrate_vars: The variables to integrate. It must be a subset of the variables on
            which the circuit given in the constructor is defined on.
            The format can be one of the following three:
                1. Tensor of shape (B, D) where B is the batch size and D is the number of
                    variables in the scope of the circuit. Its dtype should be torch.bool
                    and have True in the positions of random variables that should be
                    marginalised out and False elsewhere.
                2. Scope, in this case the same integration mask is applied for all entries
                    of the batch
                3. List of Scopes, where the length of the list must be either 1 or B. If
                    the list has length 1, behaves as above.
    Returns:
        The result of the integration query, given as a tensor of shape $(B, O, K)$,
            where $B$ is the batch size, $O$ is the number of output vectors of the circuit, and
            $K$ is the number of units in each output vector.
    """
    if isinstance(integrate_vars, Tensor):
        # Check type of tensor is boolean
        if integrate_vars.dtype != torch.bool:
            raise ValueError(
                f"Expected dtype of tensor to be torch.bool, got {integrate_vars.dtype}"
            )
        # If single dimensional tensor, assume batch size = 1
        if len(integrate_vars.shape) == 1:
            integrate_vars = torch.unsqueeze(integrate_vars, 0)
        # If the scope is correct, proceed, otherwise error
        num_vars = max(self._circuit.scope) + 1
        if integrate_vars.shape[1] == num_vars:
            integrate_vars_mask = integrate_vars
        else:
            raise ValueError(
                f"Circuit scope has {num_vars} variables but integrate_vars "
                f"was defined over {integrate_vars.shape[1]} != {num_vars} variables"
            )
    else:
        # Convert list of scopes to a boolean mask of dimension (B, N) where
        # N is the number of variables in the circuit's scope.
        integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars)
        integrate_vars_mask = integrate_vars_mask.to(x.device)

    # Check batch sizes of input x and mask are compatible
    if integrate_vars_mask.shape[0] not in (1, x.shape[0]):
        raise ValueError(
            "The number of scopes to integrate over must "
            "either match the batch size of x, or be 1 if you "
            "want to broadcast. Found #inputs = "
            f"{x.shape[0]} != {integrate_vars_mask.shape[0]} = len(integrate_vars)"
        )

    output = self._circuit.evaluate(
        x,
        module_fn=functools.partial(
            IntegrateQuery._layer_fn, integrate_vars_mask=integrate_vars_mask
        ),
    )  # (O, B, K)
    return output.transpose(0, 1)  # (B, O, K)

__init__(circuit) ¤

Initialize an integration query object.

Parameters:

Name Type Description Default
circuit TorchCircuit

The circuit to integrate over.

required

Raises:

Type Description
ValueError

If the circuit to integrate is not smooth or not decomposable.

Source code in cirkit/backend/torch/queries.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, circuit: TorchCircuit) -> None:
    """Initialize an integration query object.

    Args:
        circuit: The circuit to integrate over.

    Raises:
        ValueError: If the circuit to integrate is not smooth or not decomposable.
    """
    if not circuit.properties.smooth or not circuit.properties.decomposable:
        raise ValueError(
            f"The circuit to integrate must be smooth and decomposable, "
            f"but found {circuit.properties}"
        )
    super().__init__()
    self._circuit = circuit

_layer_fn(layer, x, *, integrate_vars_mask) staticmethod ¤

Source code in cirkit/backend/torch/queries.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
@staticmethod
def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> Tensor:
    # Evaluate a layer: if it is not an input layer, then evaluate it in the usual
    # feed-forward way. Otherwise, use the variables to integrate to solve the marginal
    # queries on the input layers.
    output = layer(x)  # (F, B, Ko)
    if not isinstance(layer, TorchInputLayer):
        return output
    if layer.num_variables > 1:
        raise NotImplementedError("Integration of multivariate input layers is not supported")
    # integrate_vars_mask is a boolean tensor of dim (B, N)
    # where N is the number of variables in the scope of the whole circuit.
    #
    # layer.scope_idx contains a subset of the variable_idxs of the scope
    # but may be a reshaped tensor; the shape and order of the variables may be different.
    #
    # as such, we need to use the idxs in layer.scope_idx to lookup the values from
    # the integrate_vars_mask - this will return the correct shape and values.
    #
    # if integrate_vars_mask was a vector, we could do integrate_vars_mask[layer.scope_idx]
    # the vmap below applies the above across the B dimension

    # integration_mask has dimension (B, F, Ko)
    integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask)
    # permute to match integration_output: integration_mask has dimension (F, B, Ko)
    integration_mask = integration_mask.permute([1, 0, 2])

    if not torch.any(integration_mask).item():
        return output

    integration_output = layer.integrate()
    # Use the integration mask to select which output should be the result of
    # an integration operation, and which should not be
    # This is done in parallel for all folds, and regardless of whether the
    # circuit is folded or unfolded
    return torch.where(integration_mask, integration_output, output)

scopes_to_mask(circuit, batch_integrate_vars) staticmethod ¤

Accepts a batch of scopes and returns a boolean mask as a tensor with True in positions of specified scope indices and False otherwise.

Source code in cirkit/backend/torch/queries.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@staticmethod
def scopes_to_mask(circuit: TorchCircuit, batch_integrate_vars: Scope | list[Scope]):
    """Accepts a batch of scopes and returns a boolean mask as a tensor with
    True in positions of specified scope indices and False otherwise.
    """
    # If we passed a single scope, assume B = 1
    if isinstance(batch_integrate_vars, Scope):
        batch_integrate_vars = [batch_integrate_vars]

    batch_size = len(tuple(batch_integrate_vars))
    # There are cases where the circuit.scope may change,
    # e.g. we may marginalise out X_1 and the length of the scope may be smaller
    # but the actual scope will not have been shifted.
    num_rvs = max(circuit.scope) + 1
    num_idxs = sum(len(s) for s in batch_integrate_vars)

    # TODO: Maybe consider using a sparse tensor
    mask = torch.zeros((batch_size, num_rvs), dtype=torch.bool)

    # Catch case of only empty scopes where the following command will fail
    if num_idxs == 0:
        return mask

    batch_idxs, rv_idxs = zip(
        *((i, idx) for i, idxs in enumerate(batch_integrate_vars) for idx in idxs if idxs)
    )

    # Check that we have not asked to marginalise variables that are not defined
    invalid_idxs = Scope(rv_idxs) - circuit.scope
    if invalid_idxs:
        raise ValueError(
            "The variables to marginalize must be a subset of "
            "the circuit scope. Invalid variables "
            f"not in scope: {list(invalid_idxs)} "
        )

    mask[batch_idxs, rv_idxs] = True

    return mask

Query ¤

Bases: ABC

An object used to run queries of circuits compiled using the torch backend.

Source code in cirkit/backend/torch/queries.py
13
14
15
16
17
class Query(ABC):
    """An object used to run queries of circuits compiled using the torch backend."""

    def __init__(self) -> None:
        ...

__init__() ¤

Source code in cirkit/backend/torch/queries.py
16
17
def __init__(self) -> None:
    ...

SamplingQuery ¤

Bases: Query

The sampling query object.

Source code in cirkit/backend/torch/queries.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class SamplingQuery(Query):
    """The sampling query object."""

    def __init__(self, circuit: TorchCircuit) -> None:
        """Initialize a sampling query object. Currently, only sampling from the joint distribution
            is supported, i.e., sampling won't work in the case of circuits obtained by
            marginalization, or by observing evidence. Conditional sampling is currently not
            implemented.

        Args:
            circuit: The circuit to sample from.

        Raises:
            ValueError: If the circuit to sample from is not normalised.
        """
        if not circuit.properties.smooth or not circuit.properties.decomposable:
            raise ValueError(
                f"The circuit to sample from must be smooth and decomposable, "
                f"but found {circuit.properties}"
            )
        # TODO: add a check to verify the circuit is monotonic and normalized?
        super().__init__()
        self._circuit = circuit

    def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
        """Sample a number of data points.

        Args:
            num_samples: The number of samples to return.

        Return:
            A pair (samples, mixture_samples), consisting of (i) an assignment to the observed
            variables the circuit is defined on, and (ii) the samples of the finitely-discrete
            latent variables associated to the sum units. The samples (i) are returned as a
            tensor of shape (num_samples, num_channels, num_variables).

        Raises:
            ValueError: if the number of samples is not a positive number.
        """
        if num_samples <= 0:
            raise ValueError("The number of samples must be a positive number")

        mixture_samples: list[Tensor] = []
        # samples: (O, C, K, num_samples, D)
        samples = self._circuit.evaluate(
            module_fn=functools.partial(
                self._layer_fn,
                num_samples=num_samples,
                mixture_samples=mixture_samples,
            ),
        )
        # samples: (num_samples, O, K, C, D)
        samples = samples.permute(3, 0, 2, 1, 4)
        # TODO: fix for the case of multi-output circuits, i.e., O != 1 or K != 1
        samples = samples[:, 0, 0]  # (num_samples, C, D)
        return samples, mixture_samples

    def _layer_fn(
        self, layer: TorchLayer, *inputs: Tensor, num_samples: int, mixture_samples: list[Tensor]
    ) -> Tensor:
        # Sample from an input layer
        if not inputs:
            assert isinstance(layer, TorchInputLayer)
            samples = layer.sample(num_samples)
            samples = self._pad_samples(samples, layer.scope_idx)
            mixture_samples.append(samples)
            return samples

        # Sample through an inner layer
        assert isinstance(layer, TorchInnerLayer)
        samples, mix_samples = layer.sample(*inputs)
        if mix_samples is not None:
            mixture_samples.append(mix_samples)
        return samples

    def _pad_samples(self, samples: Tensor, scope_idx: Tensor) -> Tensor:
        """Pads univariate samples to the size of the scope of the circuit (output dimension)
        according to scope for compatibility in downstream inner nodes.
        """
        if scope_idx.shape[1] != 1:
            raise NotImplementedError("Padding is only implemented for univariate samples")

        # padded_samples: (F, C, K, num_samples, D)
        padded_samples = torch.zeros(
            (*samples.shape, len(self._circuit.scope)), device=samples.device, dtype=samples.dtype
        )
        fold_idx = torch.arange(samples.shape[0], device=samples.device)
        padded_samples[fold_idx, :, :, :, scope_idx.squeeze(dim=1)] = samples
        return padded_samples

_circuit = circuit instance-attribute ¤

__call__(num_samples=1) ¤

Sample a number of data points.

Parameters:

Name Type Description Default
num_samples int

The number of samples to return.

1
Return

A pair (samples, mixture_samples), consisting of (i) an assignment to the observed variables the circuit is defined on, and (ii) the samples of the finitely-discrete latent variables associated to the sum units. The samples (i) are returned as a tensor of shape (num_samples, num_channels, num_variables).

Raises:

Type Description
ValueError

if the number of samples is not a positive number.

Source code in cirkit/backend/torch/queries.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
    """Sample a number of data points.

    Args:
        num_samples: The number of samples to return.

    Return:
        A pair (samples, mixture_samples), consisting of (i) an assignment to the observed
        variables the circuit is defined on, and (ii) the samples of the finitely-discrete
        latent variables associated to the sum units. The samples (i) are returned as a
        tensor of shape (num_samples, num_channels, num_variables).

    Raises:
        ValueError: if the number of samples is not a positive number.
    """
    if num_samples <= 0:
        raise ValueError("The number of samples must be a positive number")

    mixture_samples: list[Tensor] = []
    # samples: (O, C, K, num_samples, D)
    samples = self._circuit.evaluate(
        module_fn=functools.partial(
            self._layer_fn,
            num_samples=num_samples,
            mixture_samples=mixture_samples,
        ),
    )
    # samples: (num_samples, O, K, C, D)
    samples = samples.permute(3, 0, 2, 1, 4)
    # TODO: fix for the case of multi-output circuits, i.e., O != 1 or K != 1
    samples = samples[:, 0, 0]  # (num_samples, C, D)
    return samples, mixture_samples

__init__(circuit) ¤

Initialize a sampling query object. Currently, only sampling from the joint distribution is supported, i.e., sampling won't work in the case of circuits obtained by marginalization, or by observing evidence. Conditional sampling is currently not implemented.

Parameters:

Name Type Description Default
circuit TorchCircuit

The circuit to sample from.

required

Raises:

Type Description
ValueError

If the circuit to sample from is not normalised.

Source code in cirkit/backend/torch/queries.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def __init__(self, circuit: TorchCircuit) -> None:
    """Initialize a sampling query object. Currently, only sampling from the joint distribution
        is supported, i.e., sampling won't work in the case of circuits obtained by
        marginalization, or by observing evidence. Conditional sampling is currently not
        implemented.

    Args:
        circuit: The circuit to sample from.

    Raises:
        ValueError: If the circuit to sample from is not normalised.
    """
    if not circuit.properties.smooth or not circuit.properties.decomposable:
        raise ValueError(
            f"The circuit to sample from must be smooth and decomposable, "
            f"but found {circuit.properties}"
        )
    # TODO: add a check to verify the circuit is monotonic and normalized?
    super().__init__()
    self._circuit = circuit

_layer_fn(layer, *inputs, num_samples, mixture_samples) ¤

Source code in cirkit/backend/torch/queries.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def _layer_fn(
    self, layer: TorchLayer, *inputs: Tensor, num_samples: int, mixture_samples: list[Tensor]
) -> Tensor:
    # Sample from an input layer
    if not inputs:
        assert isinstance(layer, TorchInputLayer)
        samples = layer.sample(num_samples)
        samples = self._pad_samples(samples, layer.scope_idx)
        mixture_samples.append(samples)
        return samples

    # Sample through an inner layer
    assert isinstance(layer, TorchInnerLayer)
    samples, mix_samples = layer.sample(*inputs)
    if mix_samples is not None:
        mixture_samples.append(mix_samples)
    return samples

_pad_samples(samples, scope_idx) ¤

Pads univariate samples to the size of the scope of the circuit (output dimension) according to scope for compatibility in downstream inner nodes.

Source code in cirkit/backend/torch/queries.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def _pad_samples(self, samples: Tensor, scope_idx: Tensor) -> Tensor:
    """Pads univariate samples to the size of the scope of the circuit (output dimension)
    according to scope for compatibility in downstream inner nodes.
    """
    if scope_idx.shape[1] != 1:
        raise NotImplementedError("Padding is only implemented for univariate samples")

    # padded_samples: (F, C, K, num_samples, D)
    padded_samples = torch.zeros(
        (*samples.shape, len(self._circuit.scope)), device=samples.device, dtype=samples.dtype
    )
    fold_idx = torch.arange(samples.shape[0], device=samples.device)
    padded_samples[fold_idx, :, :, :, scope_idx.squeeze(dim=1)] = samples
    return padded_samples