Skip to content

modules

modules ¤

TorchModuleT = TypeVar('TorchModuleT', bound=AbstractTorchModule) module-attribute ¤

TypeVar: A torch module type that subclasses AbstractTorchModule.

AbstractTorchModule ¤

Bases: Module, ABC

An abstract class representing a torch.nn.Module that can be folded.

An abstract torch module is used as base class for both the circuit layers and the nodes of the computational graph of the parameters of each layer.

Source code in cirkit/backend/torch/graph/modules.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class AbstractTorchModule(nn.Module, ABC):
    """An abstract class representing a [torch.nn.Module][torch.nn.Module] that can be folded.

    An abstract torch module is used as base class for both the circuit layers and the nodes
    of the computational graph of the parameters of each layer.
    """

    def __init__(self, *, num_folds: int = 1):
        """Initialize the abstract torch module object.

        Args:
            num_folds: The number of folds computed by the module.
        """
        super().__init__()
        self._num_folds = num_folds

    @property
    def num_folds(self) -> int:
        """Retrieve the number of folds.

        Returns:
            The number of folds.
        """
        return self._num_folds

    @property
    @abstractmethod
    def fold_settings(self) -> tuple[Any, ...]:
        """Retrieve a tuple of attributes on which modules must agree on in order to be folded.

        Returns:
            A tuple of attributes.
        """

    @property
    def sub_modules(self) -> Mapping[str, "AbstractTorchModule"]:
        """Retrieve a dictionary mapping string identifiers to torch sub-modules,
        that must be passed to the ```__init__``` method of the top-level torch module.

        Returns:
            A dictionary of torch modules.
        """
        return {}

fold_settings abstractmethod property ¤

Retrieve a tuple of attributes on which modules must agree on in order to be folded.

Returns:

Type Description
tuple[Any, ...]

A tuple of attributes.

num_folds property ¤

Retrieve the number of folds.

Returns:

Type Description
int

The number of folds.

sub_modules property ¤

Retrieve a dictionary mapping string identifiers to torch sub-modules, that must be passed to the __init__ method of the top-level torch module.

Returns:

Type Description
Mapping[str, AbstractTorchModule]

A dictionary of torch modules.

__init__(*, num_folds=1) ¤

Initialize the abstract torch module object.

Parameters:

Name Type Description Default
num_folds int

The number of folds computed by the module.

1
Source code in cirkit/backend/torch/graph/modules.py
19
20
21
22
23
24
25
26
def __init__(self, *, num_folds: int = 1):
    """Initialize the abstract torch module object.

    Args:
        num_folds: The number of folds computed by the module.
    """
    super().__init__()
    self._num_folds = num_folds

AddressBook ¤

Bases: Module, Generic[TorchModuleT], ABC

The address book data structure, sometimes also known as book-keeping. The address book stores a list of AddressBookEntry, where each entry stores the information needed to gather the inputs to each (possibly folded) torch module.

Source code in cirkit/backend/torch/graph/modules.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class AddressBook(nn.Module, Generic[TorchModuleT], ABC):
    """The address book data structure, sometimes also known as book-keeping.
    The address book stores a list of
    [AddressBookEntry][cirkit.backend.torch.graph.modules.AddressBookEntry],
    where each entry stores the information needed to gather the inputs to each (possibly folded)
    torch module.
    """

    def __init__(self, entries: list[AddressBookEntry[TorchModuleT]]) -> None:
        """Initializes an address book.

        Args:
            entries: The list of address book entries.

        Raises:
            ValueError: If the list of address book entries is empty.
            ValueError: If the last entry (i.e., the entry used to compute the output of
                the whole computational graph) has a torch module assigned to it, or
                if it has more than one fold index tensor, or if the fold index tensor
                is not a 1-dimensional tensor.
        """
        if not entries:
            raise ValueError("The list of address book entry must not be empty")
        last_entry = entries[-1]
        if last_entry.module is not None:
            raise ValueError(
                "The last entry of the address book must not have a module associated to it"
            )
        if len(last_entry.in_fold_idx) != 1:
            raise ValueError(
                "The last entry of the address book must have only one fold index tensor"
            )
        (out_fold_idx,) = last_entry.in_fold_idx
        if not isinstance(out_fold_idx, Tensor) or len(out_fold_idx.shape) != 1:
            raise ValueError("The output fold index tensor should be a 1-dimensional tensor")
        super().__init__()
        self._num_outputs = out_fold_idx.shape[0]
        self._entry_modules: list[TorchModuleT | None] = [e.module for e in entries]
        self._entry_in_module_ids: list[list[list[int]]] = [e.in_module_ids for e in entries]
        # We register the book-keeping tensor indices as buffers.
        # By doing so they are automatically transferred to the device
        # This reduces CPU-device communications required to transfer these indices
        # TODO: Perhaps this can be made more elegant in the future, if someone
        #  decides to introduce a nn.BufferList container in torch
        self._entry_in_fold_idx_targets: list[list[str]] = []
        for i, e in enumerate(entries):
            self._entry_in_fold_idx_targets.append([])
            for j, fi in enumerate(e.in_fold_idx):
                in_fold_idx_target = f"_in_fold_idx_{i}_{j}"
                if isinstance(fi, Tensor):
                    self.register_buffer(in_fold_idx_target, fi)
                else:
                    setattr(self, in_fold_idx_target, fi)
                self._entry_in_fold_idx_targets[-1].append(in_fold_idx_target)

    def __len__(self) -> int:
        """Retrieve the length of the address book.

        Returns:
            The number of address book entries.
        """
        return len(self._entry_modules)

    def __iter__(self) -> Iterator[AddressBookEntry[TorchModuleT]]:
        """Retrieve an iterator over address book entries, i.e., a tuple consisting of
        three objects: (i) the torch module to evaluate (it can be None if the entry
        is needed to return the output of the computational graph); (ii) for each input
        to the module (i.e., depending on the arity) we have the list of ids to the
        outputs of other modules (it can be empty if the module is an input module); and
        (iii) for each input to the module we have the fold indexing, which
        is used to retrieve the inputs to a module, even if they are folded modules.

        Returns:
            An iterator over address book entries.
        """
        for module, in_module_ids_hs, in_fold_idx_targets in zip(
            self._entry_modules, self._entry_in_module_ids, self._entry_in_fold_idx_targets
        ):
            yield AddressBookEntry(
                module,
                in_module_ids_hs,
                [getattr(self, target) for target in in_fold_idx_targets],
            )

    @property
    def num_outputs(self) -> int:
        """The number of outputs of the whole computational graph represented
        through the address book.

        For instance, for a circuit with $n$ output layers, this will be equal to $n$.

        Returns:
            The number of outputs.
        """
        return self._num_outputs

    @abstractmethod
    def lookup(
        self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
    ) -> Iterator[tuple[TorchModuleT | None, tuple]]:
        """Retrieve an iterator that iteratively returns a torch module and the tensor inputs to it.

        Args:
            module_outputs: A list of the outputs of each torch module. This list is expected to
                be iteratively expanded as we continue evaluating the modules of the torch
                computational graph.
            in_graph: An optional tensor input to the whole computational graph. This is used
                as input to the torch modules that do not receive input from other torch
                modules within the torch computationa graph.

        Returns:
            An iterator of tuples, where the first element is a torch module if we are
            retrieving the inputs to it, and None if we are retrieving the output of the
            whole computational graph (i.e., in the final step of the evaluation).
            The second element is instead a tuple of arguments that are input to the
            torch module (e.g., some input tensors)
        """

num_outputs property ¤

The number of outputs of the whole computational graph represented through the address book.

For instance, for a circuit with \(n\) output layers, this will be equal to \(n\).

Returns:

Type Description
int

The number of outputs.

__init__(entries) ¤

Initializes an address book.

Parameters:

Name Type Description Default
entries list[AddressBookEntry[TorchModuleT]]

The list of address book entries.

required

Raises:

Type Description
ValueError

If the list of address book entries is empty.

ValueError

If the last entry (i.e., the entry used to compute the output of the whole computational graph) has a torch module assigned to it, or if it has more than one fold index tensor, or if the fold index tensor is not a 1-dimensional tensor.

Source code in cirkit/backend/torch/graph/modules.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def __init__(self, entries: list[AddressBookEntry[TorchModuleT]]) -> None:
    """Initializes an address book.

    Args:
        entries: The list of address book entries.

    Raises:
        ValueError: If the list of address book entries is empty.
        ValueError: If the last entry (i.e., the entry used to compute the output of
            the whole computational graph) has a torch module assigned to it, or
            if it has more than one fold index tensor, or if the fold index tensor
            is not a 1-dimensional tensor.
    """
    if not entries:
        raise ValueError("The list of address book entry must not be empty")
    last_entry = entries[-1]
    if last_entry.module is not None:
        raise ValueError(
            "The last entry of the address book must not have a module associated to it"
        )
    if len(last_entry.in_fold_idx) != 1:
        raise ValueError(
            "The last entry of the address book must have only one fold index tensor"
        )
    (out_fold_idx,) = last_entry.in_fold_idx
    if not isinstance(out_fold_idx, Tensor) or len(out_fold_idx.shape) != 1:
        raise ValueError("The output fold index tensor should be a 1-dimensional tensor")
    super().__init__()
    self._num_outputs = out_fold_idx.shape[0]
    self._entry_modules: list[TorchModuleT | None] = [e.module for e in entries]
    self._entry_in_module_ids: list[list[list[int]]] = [e.in_module_ids for e in entries]
    # We register the book-keeping tensor indices as buffers.
    # By doing so they are automatically transferred to the device
    # This reduces CPU-device communications required to transfer these indices
    # TODO: Perhaps this can be made more elegant in the future, if someone
    #  decides to introduce a nn.BufferList container in torch
    self._entry_in_fold_idx_targets: list[list[str]] = []
    for i, e in enumerate(entries):
        self._entry_in_fold_idx_targets.append([])
        for j, fi in enumerate(e.in_fold_idx):
            in_fold_idx_target = f"_in_fold_idx_{i}_{j}"
            if isinstance(fi, Tensor):
                self.register_buffer(in_fold_idx_target, fi)
            else:
                setattr(self, in_fold_idx_target, fi)
            self._entry_in_fold_idx_targets[-1].append(in_fold_idx_target)

__iter__() ¤

Retrieve an iterator over address book entries, i.e., a tuple consisting of three objects: (i) the torch module to evaluate (it can be None if the entry is needed to return the output of the computational graph); (ii) for each input to the module (i.e., depending on the arity) we have the list of ids to the outputs of other modules (it can be empty if the module is an input module); and (iii) for each input to the module we have the fold indexing, which is used to retrieve the inputs to a module, even if they are folded modules.

Returns:

Type Description
Iterator[AddressBookEntry[TorchModuleT]]

An iterator over address book entries.

Source code in cirkit/backend/torch/graph/modules.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def __iter__(self) -> Iterator[AddressBookEntry[TorchModuleT]]:
    """Retrieve an iterator over address book entries, i.e., a tuple consisting of
    three objects: (i) the torch module to evaluate (it can be None if the entry
    is needed to return the output of the computational graph); (ii) for each input
    to the module (i.e., depending on the arity) we have the list of ids to the
    outputs of other modules (it can be empty if the module is an input module); and
    (iii) for each input to the module we have the fold indexing, which
    is used to retrieve the inputs to a module, even if they are folded modules.

    Returns:
        An iterator over address book entries.
    """
    for module, in_module_ids_hs, in_fold_idx_targets in zip(
        self._entry_modules, self._entry_in_module_ids, self._entry_in_fold_idx_targets
    ):
        yield AddressBookEntry(
            module,
            in_module_ids_hs,
            [getattr(self, target) for target in in_fold_idx_targets],
        )

__len__() ¤

Retrieve the length of the address book.

Returns:

Type Description
int

The number of address book entries.

Source code in cirkit/backend/torch/graph/modules.py
160
161
162
163
164
165
166
def __len__(self) -> int:
    """Retrieve the length of the address book.

    Returns:
        The number of address book entries.
    """
    return len(self._entry_modules)

lookup(module_outputs, *, in_graph=None) abstractmethod ¤

Retrieve an iterator that iteratively returns a torch module and the tensor inputs to it.

Parameters:

Name Type Description Default
module_outputs list[Tensor]

A list of the outputs of each torch module. This list is expected to be iteratively expanded as we continue evaluating the modules of the torch computational graph.

required
in_graph Tensor | None

An optional tensor input to the whole computational graph. This is used as input to the torch modules that do not receive input from other torch modules within the torch computationa graph.

None

Returns:

Type Description
Iterator[tuple[TorchModuleT | None, tuple]]

An iterator of tuples, where the first element is a torch module if we are

Iterator[tuple[TorchModuleT | None, tuple]]

retrieving the inputs to it, and None if we are retrieving the output of the

Iterator[tuple[TorchModuleT | None, tuple]]

whole computational graph (i.e., in the final step of the evaluation).

Iterator[tuple[TorchModuleT | None, tuple]]

The second element is instead a tuple of arguments that are input to the

Iterator[tuple[TorchModuleT | None, tuple]]

torch module (e.g., some input tensors)

Source code in cirkit/backend/torch/graph/modules.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
@abstractmethod
def lookup(
    self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchModuleT | None, tuple]]:
    """Retrieve an iterator that iteratively returns a torch module and the tensor inputs to it.

    Args:
        module_outputs: A list of the outputs of each torch module. This list is expected to
            be iteratively expanded as we continue evaluating the modules of the torch
            computational graph.
        in_graph: An optional tensor input to the whole computational graph. This is used
            as input to the torch modules that do not receive input from other torch
            modules within the torch computationa graph.

    Returns:
        An iterator of tuples, where the first element is a torch module if we are
        retrieving the inputs to it, and None if we are retrieving the output of the
        whole computational graph (i.e., in the final step of the evaluation).
        The second element is instead a tuple of arguments that are input to the
        torch module (e.g., some input tensors)
    """

AddressBookEntry dataclass ¤

Bases: Generic[TorchModuleT]

An entry of the address book data structure, which stores (i) the module (if it is not an output entry (i.e., an entry used to compute the output of the whole computational graph), and for each input module to it, (ii) it stores the unique indices of other modules, and (iii) the (optionally None) fold index tensor to apply in order to recover the input tensors to each fold.

Source code in cirkit/backend/torch/graph/modules.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@dataclass(frozen=True)
class AddressBookEntry(Generic[TorchModuleT]):
    """An entry of the address book data structure, which stores (i) the module (if it is
    not an output entry (i.e., an entry used to compute the output of the whole
    computational graph), and for each input module to it, (ii) it stores the unique indices
    of other modules, and (iii) the (optionally None) fold index tensor to apply in order
    to recover the input tensors to each fold.
    """

    module: TorchModuleT | None
    """The module the entry refers to. It can be None if the entry is then used to
    compute the output of the whole computational graph."""
    in_module_ids: list[list[int]]
    """For each input module, it stores the list of other module indices."""
    in_fold_idx: list[Tensor | tuple[slice | None, ...]]
    """For each input module, it stores the fold index tensor used to gather the
    input tensors to each fold. It is a tuple of optional slices whether there is no need of
    gathering the input tensors, i.e., if the indexing operation would act as an unsqueezing
    operation that can be much more efficient."""

in_fold_idx instance-attribute ¤

For each input module, it stores the fold index tensor used to gather the input tensors to each fold. It is a tuple of optional slices whether there is no need of gathering the input tensors, i.e., if the indexing operation would act as an unsqueezing operation that can be much more efficient.

in_module_ids instance-attribute ¤

For each input module, it stores the list of other module indices.

module instance-attribute ¤

The module the entry refers to. It can be None if the entry is then used to compute the output of the whole computational graph.

__init__(module, in_module_ids, in_fold_idx) ¤

FoldIndexInfo dataclass ¤

Bases: Generic[TorchModuleT]

The folding index information of a folded computational graph, i.e., a directed acylic graph of torch modules.

This data class stores (i) the topological ordering, (ii) the input fold index information for each torch module, and (iii) the output fold index information.

Source code in cirkit/backend/torch/graph/modules.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@dataclass(frozen=True)
class FoldIndexInfo(Generic[TorchModuleT]):
    """The folding index information of a folded computational graph, i.e.,
    a [directed acylic graph][cirkit.backend.torch.graph.modules.TorchDiAcyclicGraph]
    of [torch modules][cirkit.backend.torch.graph.modules.AbstractTorchModule].

    This data class stores (i) the topological ordering, (ii) the input fold index
    information for each torch module, and (iii) the output fold index information.
    """

    ordering: list[TorchModuleT]
    """The topological ordering of torch modules."""
    in_fold_idx: dict[int, list[list[tuple[int, int]]]]
    """The input fold index information. For each module index, it stores for each output
    fold computed by the module (first list), and for each input to the module (second list
    whose length is the arity), a tuple of (1) the input module index and (2) the fold index
    within that input module."""
    out_fold_idx: list[tuple[int, int]]
    """The output fold index information. For each output (first list), it stores a tuple of
    (1) the output module index and (2) the fold index within that output module."""

in_fold_idx instance-attribute ¤

The input fold index information. For each module index, it stores for each output fold computed by the module (first list), and for each input to the module (second list whose length is the arity), a tuple of (1) the input module index and (2) the fold index within that input module.

ordering instance-attribute ¤

The topological ordering of torch modules.

out_fold_idx instance-attribute ¤

The output fold index information. For each output (first list), it stores a tuple of (1) the output module index and (2) the fold index within that output module.

__init__(ordering, in_fold_idx, out_fold_idx) ¤

ModuleEvalFunctional ¤

Bases: Protocol

The protocol of a function that evaluates a module on some inputs.

Source code in cirkit/backend/torch/graph/modules.py
224
225
226
227
228
229
230
231
232
233
234
235
236
class ModuleEvalFunctional(Protocol):  # pylint: disable=too-few-public-methods
    """The protocol of a function that evaluates a module on some inputs."""

    def __call__(self, module: TorchModuleT, *inputs: Tensor) -> Tensor:
        """Evaluate a module on some inputs.

        Args:
            module: The module to evaluate.
            inputs: The tensor inputs to the module

        Returns:
            Tensor: The output of the module as specified by this functional.
        """

__call__(module, *inputs) ¤

Evaluate a module on some inputs.

Parameters:

Name Type Description Default
module TorchModuleT

The module to evaluate.

required
inputs Tensor

The tensor inputs to the module

()

Returns:

Name Type Description
Tensor Tensor

The output of the module as specified by this functional.

Source code in cirkit/backend/torch/graph/modules.py
227
228
229
230
231
232
233
234
235
236
def __call__(self, module: TorchModuleT, *inputs: Tensor) -> Tensor:
    """Evaluate a module on some inputs.

    Args:
        module: The module to evaluate.
        inputs: The tensor inputs to the module

    Returns:
        Tensor: The output of the module as specified by this functional.
    """

TorchDiAcyclicGraph ¤

Bases: Module, DiAcyclicGraph[TorchModuleT], ABC

A torch directed acyclic graph module, i.e., a computational graph made of torch modules.

Source code in cirkit/backend/torch/graph/modules.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class TorchDiAcyclicGraph(nn.Module, DiAcyclicGraph[TorchModuleT], ABC):
    """A torch directed acyclic graph module, i.e., a computational graph made of torch modules."""

    def __init__(
        self,
        modules: Sequence[TorchModuleT],
        in_modules: Mapping[TorchModuleT, Sequence[TorchModuleT]],
        outputs: Sequence[TorchModuleT],
        *,
        fold_idx_info: FoldIndexInfo[TorchModuleT] | None = None,
    ):
        """Initialize a torch computational graph.

        Args:
            modules: The module nodes.
            in_modules: A dictionary mapping modules to their input modules, if any.
            outputs: A list of modules that are the output modules in the computational graph.
            fold_idx_info: The folding index information.
                It can be None if the Torch graph is not folded.
        """
        modules: list[TorchModuleT] = nn.ModuleList(modules)  # type: ignore
        super().__init__()
        super(nn.Module, self).__init__(modules, in_modules, outputs)
        self._is_folded = fold_idx_info is not None
        if fold_idx_info is None:
            fold_idx_info = self._build_unfold_index_info()
        self._address_book = self._build_address_book(fold_idx_info)

    @property
    def is_folded(self) -> bool:
        """Retrieves whether the computational graph is folded or not.

        Returns:
            True if it is folded, False otherwise.
        """
        return self._is_folded

    @property
    def address_book(self) -> AddressBook[TorchModuleT]:
        """Retrieve the address book object of the computational graph.

        Returns:
            The address book.
        """
        return self._address_book

    def subgraph(self, *roots: TorchModuleT) -> Self:
        """Assuming the computational graph is not a folded one,
        this returns the sub-graph having the given root torch modules as output modules.

        Args:
            *roots: The root torch modules of the sub-graph to return.

        Returns:
            A new torch computational graph having the given roots as the output torch modules.

        Raises:
            ValueError: If the computational graph is folded.
        """
        if self.is_folded:
            raise ValueError("Cannot extract a sub-computational graph from a folded one")
        nodes, in_nodes = subgraph(roots, self.node_inputs)
        return self.__class__(nodes, in_nodes, outputs=roots)

    def evaluate(
        self, x: Tensor | None = None, module_fn: ModuleEvalFunctional | None = None
    ) -> Tensor:
        """Evaluate the Torch graph by following the topological ordering,
            and by using the address book information to retrieve the inputs to each module.

        Args:
            x: The input of the Torch computational graph. It can be None.
            module_fn: A functional over modules that overrides the forward method defined by a
                module. It can be None. If it is None, then the ```__call__``` method defined by
                the module itself is used.

        Returns:
            The output tensor of the Torch graph.
            If the Torch graph has multiple outputs, then they will be stacked.

        Raises:
            RuntimeError: If the address book is somehow not well-formed.
        """
        # Evaluate the computational graph by following the topological ordering,
        # and by using the book address information to retrieve the inputs to each
        # (possibly folded) torch module.
        module_outputs: list[Tensor] = []
        for module, inputs in self._address_book.lookup(module_outputs, in_graph=x):
            if module is None:
                (output,) = inputs
                return output
            if module_fn is None:
                y = module(*inputs)
            else:
                y = module_fn(module, *inputs)
            module_outputs.append(y)
        raise RuntimeError("The address book is malformed")

    @abstractmethod
    def _build_unfold_index_info(self) -> FoldIndexInfo[TorchModuleT]: ...

    @abstractmethod
    def _build_address_book(
        self, fold_idx_info: FoldIndexInfo[TorchModuleT]
    ) -> AddressBook[TorchModuleT]: ...

    def __repr__(self) -> str:
        def indent(s: str) -> str:
            ss = s.split("\n")
            r = ss[0]
            if len(ss) == 1:
                return r
            return r + "\n" + "\n".join(f"  {t}" for t in ss[1:])

        lines = [self.__class__.__name__ + "("]
        extra_lines = self.extra_repr()
        if extra_lines:
            lines.append(f"  {indent(extra_lines)}")
        for i, entry in enumerate(self._address_book):
            if entry.module is None:
                continue
            repr_module = indent(repr(entry.module))
            lines.append(f"  ({i}): {repr_module}")
        lines.append(")")
        return "\n".join(lines)

address_book property ¤

Retrieve the address book object of the computational graph.

Returns:

Type Description
AddressBook[TorchModuleT]

The address book.

is_folded property ¤

Retrieves whether the computational graph is folded or not.

Returns:

Type Description
bool

True if it is folded, False otherwise.

__init__(modules, in_modules, outputs, *, fold_idx_info=None) ¤

Initialize a torch computational graph.

Parameters:

Name Type Description Default
modules Sequence[TorchModuleT]

The module nodes.

required
in_modules Mapping[TorchModuleT, Sequence[TorchModuleT]]

A dictionary mapping modules to their input modules, if any.

required
outputs Sequence[TorchModuleT]

A list of modules that are the output modules in the computational graph.

required
fold_idx_info FoldIndexInfo[TorchModuleT] | None

The folding index information. It can be None if the Torch graph is not folded.

None
Source code in cirkit/backend/torch/graph/modules.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def __init__(
    self,
    modules: Sequence[TorchModuleT],
    in_modules: Mapping[TorchModuleT, Sequence[TorchModuleT]],
    outputs: Sequence[TorchModuleT],
    *,
    fold_idx_info: FoldIndexInfo[TorchModuleT] | None = None,
):
    """Initialize a torch computational graph.

    Args:
        modules: The module nodes.
        in_modules: A dictionary mapping modules to their input modules, if any.
        outputs: A list of modules that are the output modules in the computational graph.
        fold_idx_info: The folding index information.
            It can be None if the Torch graph is not folded.
    """
    modules: list[TorchModuleT] = nn.ModuleList(modules)  # type: ignore
    super().__init__()
    super(nn.Module, self).__init__(modules, in_modules, outputs)
    self._is_folded = fold_idx_info is not None
    if fold_idx_info is None:
        fold_idx_info = self._build_unfold_index_info()
    self._address_book = self._build_address_book(fold_idx_info)

__repr__() ¤

Source code in cirkit/backend/torch/graph/modules.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def __repr__(self) -> str:
    def indent(s: str) -> str:
        ss = s.split("\n")
        r = ss[0]
        if len(ss) == 1:
            return r
        return r + "\n" + "\n".join(f"  {t}" for t in ss[1:])

    lines = [self.__class__.__name__ + "("]
    extra_lines = self.extra_repr()
    if extra_lines:
        lines.append(f"  {indent(extra_lines)}")
    for i, entry in enumerate(self._address_book):
        if entry.module is None:
            continue
        repr_module = indent(repr(entry.module))
        lines.append(f"  ({i}): {repr_module}")
    lines.append(")")
    return "\n".join(lines)

evaluate(x=None, module_fn=None) ¤

Evaluate the Torch graph by following the topological ordering, and by using the address book information to retrieve the inputs to each module.

Parameters:

Name Type Description Default
x Tensor | None

The input of the Torch computational graph. It can be None.

None
module_fn ModuleEvalFunctional | None

A functional over modules that overrides the forward method defined by a module. It can be None. If it is None, then the __call__ method defined by the module itself is used.

None

Returns:

Type Description
Tensor

The output tensor of the Torch graph.

Tensor

If the Torch graph has multiple outputs, then they will be stacked.

Raises:

Type Description
RuntimeError

If the address book is somehow not well-formed.

Source code in cirkit/backend/torch/graph/modules.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def evaluate(
    self, x: Tensor | None = None, module_fn: ModuleEvalFunctional | None = None
) -> Tensor:
    """Evaluate the Torch graph by following the topological ordering,
        and by using the address book information to retrieve the inputs to each module.

    Args:
        x: The input of the Torch computational graph. It can be None.
        module_fn: A functional over modules that overrides the forward method defined by a
            module. It can be None. If it is None, then the ```__call__``` method defined by
            the module itself is used.

    Returns:
        The output tensor of the Torch graph.
        If the Torch graph has multiple outputs, then they will be stacked.

    Raises:
        RuntimeError: If the address book is somehow not well-formed.
    """
    # Evaluate the computational graph by following the topological ordering,
    # and by using the book address information to retrieve the inputs to each
    # (possibly folded) torch module.
    module_outputs: list[Tensor] = []
    for module, inputs in self._address_book.lookup(module_outputs, in_graph=x):
        if module is None:
            (output,) = inputs
            return output
        if module_fn is None:
            y = module(*inputs)
        else:
            y = module_fn(module, *inputs)
        module_outputs.append(y)
    raise RuntimeError("The address book is malformed")

subgraph(*roots) ¤

Assuming the computational graph is not a folded one, this returns the sub-graph having the given root torch modules as output modules.

Parameters:

Name Type Description Default
*roots TorchModuleT

The root torch modules of the sub-graph to return.

()

Returns:

Type Description
Self

A new torch computational graph having the given roots as the output torch modules.

Raises:

Type Description
ValueError

If the computational graph is folded.

Source code in cirkit/backend/torch/graph/modules.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def subgraph(self, *roots: TorchModuleT) -> Self:
    """Assuming the computational graph is not a folded one,
    this returns the sub-graph having the given root torch modules as output modules.

    Args:
        *roots: The root torch modules of the sub-graph to return.

    Returns:
        A new torch computational graph having the given roots as the output torch modules.

    Raises:
        ValueError: If the computational graph is folded.
    """
    if self.is_folded:
        raise ValueError("Cannot extract a sub-computational graph from a folded one")
    nodes, in_nodes = subgraph(roots, self.node_inputs)
    return self.__class__(nodes, in_nodes, outputs=roots)