Skip to content

compiler

compiler ¤

TorchCompiler ¤

Bases: AbstractCompiler

Source code in cirkit/backend/torch/compiler.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class TorchCompiler(AbstractCompiler):
    def __init__(self, semiring: str = "sum-product", fold: bool = False, optimize: bool = False):
        super().__init__(
            CompilerLayerRegistry(DEFAULT_LAYER_COMPILATION_RULES),
            CompilerParameterRegistry(DEFAULT_PARAMETER_COMPILATION_RULES),
            CompilerInitializerRegistry(DEFAULT_INITIALIZER_COMPILATION_RULES),
            fold=fold,
            optimize=optimize,
        )

        # The semiring being used at compile time
        self._semiring: Semiring = SemiringImpl.from_name(semiring)

        # The state of the compiler
        self._state = TorchCompilerState()

        # The registry of optimization rules
        self._optimization_registry = {
            "parameter": ParameterOptRegistry(DEFAULT_PARAMETER_OPT_RULES),
            "layer_fuse": LayerOptRegistry(DEFAULT_LAYER_FUSE_OPT_RULES),
            "layer_shatter": LayerOptRegistry(DEFAULT_LAYER_SHATTER_OPT_RULES),
        }

    def compile_pipeline(self, sc: Circuit) -> AbstractTorchCircuit:
        # Compile the circuits following the topological ordering of the pipeline.
        for sci in pipeline_topological_ordering([sc]):
            # Check if the circuit in the pipeline has already been compiled
            if self.is_compiled(sci):
                continue

            # Compile the circuit
            self._compile_circuit(sci)

        # Return the compiled circuit (i.e., the output of the circuit pipeline)
        return self.get_compiled_circuit(sc)

    @property
    def semiring(self) -> Semiring:
        return self._semiring

    @property
    def is_fold_enabled(self) -> bool:
        return self._flags["fold"]

    @property
    def is_optimize_enabled(self) -> bool:
        return self._flags["optimize"]

    @property
    def state(self) -> TorchCompilerState:
        return self._state

    def compile_layer(self, layer: Layer) -> TorchLayer:
        signature = type(layer)
        rule = self.retrieve_layer_rule(signature)
        return cast(TorchLayer, rule(self, layer))

    def compile_parameter(self, parameter: Parameter) -> TorchParameter:
        # A map from symbolic to compiled parameters
        compiled_nodes_map: dict[ParameterNode, TorchParameterNode] = {}

        # The parameter nodes, and their inputs
        nodes: list[TorchParameterNode] = []
        in_nodes: dict[TorchParameterNode, list[TorchParameterNode]] = {}

        # Compile the parameter by following the topological ordering
        for p in parameter.topological_ordering():
            # Compile the parameter node and make the connections
            compiled_p = self._compile_parameter_node(p)
            in_compiled_nodes = [compiled_nodes_map[pi] for pi in parameter.node_inputs(p)]
            in_nodes[compiled_p] = in_compiled_nodes
            compiled_nodes_map[p] = compiled_p
            nodes.append(compiled_p)

        # Build the parameter's computational graph
        outputs = [compiled_nodes_map[parameter.output]]
        return TorchParameter(nodes, in_nodes, outputs)

    def compile_initializer(self, initializer: Initializer) -> Callable[[Tensor], Tensor]:
        # Retrieve the rule for the given initializer and compile it
        signature = type(initializer)
        rule = self.retrieve_initializer_rule(signature)
        return cast(Callable[[Tensor], Tensor], rule(self, initializer))

    def retrieve_optimization_registry(self, kind: str) -> CompilerRegistry:
        return cast(CompilerRegistry, self._optimization_registry[kind])

    def retrieve_optimization_rule(self, kind: str, pattern: GraphOptPattern) -> Callable:
        registry = self.retrieve_optimization_registry(kind)
        return registry.retrieve_rule(pattern)

    def _compile_parameter_node(self, node: ParameterNode) -> TorchParameterNode:
        signature = type(node)
        rule = self.retrieve_parameter_rule(signature)
        return cast(TorchParameterNode, rule(self, node))

    def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
        # A map from symbolic to compiled layers
        compiled_layers_map: dict[Layer, TorchLayer] = {}

        # The inputs of each layer
        in_layers: dict[TorchLayer, list[TorchLayer]] = {}

        # Compile layers by following the topological ordering
        for sl in sc.topological_ordering():
            # Compile the layer, for any layer types
            layer = self.compile_layer(sl)

            # Build the connectivity between compiled layers
            ins = [compiled_layers_map[sli] for sli in sc.layer_inputs(sl)]
            in_layers[layer] = ins
            compiled_layers_map[sl] = layer

        # If the symbolic circuit being compiled has empty scope,
        # then return a 'constant circuit' whose interface does not require inputs
        cc_cls = TorchCircuit if sc.scope else TorchConstantCircuit

        # Construct the sequence of output layers
        outputs = [compiled_layers_map[sl] for sl in sc.outputs]

        # Construct the tensorized circuit
        layers = list(compiled_layers_map.values())
        cc = cc_cls(
            sc.scope,
            sc.num_channels,
            layers=layers,
            in_layers=in_layers,
            outputs=outputs,
            properties=sc.properties,
        )

        # Post-process the compiled circuit, i.e.,
        # optionally apply optimizations to it and then fold it
        cc = self._post_process_circuit(cc)

        # Allocate & initialize the parameters
        cc.reset_parameters()

        # Register the compiled circuit
        self.register_compiled_circuit(sc, cc)

        # Signal the end of the circuit compilation to the state
        self._state.finish_compilation()
        return cc

    def _post_process_circuit(self, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
        if self.is_optimize_enabled:
            # Optimize the circuit computational graph
            opt_cc = _optimize_circuit(self, cc, max_opt_steps=5)
            del cc
            cc = opt_cc
        if self.is_fold_enabled:
            # Optimize the circuit by folding it
            opt_cc = _fold_circuit(self, cc)
            del cc
            cc = opt_cc
        return cc

_optimization_registry = {'parameter': ParameterOptRegistry(DEFAULT_PARAMETER_OPT_RULES), 'layer_fuse': LayerOptRegistry(DEFAULT_LAYER_FUSE_OPT_RULES), 'layer_shatter': LayerOptRegistry(DEFAULT_LAYER_SHATTER_OPT_RULES)} instance-attribute ¤

_semiring = SemiringImpl.from_name(semiring) instance-attribute ¤

_state = TorchCompilerState() instance-attribute ¤

is_fold_enabled property ¤

is_optimize_enabled property ¤

semiring property ¤

state property ¤

__init__(semiring='sum-product', fold=False, optimize=False) ¤

Source code in cirkit/backend/torch/compiler.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(self, semiring: str = "sum-product", fold: bool = False, optimize: bool = False):
    super().__init__(
        CompilerLayerRegistry(DEFAULT_LAYER_COMPILATION_RULES),
        CompilerParameterRegistry(DEFAULT_PARAMETER_COMPILATION_RULES),
        CompilerInitializerRegistry(DEFAULT_INITIALIZER_COMPILATION_RULES),
        fold=fold,
        optimize=optimize,
    )

    # The semiring being used at compile time
    self._semiring: Semiring = SemiringImpl.from_name(semiring)

    # The state of the compiler
    self._state = TorchCompilerState()

    # The registry of optimization rules
    self._optimization_registry = {
        "parameter": ParameterOptRegistry(DEFAULT_PARAMETER_OPT_RULES),
        "layer_fuse": LayerOptRegistry(DEFAULT_LAYER_FUSE_OPT_RULES),
        "layer_shatter": LayerOptRegistry(DEFAULT_LAYER_SHATTER_OPT_RULES),
    }

_compile_circuit(sc) ¤

Source code in cirkit/backend/torch/compiler.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
    # A map from symbolic to compiled layers
    compiled_layers_map: dict[Layer, TorchLayer] = {}

    # The inputs of each layer
    in_layers: dict[TorchLayer, list[TorchLayer]] = {}

    # Compile layers by following the topological ordering
    for sl in sc.topological_ordering():
        # Compile the layer, for any layer types
        layer = self.compile_layer(sl)

        # Build the connectivity between compiled layers
        ins = [compiled_layers_map[sli] for sli in sc.layer_inputs(sl)]
        in_layers[layer] = ins
        compiled_layers_map[sl] = layer

    # If the symbolic circuit being compiled has empty scope,
    # then return a 'constant circuit' whose interface does not require inputs
    cc_cls = TorchCircuit if sc.scope else TorchConstantCircuit

    # Construct the sequence of output layers
    outputs = [compiled_layers_map[sl] for sl in sc.outputs]

    # Construct the tensorized circuit
    layers = list(compiled_layers_map.values())
    cc = cc_cls(
        sc.scope,
        sc.num_channels,
        layers=layers,
        in_layers=in_layers,
        outputs=outputs,
        properties=sc.properties,
    )

    # Post-process the compiled circuit, i.e.,
    # optionally apply optimizations to it and then fold it
    cc = self._post_process_circuit(cc)

    # Allocate & initialize the parameters
    cc.reset_parameters()

    # Register the compiled circuit
    self.register_compiled_circuit(sc, cc)

    # Signal the end of the circuit compilation to the state
    self._state.finish_compilation()
    return cc

_compile_parameter_node(node) ¤

Source code in cirkit/backend/torch/compiler.py
198
199
200
201
def _compile_parameter_node(self, node: ParameterNode) -> TorchParameterNode:
    signature = type(node)
    rule = self.retrieve_parameter_rule(signature)
    return cast(TorchParameterNode, rule(self, node))

_post_process_circuit(cc) ¤

Source code in cirkit/backend/torch/compiler.py
252
253
254
255
256
257
258
259
260
261
262
263
def _post_process_circuit(self, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
    if self.is_optimize_enabled:
        # Optimize the circuit computational graph
        opt_cc = _optimize_circuit(self, cc, max_opt_steps=5)
        del cc
        cc = opt_cc
    if self.is_fold_enabled:
        # Optimize the circuit by folding it
        opt_cc = _fold_circuit(self, cc)
        del cc
        cc = opt_cc
    return cc

compile_initializer(initializer) ¤

Source code in cirkit/backend/torch/compiler.py
185
186
187
188
189
def compile_initializer(self, initializer: Initializer) -> Callable[[Tensor], Tensor]:
    # Retrieve the rule for the given initializer and compile it
    signature = type(initializer)
    rule = self.retrieve_initializer_rule(signature)
    return cast(Callable[[Tensor], Tensor], rule(self, initializer))

compile_layer(layer) ¤

Source code in cirkit/backend/torch/compiler.py
159
160
161
162
def compile_layer(self, layer: Layer) -> TorchLayer:
    signature = type(layer)
    rule = self.retrieve_layer_rule(signature)
    return cast(TorchLayer, rule(self, layer))

compile_parameter(parameter) ¤

Source code in cirkit/backend/torch/compiler.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def compile_parameter(self, parameter: Parameter) -> TorchParameter:
    # A map from symbolic to compiled parameters
    compiled_nodes_map: dict[ParameterNode, TorchParameterNode] = {}

    # The parameter nodes, and their inputs
    nodes: list[TorchParameterNode] = []
    in_nodes: dict[TorchParameterNode, list[TorchParameterNode]] = {}

    # Compile the parameter by following the topological ordering
    for p in parameter.topological_ordering():
        # Compile the parameter node and make the connections
        compiled_p = self._compile_parameter_node(p)
        in_compiled_nodes = [compiled_nodes_map[pi] for pi in parameter.node_inputs(p)]
        in_nodes[compiled_p] = in_compiled_nodes
        compiled_nodes_map[p] = compiled_p
        nodes.append(compiled_p)

    # Build the parameter's computational graph
    outputs = [compiled_nodes_map[parameter.output]]
    return TorchParameter(nodes, in_nodes, outputs)

compile_pipeline(sc) ¤

Source code in cirkit/backend/torch/compiler.py
130
131
132
133
134
135
136
137
138
139
140
141
def compile_pipeline(self, sc: Circuit) -> AbstractTorchCircuit:
    # Compile the circuits following the topological ordering of the pipeline.
    for sci in pipeline_topological_ordering([sc]):
        # Check if the circuit in the pipeline has already been compiled
        if self.is_compiled(sci):
            continue

        # Compile the circuit
        self._compile_circuit(sci)

    # Return the compiled circuit (i.e., the output of the circuit pipeline)
    return self.get_compiled_circuit(sc)

retrieve_optimization_registry(kind) ¤

Source code in cirkit/backend/torch/compiler.py
191
192
def retrieve_optimization_registry(self, kind: str) -> CompilerRegistry:
    return cast(CompilerRegistry, self._optimization_registry[kind])

retrieve_optimization_rule(kind, pattern) ¤

Source code in cirkit/backend/torch/compiler.py
194
195
196
def retrieve_optimization_rule(self, kind: str, pattern: GraphOptPattern) -> Callable:
    registry = self.retrieve_optimization_registry(kind)
    return registry.retrieve_rule(pattern)

TorchCompilerState ¤

Source code in cirkit/backend/torch/compiler.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class TorchCompilerState:
    def __init__(self):
        # A map from symbolic parameter tensors to a tuple containing the compiled parameter tensor,
        # and the slice index, which is 0 if the compiled parameter tensor is unfolded.
        # If the compiled parameter tensor is folded, then the slice index can be non-zero.
        self._compiled_parameters: dict[TensorParameter, tuple[TorchTensorParameter, int]] = {}

        # We keep a reverse map from compiled and unfolded parameter tensors
        # to the corresponding symbolic parameter tensors.
        # This is useful to update the map from symbolic to compiled parameter tensors above
        # after we fold the tensor parameters within a circuit.
        # Since this is useful only for folding, it will be cleared after each circuit compilation.
        self._symbolic_parameters: dict[TorchTensorParameter, TensorParameter] = {}

    def finish_compilation(self) -> None:
        # Clear the map from (unfolded) compiled parameter tensors to symbolic ones
        self._symbolic_parameters.clear()

    def has_compiled_parameter(self, p: TensorParameter) -> bool:
        # Retrieve whether a tensor parameter has already been compiled
        return p in self._compiled_parameters

    def retrieve_compiled_parameter(self, p: TensorParameter) -> tuple[TorchTensorParameter, int]:
        # Retrieve the compiled parameter: we return the fold index as well.
        return self._compiled_parameters[p]

    def retrieve_symbolic_parameter(self, p: TorchTensorParameter) -> TensorParameter:
        # Retrieve the symbolic parameter tensor associated to the compiled one (which is unfolded)
        return self._symbolic_parameters[p]

    def register_compiled_parameter(
        self, sp: TensorParameter, cp: TorchTensorParameter, *, fold_idx: int | None = None
    ) -> None:
        # Register a link from a symbolic parameter tensor to a compiled parameter tensor.
        if fold_idx is None:
            # We are registering an unfolded compiled parameter tensor
            # So, we can also register the reverse map (i.e., compiled to symbolic)
            self._compiled_parameters[sp] = (cp, 0)
            self._symbolic_parameters[cp] = sp

        # We are registering a folded compiled parameter tensor
        # So, we associate the symbolic parameter tensor to a particular slice of the
        # folded compiled parameter tensor, which is specified by the 'fold_idx'.
        self._compiled_parameters[sp] = (cp, fold_idx)

_compiled_parameters = {} instance-attribute ¤

_symbolic_parameters = {} instance-attribute ¤

__init__() ¤

Source code in cirkit/backend/torch/compiler.py
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(self):
    # A map from symbolic parameter tensors to a tuple containing the compiled parameter tensor,
    # and the slice index, which is 0 if the compiled parameter tensor is unfolded.
    # If the compiled parameter tensor is folded, then the slice index can be non-zero.
    self._compiled_parameters: dict[TensorParameter, tuple[TorchTensorParameter, int]] = {}

    # We keep a reverse map from compiled and unfolded parameter tensors
    # to the corresponding symbolic parameter tensors.
    # This is useful to update the map from symbolic to compiled parameter tensors above
    # after we fold the tensor parameters within a circuit.
    # Since this is useful only for folding, it will be cleared after each circuit compilation.
    self._symbolic_parameters: dict[TorchTensorParameter, TensorParameter] = {}

finish_compilation() ¤

Source code in cirkit/backend/torch/compiler.py
75
76
77
def finish_compilation(self) -> None:
    # Clear the map from (unfolded) compiled parameter tensors to symbolic ones
    self._symbolic_parameters.clear()

has_compiled_parameter(p) ¤

Source code in cirkit/backend/torch/compiler.py
79
80
81
def has_compiled_parameter(self, p: TensorParameter) -> bool:
    # Retrieve whether a tensor parameter has already been compiled
    return p in self._compiled_parameters

register_compiled_parameter(sp, cp, *, fold_idx=None) ¤

Source code in cirkit/backend/torch/compiler.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def register_compiled_parameter(
    self, sp: TensorParameter, cp: TorchTensorParameter, *, fold_idx: int | None = None
) -> None:
    # Register a link from a symbolic parameter tensor to a compiled parameter tensor.
    if fold_idx is None:
        # We are registering an unfolded compiled parameter tensor
        # So, we can also register the reverse map (i.e., compiled to symbolic)
        self._compiled_parameters[sp] = (cp, 0)
        self._symbolic_parameters[cp] = sp

    # We are registering a folded compiled parameter tensor
    # So, we associate the symbolic parameter tensor to a particular slice of the
    # folded compiled parameter tensor, which is specified by the 'fold_idx'.
    self._compiled_parameters[sp] = (cp, fold_idx)

retrieve_compiled_parameter(p) ¤

Source code in cirkit/backend/torch/compiler.py
83
84
85
def retrieve_compiled_parameter(self, p: TensorParameter) -> tuple[TorchTensorParameter, int]:
    # Retrieve the compiled parameter: we return the fold index as well.
    return self._compiled_parameters[p]

retrieve_symbolic_parameter(p) ¤

Source code in cirkit/backend/torch/compiler.py
87
88
89
def retrieve_symbolic_parameter(self, p: TorchTensorParameter) -> TensorParameter:
    # Retrieve the symbolic parameter tensor associated to the compiled one (which is unfolded)
    return self._symbolic_parameters[p]

_fold_circuit(compiler, cc) ¤

Source code in cirkit/backend/torch/compiler.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def _fold_circuit(compiler: TorchCompiler, cc: AbstractTorchCircuit) -> AbstractTorchCircuit:
    # Fold the layers in the given circuit, by following the layer-wise topological ordering
    layers, in_layers, outputs, fold_idx_info = build_folded_graph(
        cc.layerwise_topological_ordering(),
        outputs=cc.outputs,
        incomings_fn=cc.layer_inputs,
        fold_group_fn=functools.partial(_fold_layers_group, compiler=compiler),
    )

    # Instantiate a folded circuit
    return type(cc)(
        cc.scope,
        cc.num_channels,
        layers,
        in_layers,
        outputs,
        properties=cc.properties,
        fold_idx_info=fold_idx_info,
    )

_fold_layers_group(layers, *, compiler) ¤

Source code in cirkit/backend/torch/compiler.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def _fold_layers_group(layers: list[TorchLayer], *, compiler: TorchCompiler) -> TorchLayer:
    # Retrieve the class of the folded layer, as well as the configuration attributes
    fold_layer_cls = type(layers[0])
    fold_layer_conf = layers[0].config

    # If we are folding input layers, then concatenate the variables scope index tensors
    kwargs = {}
    if issubclass(fold_layer_cls, TorchInputLayer):
        if not issubclass(fold_layer_cls, TorchConstantLayer):
            kwargs["scope_idx"] = torch.cat([l.scope_idx for l in layers])
    else:
        # We are folding sum or product layers, so simply set the number of folds
        kwargs["num_folds"] = sum(l.num_folds for l in layers)

    # Retrieve the parameters of each layer, and
    # retrieve the sub-module layers of each layer
    layer_params: dict[str, list[TorchParameter]] = defaultdict(list)
    layer_submodules: dict[str, list[TorchLayer]] = defaultdict(list)
    for l in layers:
        for n, p in l.params.items():
            layer_params[n].append(p)
        for n, sub_l in l.sub_modules.items():
            layer_submodules[n].append(sub_l)

    # Fold the parameters, if the layers have any
    fold_layer_parameters: dict[str, TorchParameter] = {
        n: _fold_parameters(compiler, ps) for n, ps in layer_params.items()
    }

    # Fold all sub-module layers, if the layers have any
    fold_layer_submodules: dict[str, TorchLayer] = {
        n: _fold_layers_group(ls, compiler=compiler) for n, ls in layer_submodules.items()
    }

    # Instantiate a new folded layer, using the folded layer configuration and the folded parameters
    return fold_layer_cls(
        **fold_layer_conf,
        **fold_layer_submodules,
        **fold_layer_parameters,
        semiring=compiler.semiring,
        **kwargs,
    )

_fold_parameter_nodes_group(group, *, compiler) ¤

Source code in cirkit/backend/torch/compiler.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def _fold_parameter_nodes_group(
    group: list[TorchParameterNode], *, compiler: TorchCompiler
) -> TorchParameterNode:
    fold_node_cls = type(group[0])
    # Catch the case we are folding tensor parameters
    # That is, we set the number of folds, copy the number of parameters and relevant flags,
    # and stack the initialization functions together.
    if issubclass(fold_node_cls, TorchTensorParameter):
        assert all(isinstance(p, TorchTensorParameter) for p in group)
        folded_node = TorchTensorParameter(
            *group[0].shape,
            num_folds=len(group),
            requires_grad=group[0].requires_grad,
            initializer_=functools.partial(
                foldwise_initializer_, initializers=list(map(lambda p: p.initializer, group))
            ),
            dtype=group[0].dtype,
        )
        # If we are folding parameter tensors, then update the registry as to maintain the correct
        # mapping between symbolic parameter leaves (which are unfolded) and slices within the folded
        # compiled parameter leaves.
        for i, p in enumerate(group):
            sp = compiler.state.retrieve_symbolic_parameter(p)
            compiler.state.register_compiled_parameter(sp, folded_node, fold_idx=i)
        return folded_node
    # Catch the case we are folding parameters obtained via slicing
    # This case regularly fires when doing operations over circuits
    # that are compiled into folded tensorized circuits
    if issubclass(fold_node_cls, TorchPointerParameter):
        assert all(isinstance(p, TorchPointerParameter) for p in group)
        if len(group) == 1:
            # Catch the case we are not able to fold multiple tensor slicing operations
            # In such a case, just have the slice as folded parameter (i.e., number of folds = 1)
            return group[0]
        # Catch the case we are able to fold multiple tensor slicing operations
        in_folded_node = group[0].deref()
        in_fold_idx: list[int] = list(
            chain.from_iterable(
                list(range(p.num_folds)) if p.fold_idx is None else p.fold_idx for p in group
            )
        )
        return TorchPointerParameter(in_folded_node, fold_idx=in_fold_idx)
    # We are folding an operator: just set the number of folds and copy the configuration parameters
    assert all(isinstance(p, TorchParameterOp) for p in group)
    return fold_node_cls(**group[0].config, num_folds=len(group))

_fold_parameters(compiler, parameters) ¤

Source code in cirkit/backend/torch/compiler.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def _fold_parameters(compiler: TorchCompiler, parameters: list[TorchParameter]) -> TorchParameter:
    # Retrieve:
    # (i)  the parameter nodes and the input to each node;
    # (ii) the layer-wise (aka bottom-up) topological orderings of parameter nodes
    in_nodes: dict[TorchParameterNode, Sequence[TorchParameterNode]] = {}
    for pi in parameters:
        in_nodes.update(pi.nodes_inputs)
    ordering: list[list[TorchParameterNode]] = []
    for pi in parameters:
        for i, frontier in enumerate(pi.layerwise_topological_ordering()):
            if i < len(ordering):
                ordering[i].extend(frontier)
                continue
            ordering.append(frontier)

    # Fold the nodes in the merged parameter computational graphs,
    # by following the layer-wise topological ordering
    nodes, in_nodes, outputs, fold_idx_info = build_folded_graph(
        ordering,
        outputs=chain.from_iterable(map(lambda pi: pi.outputs, parameters)),
        incomings_fn=in_nodes.get,
        fold_group_fn=functools.partial(_fold_parameter_nodes_group, compiler=compiler),
    )

    # Construct the folded parameter's computational graph
    return TorchParameter(nodes, in_nodes, outputs, fold_idx_info=fold_idx_info)

_match_layer_pattern(layer, pattern, *, incomings_fn, outcomings_fn) ¤

Source code in cirkit/backend/torch/compiler.py
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def _match_layer_pattern(
    layer: TorchLayer,
    pattern: LayerOptPattern,
    *,
    incomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
    outcomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
) -> LayerOptMatch | None:
    ppatterns = pattern.ppatterns()
    cpatterns = pattern.cpatterns()
    pattern_entries = pattern.entries()
    num_entries = len(pattern_entries)
    matched_layers = []
    matched_parameters = []

    # Start matching the pattern from the root
    # TODO: generalize to match DAGs or trees
    for lid in range(num_entries):
        # First, attempt to match the layer
        if not isinstance(layer, pattern_entries[lid]):
            return None
        in_nodes = incomings_fn(layer)
        if len(in_nodes) > 1 and lid != num_entries - 1:
            return None
        out_nodes = outcomings_fn(layer)
        if len(out_nodes) > 1 and lid != 0:
            return None

        # Second, attempt to match the configuration patterns for the layer
        for cname, cvalue in cpatterns[lid].items():
            if layer.config[cname] != cvalue:
                return None

        # Third, attempt to match the patterns specified for its parameters
        lpmatches = {}
        for pname, ppattern in ppatterns[lid].items():
            pgraph = layer.params[pname]
            matches, _ = match_optimization_patterns(
                pgraph.topological_ordering(),
                pgraph.outputs,
                [ppattern],
                incomings_fn=pgraph.node_inputs,
                outcomings_fn=pgraph.node_outputs,
                pattern_matcher_fn=_match_parameter_nodes_pattern,
            )
            if not matches:
                return None
            lpmatches[pname] = matches
        matched_parameters.append(lpmatches)

        # We got a match with the layer and its parameters.
        # Next, try to match its input sub-graph.
        matched_layers.append(layer)
        if lid != num_entries - 1:
            (layer,) = in_nodes

    return LayerOptMatch(pattern, matched_layers, matched_parameters)

_match_parameter_nodes_pattern(node, pattern, *, incomings_fn, outcomings_fn) ¤

Source code in cirkit/backend/torch/compiler.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def _match_parameter_nodes_pattern(
    node: TorchParameterNode,
    pattern: ParameterOptPattern,
    *,
    incomings_fn: Callable[[TorchParameterNode], Sequence[TorchParameterNode]],
    outcomings_fn: Callable[[TorchParameterNode], Sequence[TorchParameterNode]],
) -> ParameterOptMatch | None:
    pattern_entries = pattern.entries()
    num_entries = len(pattern_entries)
    matched_nodes = []

    # Start matching the pattern from the root
    # TODO: generalize to match DAGs or binary trees
    for nid in range(num_entries):
        if not isinstance(node, pattern_entries[nid]):
            return None
        in_nodes = incomings_fn(node)
        if len(in_nodes) > 1 and nid != num_entries - 1:
            return None
        out_nodes = outcomings_fn(node)
        if len(out_nodes) > 1 and nid != 0:
            return None
        matched_nodes.append(node)
        if nid != num_entries - 1:
            (node,) = in_nodes

    return ParameterOptMatch(pattern, matched_nodes)

_optimize_circuit(compiler, cc, *, max_opt_steps=5) ¤

Source code in cirkit/backend/torch/compiler.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def _optimize_circuit(
    compiler: TorchCompiler, cc: AbstractTorchCircuit, *, max_opt_steps: int = 5
) -> AbstractTorchCircuit:
    assert max_opt_steps > 0

    # Each optimization step consists of three kinds of optimizations (see below).
    # We continue optimizing until no further optimization can be performed
    # or if we reach a maximum number of optimization steps being performed
    optimizing = True
    opt_step = 0
    while optimizing and opt_step < max_opt_steps:
        # First optimization step: optimize the parameters node of the parameter graphs of each layer
        opt_cc, opt_fuse_parameter_nodes = _optimize_parameter_nodes(compiler, cc)
        del cc
        cc = opt_cc

        # Second optimization step: shatter layers in multiple more efficient ones
        opt_cc, opt_shatter_layers = _optimize_layers(compiler, cc, shatter=True)
        del cc
        cc = opt_cc

        # Third optimization step: fuse multiple layers into a single more efficient one
        opt_cc, opt_fuse_layers = _optimize_layers(compiler, cc, shatter=False)
        del cc
        cc = opt_cc

        # Update the optimization step and whether we should continue optimizing
        optimizing = opt_fuse_parameter_nodes or opt_shatter_layers or opt_fuse_layers
        opt_step += 1

    return cc

_optimize_layers(compiler, cc, *, shatter=False) ¤

Source code in cirkit/backend/torch/compiler.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def _optimize_layers(
    compiler: TorchCompiler, cc: AbstractTorchCircuit, *, shatter: bool = False
) -> tuple[AbstractTorchCircuit, bool]:
    def match_optimizer_shatter(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
        rule = compiler.retrieve_optimization_rule("layer_shatter", match.pattern)
        func = cast(LayerOptApplyFunc, rule)
        return func(compiler, match)

    def match_optimizer_fuse(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
        rule = compiler.retrieve_optimization_rule("layer_fuse", match.pattern)
        func = cast(LayerOptApplyFunc, rule)
        return func(compiler, match)

    registry = compiler.retrieve_optimization_registry("layer_shatter" if shatter else "layer_fuse")
    match_optimizer = match_optimizer_shatter if shatter else match_optimizer_fuse
    optimize_result = optimize_graph(
        cc.topological_ordering(),
        cc.outputs,
        registry.signatures,
        incomings_fn=cc.layer_inputs,
        outcomings_fn=cc.layer_outputs,
        pattern_matcher_fn=_match_layer_pattern,
        match_optimizer_fn=match_optimizer,
    )
    if optimize_result is None:
        return cc, False
    layers, in_layers, outputs = optimize_result
    cc = type(cc)(cc.scope, cc.num_channels, layers, in_layers, outputs, properties=cc.properties)
    return cc, True

_optimize_parameter_nodes(compiler, cc) ¤

Source code in cirkit/backend/torch/compiler.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def _optimize_parameter_nodes(
    compiler: TorchCompiler, cc: AbstractTorchCircuit
) -> tuple[AbstractTorchCircuit, bool]:
    def match_optimizer(match: ParameterOptMatch) -> tuple[TorchParameterNode, ...]:
        rule = compiler.retrieve_optimization_rule("parameter", match.pattern)
        func = cast(ParameterOptApplyFunc, rule)
        return func(compiler, match)

    # Loop through all the layers
    has_been_optimized = False
    patterns = compiler.retrieve_optimization_registry("parameter").signatures
    for layer in cc.layers:
        # Retrieve the parameter computational graphs of the layer
        for pname, pgraph in layer.params.items():
            # Optimize the parameter computational graph
            optimize_result = optimize_graph(
                pgraph.topological_ordering(),
                pgraph.outputs,
                patterns,
                incomings_fn=pgraph.node_inputs,
                outcomings_fn=pgraph.node_outputs,
                pattern_matcher_fn=_match_parameter_nodes_pattern,
                match_optimizer_fn=match_optimizer,
            )

            # Check if no optimization is possible
            if optimize_result is None:
                continue
            nodes, in_nodes, outputs = optimize_result

            # Build the optimized computational graph
            pgraph = type(pgraph)(nodes, in_nodes, outputs)

            # Update the parameter computational graph assigned to the layer
            assert hasattr(layer, pname)
            setattr(layer, pname, pgraph)
            has_been_optimized = True

    # Check whether no parameter optimization has been possible
    if has_been_optimized:
        return cc, True
    return cc, False