Skip to content

modules

modules ¤

TorchModule = TypeVar('TorchModule', bound=AbstractTorchModule) module-attribute ¤

TypeVar: A torch module type that subclasses AbstractTorchModule.

AbstractTorchModule ¤

Bases: Module, ABC

An abstract class representing a torch.nn.Module that can be folded.

An abstract torch module is used as base class for both the circuit layers and the nodes of the computational graph of the parameters of each layer.

Source code in cirkit/backend/torch/graph/modules.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class AbstractTorchModule(nn.Module, ABC):
    """An abstract class representing a [torch.nn.Module][torch.nn.Module] that can be folded.

    An abstract torch module is used as base class for both the circuit layers and the nodes
    of the computational graph of the parameters of each layer.
    """

    def __init__(self, *, num_folds: int = 1):
        """Initialize the abstract torch module object.

        Args:
            num_folds: The number of folds computed by the module.
        """
        super().__init__()
        self._num_folds = num_folds

    @property
    def num_folds(self) -> int:
        """Retrieve the number of folds.

        Returns:
            The number of folds.
        """
        return self._num_folds

    @property
    @abstractmethod
    def fold_settings(self) -> tuple[Any, ...]:
        """Retrieve a tuple of attributes on which modules must agree on in order to be folded.

        Returns:
            A tuple of attributes.
        """

    @property
    def sub_modules(self) -> Mapping[str, "AbstractTorchModule"]:
        """Retrieve a dictionary mapping string identifiers to torch sub-modules,
        that must be passed to the ```__init__``` method of the top-level torch module.

        Returns:
            A dictionary of torch modules.
        """
        return {}

_num_folds = num_folds instance-attribute ¤

fold_settings abstractmethod property ¤

Retrieve a tuple of attributes on which modules must agree on in order to be folded.

Returns:

Type Description
tuple[Any, ...]

A tuple of attributes.

num_folds property ¤

Retrieve the number of folds.

Returns:

Type Description
int

The number of folds.

sub_modules property ¤

Retrieve a dictionary mapping string identifiers to torch sub-modules, that must be passed to the __init__ method of the top-level torch module.

Returns:

Type Description
Mapping[str, AbstractTorchModule]

A dictionary of torch modules.

__init__(*, num_folds=1) ¤

Initialize the abstract torch module object.

Parameters:

Name Type Description Default
num_folds int

The number of folds computed by the module.

1
Source code in cirkit/backend/torch/graph/modules.py
18
19
20
21
22
23
24
25
def __init__(self, *, num_folds: int = 1):
    """Initialize the abstract torch module object.

    Args:
        num_folds: The number of folds computed by the module.
    """
    super().__init__()
    self._num_folds = num_folds

AddressBook ¤

Bases: Module, ABC

The address book data structure, sometimes also known as book-keeping. The address book stores a list of AddressBookEntry, where each entry stores the information needed to gather the inputs to each (possibly folded) torch module.

Source code in cirkit/backend/torch/graph/modules.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class AddressBook(nn.Module, ABC):
    """The address book data structure, sometimes also known as book-keeping.
    The address book stores a list of
    [AddressBookEntry][cirkit.backend.torch.graph.modules.AddressBookEntry],
    where each entry stores the information needed to gather the inputs to each (possibly folded)
    torch module.
    """

    def __init__(self, entries: list[AddressBookEntry]) -> None:
        """Initializes an address book.

        Args:
            entries: The list of address book entries.

        Raises:
            ValueError: If the list of address book entries is empty.
            ValueError: If the last entry (i.e., the entry used to compute the output of
                the whole computational graph) has a torch module assigned to it, or
                if it has more than one fold index tensor, or if the fold index tensor
                is not a 1-dimensional tensor.
        """
        if not entries:
            raise ValueError("The list of address book entry must not be empty")
        last_entry = entries[-1]
        if last_entry.module is not None:
            raise ValueError(
                "The last entry of the address book must not have a module associated to it"
            )
        if len(last_entry.in_fold_idx) != 1:
            raise ValueError(
                "The last entry of the address book must have only one fold index tensor"
            )
        (out_fold_idx,) = last_entry.in_fold_idx
        if len(out_fold_idx.shape) != 1:
            raise ValueError("The output fold index tensor should be a 1-dimensional tensor")
        super().__init__()
        self._entries = entries
        self._num_outputs = out_fold_idx.shape[0]

    def __len__(self) -> int:
        """Retrieve the length of the address book.

        Returns:
            The number of address book entries.
        """
        return len(self._entries)

    def __iter__(self) -> Iterator[AddressBookEntry]:
        """Retrieve an iterator over address book entries.

        Returns:
            An iterator over address book entries.
        """
        return iter(self._entries)

    @property
    def num_outputs(self) -> int:
        """The number of outputs of the whole computational graph represented
        through the address book.

        For instance, for a circuit with $n$ output layers, this will be equal to $n$.

        Returns:
            The number of outputs.
        """
        return self._num_outputs

    @abstractmethod
    def lookup(
        self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
    ) -> Iterator[tuple[TorchModule | None, tuple]]:
        """Retrive an iterator that iteratively returns a torch module and the tensor inputs to it.

        Args:
            module_outputs: A list of the outputs of each torch module. This list is expected to
                be iteratively expanded as we continue evaluating the modules of the torch
                computational graph.
            in_graph: An optional tensor input to the whole computational graph. This is used
                as input to the torch modules that do not receive input from other torch
                modules within the torch computationa graph.

        Returns:
            An iterator of tuples, where the first element is a torch module if we are
            retriving the inputs to it, and None if we are retrieving the output of the
            whole computational graph (i.e., in the final step of the evaluation).
            The second element is instead a tuple of arguments that are input to the
            torch module (e.g., some input tensors)
        """

_entries = entries instance-attribute ¤

_num_outputs = out_fold_idx.shape[0] instance-attribute ¤

num_outputs property ¤

The number of outputs of the whole computational graph represented through the address book.

For instance, for a circuit with \(n\) output layers, this will be equal to \(n\).

Returns:

Type Description
int

The number of outputs.

__init__(entries) ¤

Initializes an address book.

Parameters:

Name Type Description Default
entries list[AddressBookEntry]

The list of address book entries.

required

Raises:

Type Description
ValueError

If the list of address book entries is empty.

ValueError

If the last entry (i.e., the entry used to compute the output of the whole computational graph) has a torch module assigned to it, or if it has more than one fold index tensor, or if the fold index tensor is not a 1-dimensional tensor.

Source code in cirkit/backend/torch/graph/modules.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def __init__(self, entries: list[AddressBookEntry]) -> None:
    """Initializes an address book.

    Args:
        entries: The list of address book entries.

    Raises:
        ValueError: If the list of address book entries is empty.
        ValueError: If the last entry (i.e., the entry used to compute the output of
            the whole computational graph) has a torch module assigned to it, or
            if it has more than one fold index tensor, or if the fold index tensor
            is not a 1-dimensional tensor.
    """
    if not entries:
        raise ValueError("The list of address book entry must not be empty")
    last_entry = entries[-1]
    if last_entry.module is not None:
        raise ValueError(
            "The last entry of the address book must not have a module associated to it"
        )
    if len(last_entry.in_fold_idx) != 1:
        raise ValueError(
            "The last entry of the address book must have only one fold index tensor"
        )
    (out_fold_idx,) = last_entry.in_fold_idx
    if len(out_fold_idx.shape) != 1:
        raise ValueError("The output fold index tensor should be a 1-dimensional tensor")
    super().__init__()
    self._entries = entries
    self._num_outputs = out_fold_idx.shape[0]

__iter__() ¤

Retrieve an iterator over address book entries.

Returns:

Type Description
Iterator[AddressBookEntry]

An iterator over address book entries.

Source code in cirkit/backend/torch/graph/modules.py
150
151
152
153
154
155
156
def __iter__(self) -> Iterator[AddressBookEntry]:
    """Retrieve an iterator over address book entries.

    Returns:
        An iterator over address book entries.
    """
    return iter(self._entries)

__len__() ¤

Retrieve the length of the address book.

Returns:

Type Description
int

The number of address book entries.

Source code in cirkit/backend/torch/graph/modules.py
142
143
144
145
146
147
148
def __len__(self) -> int:
    """Retrieve the length of the address book.

    Returns:
        The number of address book entries.
    """
    return len(self._entries)

lookup(module_outputs, *, in_graph=None) abstractmethod ¤

Retrive an iterator that iteratively returns a torch module and the tensor inputs to it.

Parameters:

Name Type Description Default
module_outputs list[Tensor]

A list of the outputs of each torch module. This list is expected to be iteratively expanded as we continue evaluating the modules of the torch computational graph.

required
in_graph Tensor | None

An optional tensor input to the whole computational graph. This is used as input to the torch modules that do not receive input from other torch modules within the torch computationa graph.

None

Returns:

Type Description
Iterator[tuple[TorchModule | None, tuple]]

An iterator of tuples, where the first element is a torch module if we are

Iterator[tuple[TorchModule | None, tuple]]

retriving the inputs to it, and None if we are retrieving the output of the

Iterator[tuple[TorchModule | None, tuple]]

whole computational graph (i.e., in the final step of the evaluation).

Iterator[tuple[TorchModule | None, tuple]]

The second element is instead a tuple of arguments that are input to the

Iterator[tuple[TorchModule | None, tuple]]

torch module (e.g., some input tensors)

Source code in cirkit/backend/torch/graph/modules.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@abstractmethod
def lookup(
    self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchModule | None, tuple]]:
    """Retrive an iterator that iteratively returns a torch module and the tensor inputs to it.

    Args:
        module_outputs: A list of the outputs of each torch module. This list is expected to
            be iteratively expanded as we continue evaluating the modules of the torch
            computational graph.
        in_graph: An optional tensor input to the whole computational graph. This is used
            as input to the torch modules that do not receive input from other torch
            modules within the torch computationa graph.

    Returns:
        An iterator of tuples, where the first element is a torch module if we are
        retriving the inputs to it, and None if we are retrieving the output of the
        whole computational graph (i.e., in the final step of the evaluation).
        The second element is instead a tuple of arguments that are input to the
        torch module (e.g., some input tensors)
    """

AddressBookEntry dataclass ¤

An entry of the address book data structure, which stores (i) the module (if it is not an output entry (i.e., an entry used to compute the output of the whole computational graph), and for each input module to it, (ii) it stores the unique indices of other modules, and (iii) the (optionally None) fold index tensor to apply in order to recover the input tensors to each fold.

Source code in cirkit/backend/torch/graph/modules.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@dataclass(frozen=True)
class AddressBookEntry:
    """An entry of the address book data structure, which stores (i) the module (if it is
    not an output entry (i.e., an entry used to compute the output of the whole
    computational graph), and for each input module to it, (ii) it stores the unique indices
    of other modules, and (iii) the (optionally None) fold index tensor to apply in order
    to recover the input tensors to each fold.
    """

    module: TorchModule | None
    """The module the entry refers to. It can be None if the entry is then used to
    compute the output of the whole computational graph."""
    in_module_ids: list[list[int]]
    """For each input module, it stores the list of other module indices."""
    in_fold_idx: list[Tensor | None]
    """For each input module, it stores the fold index tensor used to gather the
    input tensors to each fold. It is None whether there is no need of gathering the
    input tensors, i.e., if the indexing operation would act as an identity function."""

in_fold_idx instance-attribute ¤

For each input module, it stores the fold index tensor used to gather the input tensors to each fold. It is None whether there is no need of gathering the input tensors, i.e., if the indexing operation would act as an identity function.

in_module_ids instance-attribute ¤

For each input module, it stores the list of other module indices.

module instance-attribute ¤

The module the entry refers to. It can be None if the entry is then used to compute the output of the whole computational graph.

__init__(module, in_module_ids, in_fold_idx) ¤

FoldIndexInfo dataclass ¤

The folding index information of a folded computational graph, i.e., a directed acylic graph of torch modules.

This data class stores (i) the topological ordering, (ii) the input fold index information for each torch module, and (iii) the output fold index information.

Source code in cirkit/backend/torch/graph/modules.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@dataclass(frozen=True)
class FoldIndexInfo:
    """The folding index information of a folded computational graph, i.e.,
    a [directed acylic graph][cirkit.backend.torch.graph.modules.TorchDiAcyclicGraph]
    of [torch modules][cirkit.backend.torch.graph.modules.AbstractTorchModule].

    This data class stores (i) the topological ordering, (ii) the input fold index
    information for each torch module, and (iii) the output fold index information.
    """

    ordering: list[TorchModule]
    """The topological ordering of torch modules."""
    in_fold_idx: dict[int, list[list[tuple[int, int]]]]
    """The input fold index information. For each module index, it stores for each output
    fold computed by the module (first list), and for each input to the module (second list
    whose length is the arity), a tuple of (1) the input module index and (2) the fold index
    within that input module."""
    out_fold_idx: list[tuple[int, int]]
    """The output fold index information. For each output (first list), it stores a tuple of
    (1) the output module index and (2) the fold index within that output module."""

in_fold_idx instance-attribute ¤

The input fold index information. For each module index, it stores for each output fold computed by the module (first list), and for each input to the module (second list whose length is the arity), a tuple of (1) the input module index and (2) the fold index within that input module.

ordering instance-attribute ¤

The topological ordering of torch modules.

out_fold_idx instance-attribute ¤

The output fold index information. For each output (first list), it stores a tuple of (1) the output module index and (2) the fold index within that output module.

__init__(ordering, in_fold_idx, out_fold_idx) ¤

ModuleEvalFunctional ¤

Bases: Protocol

The protocol of a function that evaluates a module on some inputs.

Source code in cirkit/backend/torch/graph/modules.py
193
194
195
196
197
198
199
200
201
202
203
204
205
class ModuleEvalFunctional(Protocol):  # pylint: disable=too-few-public-methods
    """The protocol of a function that evaluates a module on some inputs."""

    def __call__(self, module: TorchModule, *inputs: Tensor) -> Tensor:
        """Evaluate a module on some inputs.

        Args:
            module: The module to evaluate.
            inputs: The tensor inputs to the module

        Returns:
            Tensor: The output of the module as specified by this functional.
        """

__call__(module, *inputs) ¤

Evaluate a module on some inputs.

Parameters:

Name Type Description Default
module TorchModule

The module to evaluate.

required
inputs Tensor

The tensor inputs to the module

()

Returns:

Name Type Description
Tensor Tensor

The output of the module as specified by this functional.

Source code in cirkit/backend/torch/graph/modules.py
196
197
198
199
200
201
202
203
204
205
def __call__(self, module: TorchModule, *inputs: Tensor) -> Tensor:
    """Evaluate a module on some inputs.

    Args:
        module: The module to evaluate.
        inputs: The tensor inputs to the module

    Returns:
        Tensor: The output of the module as specified by this functional.
    """

TorchDiAcyclicGraph ¤

Bases: Module, DiAcyclicGraph[TorchModule], ABC

A torch directed acyclic graph module, i.e., a computational graph made of torch modules.

Source code in cirkit/backend/torch/graph/modules.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
class TorchDiAcyclicGraph(nn.Module, DiAcyclicGraph[TorchModule], ABC):
    """A torch directed acyclic graph module, i.e., a computational graph made of torch modules."""

    def __init__(
        self,
        modules: Sequence[TorchModule],
        in_modules: dict[TorchModule, Sequence[TorchModule]],
        outputs: Sequence[TorchModule],
        *,
        fold_idx_info: FoldIndexInfo | None = None,
    ):
        """Initialize a torch computational graph.

        Args:
            modules: The module nodes.
            in_modules: A dictionary mapping modules to their input modules, if any.
            outputs: A list of modules that are the output modules in the computational graph.
            fold_idx_info: The folding index information.
                It can be None if the Torch graph is not folded.
        """
        modules: list[TorchModule] = nn.ModuleList(modules)  # type: ignore
        super().__init__()
        super(nn.Module, self).__init__(modules, in_modules, outputs)
        self._is_folded = fold_idx_info is not None
        if fold_idx_info is None:
            fold_idx_info = self._build_unfold_index_info()
        self._address_book = self._build_address_book(fold_idx_info)

    @property
    def is_folded(self) -> bool:
        """Retrieves whether the computational graph is folded or not.

        Returns:
            True if it is folded, False otherwise.
        """
        return self._is_folded

    @property
    def address_book(self) -> AddressBook:
        """Retrieve the address book object of the computational graph.

        Returns:
            The address book.
        """
        return self._address_book

    def subgraph(self, *roots: TorchModule) -> "TorchDiAcyclicGraph[TorchModule]":
        """Assuming the computational graph is not a folded one,
        this returns the sub-graph having the given root torch modules as output modules.

        Args:
            *roots: The root torch modules of the sub-graph to return.

        Returns:
            A new torch computational graph having the given roots as the output torch modules.

        Raises:
            ValueError: If the computational graph is folded.
        """
        if self.is_folded:
            raise ValueError("Cannot extract a sub-computational graph from a folded one")
        nodes, in_nodes = subgraph(roots, self.node_inputs)
        return self.__class__(nodes, in_nodes, outputs=roots)

    def evaluate(
        self, x: Tensor | None = None, module_fn: ModuleEvalFunctional | None = None
    ) -> Tensor:
        """Evaluate the Torch graph by following the topological ordering,
            and by using the address book information to retrieve the inputs to each module.

        Args:
            x: The input of the Torch computational graph. It can be None.
            module_fn: A functional over modules that overrides the forward method defined by a
                module. It can be None. If it is None, then the ```__call__``` method defined by
                the module itself is used.

        Returns:
            The output tensor of the Torch graph.
            If the Torch graph has multiple outputs, then they will be stacked.

        Raises:
            RuntimeError: If the address book is somehow not well-formed.
        """
        # Evaluate the computational graph by following the topological ordering,
        # and by using the book address information to retrieve the inputs to each
        # (possibly folded) torch module.
        module_outputs: list[Tensor] = []
        for module, inputs in self._address_book.lookup(module_outputs, in_graph=x):
            if module is None:
                (output,) = inputs
                return output
            if module_fn is None:
                y = module(*inputs)
            else:
                y = module_fn(module, *inputs)
            module_outputs.append(y)
        raise RuntimeError("The address book is malformed")

    @abstractmethod
    def _build_unfold_index_info(self) -> FoldIndexInfo:
        ...

    @abstractmethod
    def _build_address_book(self, fold_idx_info: FoldIndexInfo) -> AddressBook:
        ...

    def __repr__(self) -> str:
        def indent(s: str) -> str:
            s = s.split("\n")
            r = s[0]
            if len(s) == 1:
                return r
            return r + "\n" + "\n".join(f"  {t}" for t in s[1:])

        lines = [self.__class__.__name__ + "("]
        extra_lines = self.extra_repr()
        if extra_lines:
            lines.append(f"  {indent(extra_lines)}")
        for i, entry in enumerate(self._address_book):
            if entry.module is None:
                continue
            repr_module = indent(repr(entry.module))
            lines.append(f"  ({i}): {repr_module}")
        lines.append(")")
        return "\n".join(lines)

_address_book = self._build_address_book(fold_idx_info) instance-attribute ¤

_is_folded = fold_idx_info is not None instance-attribute ¤

address_book property ¤

Retrieve the address book object of the computational graph.

Returns:

Type Description
AddressBook

The address book.

is_folded property ¤

Retrieves whether the computational graph is folded or not.

Returns:

Type Description
bool

True if it is folded, False otherwise.

__init__(modules, in_modules, outputs, *, fold_idx_info=None) ¤

Initialize a torch computational graph.

Parameters:

Name Type Description Default
modules Sequence[TorchModule]

The module nodes.

required
in_modules dict[TorchModule, Sequence[TorchModule]]

A dictionary mapping modules to their input modules, if any.

required
outputs Sequence[TorchModule]

A list of modules that are the output modules in the computational graph.

required
fold_idx_info FoldIndexInfo | None

The folding index information. It can be None if the Torch graph is not folded.

None
Source code in cirkit/backend/torch/graph/modules.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def __init__(
    self,
    modules: Sequence[TorchModule],
    in_modules: dict[TorchModule, Sequence[TorchModule]],
    outputs: Sequence[TorchModule],
    *,
    fold_idx_info: FoldIndexInfo | None = None,
):
    """Initialize a torch computational graph.

    Args:
        modules: The module nodes.
        in_modules: A dictionary mapping modules to their input modules, if any.
        outputs: A list of modules that are the output modules in the computational graph.
        fold_idx_info: The folding index information.
            It can be None if the Torch graph is not folded.
    """
    modules: list[TorchModule] = nn.ModuleList(modules)  # type: ignore
    super().__init__()
    super(nn.Module, self).__init__(modules, in_modules, outputs)
    self._is_folded = fold_idx_info is not None
    if fold_idx_info is None:
        fold_idx_info = self._build_unfold_index_info()
    self._address_book = self._build_address_book(fold_idx_info)

__repr__() ¤

Source code in cirkit/backend/torch/graph/modules.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def __repr__(self) -> str:
    def indent(s: str) -> str:
        s = s.split("\n")
        r = s[0]
        if len(s) == 1:
            return r
        return r + "\n" + "\n".join(f"  {t}" for t in s[1:])

    lines = [self.__class__.__name__ + "("]
    extra_lines = self.extra_repr()
    if extra_lines:
        lines.append(f"  {indent(extra_lines)}")
    for i, entry in enumerate(self._address_book):
        if entry.module is None:
            continue
        repr_module = indent(repr(entry.module))
        lines.append(f"  ({i}): {repr_module}")
    lines.append(")")
    return "\n".join(lines)

_build_address_book(fold_idx_info) abstractmethod ¤

Source code in cirkit/backend/torch/graph/modules.py
310
311
312
@abstractmethod
def _build_address_book(self, fold_idx_info: FoldIndexInfo) -> AddressBook:
    ...

_build_unfold_index_info() abstractmethod ¤

Source code in cirkit/backend/torch/graph/modules.py
306
307
308
@abstractmethod
def _build_unfold_index_info(self) -> FoldIndexInfo:
    ...

evaluate(x=None, module_fn=None) ¤

Evaluate the Torch graph by following the topological ordering, and by using the address book information to retrieve the inputs to each module.

Parameters:

Name Type Description Default
x Tensor | None

The input of the Torch computational graph. It can be None.

None
module_fn ModuleEvalFunctional | None

A functional over modules that overrides the forward method defined by a module. It can be None. If it is None, then the __call__ method defined by the module itself is used.

None

Returns:

Type Description
Tensor

The output tensor of the Torch graph.

Tensor

If the Torch graph has multiple outputs, then they will be stacked.

Raises:

Type Description
RuntimeError

If the address book is somehow not well-formed.

Source code in cirkit/backend/torch/graph/modules.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def evaluate(
    self, x: Tensor | None = None, module_fn: ModuleEvalFunctional | None = None
) -> Tensor:
    """Evaluate the Torch graph by following the topological ordering,
        and by using the address book information to retrieve the inputs to each module.

    Args:
        x: The input of the Torch computational graph. It can be None.
        module_fn: A functional over modules that overrides the forward method defined by a
            module. It can be None. If it is None, then the ```__call__``` method defined by
            the module itself is used.

    Returns:
        The output tensor of the Torch graph.
        If the Torch graph has multiple outputs, then they will be stacked.

    Raises:
        RuntimeError: If the address book is somehow not well-formed.
    """
    # Evaluate the computational graph by following the topological ordering,
    # and by using the book address information to retrieve the inputs to each
    # (possibly folded) torch module.
    module_outputs: list[Tensor] = []
    for module, inputs in self._address_book.lookup(module_outputs, in_graph=x):
        if module is None:
            (output,) = inputs
            return output
        if module_fn is None:
            y = module(*inputs)
        else:
            y = module_fn(module, *inputs)
        module_outputs.append(y)
    raise RuntimeError("The address book is malformed")

subgraph(*roots) ¤

Assuming the computational graph is not a folded one, this returns the sub-graph having the given root torch modules as output modules.

Parameters:

Name Type Description Default
*roots TorchModule

The root torch modules of the sub-graph to return.

()

Returns:

Type Description
TorchDiAcyclicGraph[TorchModule]

A new torch computational graph having the given roots as the output torch modules.

Raises:

Type Description
ValueError

If the computational graph is folded.

Source code in cirkit/backend/torch/graph/modules.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def subgraph(self, *roots: TorchModule) -> "TorchDiAcyclicGraph[TorchModule]":
    """Assuming the computational graph is not a folded one,
    this returns the sub-graph having the given root torch modules as output modules.

    Args:
        *roots: The root torch modules of the sub-graph to return.

    Returns:
        A new torch computational graph having the given roots as the output torch modules.

    Raises:
        ValueError: If the computational graph is folded.
    """
    if self.is_folded:
        raise ValueError("Cannot extract a sub-computational graph from a folded one")
    nodes, in_nodes = subgraph(roots, self.node_inputs)
    return self.__class__(nodes, in_nodes, outputs=roots)