Skip to content

optimize

optimize ¤

GraphOptPattern = type[GraphOptPatternDefn[TorchModuleT]] module-attribute ¤

GraphOptMatch ¤

Bases: Generic[TorchModuleT]

Class storing data related to a single match:

  • pattern (GraphOptPattern[TorchModuleT]): the pattern of the match.
  • entries (Sequence[TorchModuleT]): Modules, from the graph being searched, matching the entries types of the pattern.
  • sub_entries (Sequence[Mapping[str, Sequence[GraphOptMatch]]]): Modules corresponding to the sub_pattern method of the pattern.
Source code in cirkit/backend/torch/graph/optimize.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class GraphOptMatch(Generic[TorchModuleT]):
    """Class storing data related to a single match:

    - pattern (GraphOptPattern[TorchModuleT]): the pattern of the match.
    - entries (Sequence[TorchModuleT]): Modules, from the graph being searched,
        matching the `entries` types of the pattern.
    - sub_entries (Sequence[Mapping[str, Sequence[GraphOptMatch]]]):
        Modules corresponding to the `sub_pattern` method of the pattern.
    """

    def __init__(
        self,
        pattern: GraphOptPattern[TorchModuleT],
        entries: Sequence[TorchModuleT],
        sub_entries: Sequence[Mapping[str, Sequence["GraphOptMatch"]]] | None = None,
    ):
        self._pattern = pattern
        self._entries = entries
        self._sub_entries = sub_entries if sub_entries is not None else ()

    @property
    def pattern(self) -> GraphOptPattern[TorchModuleT]:
        return self._pattern

    @property
    def entries(self) -> Sequence[TorchModuleT]:
        return self._entries

    @property
    def sub_entries(self) -> Sequence[Mapping[str, Sequence["GraphOptMatch"]]]:
        return self._sub_entries

    @cached_property
    def size(self) -> int:
        """Count the number of entries and sub_entries in the match

        Returns:
            int: Number of entries and sub_entries
        """
        size = len(self._entries)
        for sub_entry in self.sub_entries:
            for matches in sub_entry.values():
                size += sum(match.size for match in matches)
        return size

entries property ¤

pattern property ¤

size cached property ¤

Count the number of entries and sub_entries in the match

Returns:

Name Type Description
int int

Number of entries and sub_entries

sub_entries property ¤

__init__(pattern, entries, sub_entries=None) ¤

Source code in cirkit/backend/torch/graph/optimize.py
131
132
133
134
135
136
137
138
139
def __init__(
    self,
    pattern: GraphOptPattern[TorchModuleT],
    entries: Sequence[TorchModuleT],
    sub_entries: Sequence[Mapping[str, Sequence["GraphOptMatch"]]] | None = None,
):
    self._pattern = pattern
    self._entries = entries
    self._sub_entries = sub_entries if sub_entries is not None else ()

GraphOptPatternDefn ¤

Bases: Generic[TorchModuleT]

Class defining a pattern in a graph.

Source code in cirkit/backend/torch/graph/optimize.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class GraphOptPatternDefn(Generic[TorchModuleT]):
    """Class defining a pattern in a graph."""

    @classmethod
    def is_output(cls) -> bool:
        """Define if the pattern should be searched in the graph's outputs only.

        Returns:
            bool: if True, search only in the outputs.
        """
        return False

    @classmethod
    def entries(cls) -> Sequence[type[TorchModuleT]]:
        """Returns an ordered sequence of module type that need to be matched in the
        order of the sequence when going through the graph in the reverse topological order.

        For example: the entry `[LayerType3, LayerType2]` will match the graph:

        `LayerType1 -> LayerType2 -> LayerType3 -> LayerType4`

        The match will be the graph rooted at `LayerType3`.

        Raises:
            NotImplementedError: This method need to be implemented for any pattern

        Returns:
            Sequence[type[TorchModuleT]]: the sequence in revese topological order.
        """
        raise NotImplementedError

    @classmethod
    def config_patterns(cls) -> Sequence[Mapping[str, Any]]:
        """Returns a list of dictionaries that match a config name to a config value.

        The “config” of a layer / parameter node is simply the dictionary returned by `Layer.config`.

        The dictionary at position x in the list define the config for the x-th element of the `entries` list.

        For example, the sum layer with config:

        ```python
        {
        "num_input_units":2,
        "num_output_units":1,
        "arity":1
        }
        ```

        Will match the pattern:

        ```python
        class ExamplePattern(LayerOptPatternDefn):
            def config_patterns():
                return [{"arity":1}]
            def entries():
                return [TorchSumLayer]
        ```


        Returns:
            Sequence[Mapping[str, Any]]: List of config name -> config value mappings
        """
        return ()

    @classmethod
    def sub_patterns(cls) -> Sequence[Mapping[str, type["GraphOptPatternDefn"]]]:
        """Returns a list of dictionaries that map layer's parameter names to a `ParameterOptPattern`.

        The dictionary at position x in the list define the config for the x-th element of the `entries` list.

        For example, you can match the weight parameter of a sum layer to be of a certain `ParameterType`:

        ```python
        class LayerPatternOne(LayerOptPatternDefn):
            @classmethodParameterType
            def entries(cls) -> Sequence[type[TorchLayer]]:
                return [TorchSumLayer]

            @classmethod
            def sub_patterns(cls) -> Sequence[dict[str, ParameterOptPattern]]:
                return [{"weight": ParameterPatternOne}]

        class ParameterPatternOne(ParameterOptPatternDefn):
            @classmethod
            def entries(cls) -> list[type[TorchParameterNode]]:
                return [ParameterType]
        ```

        `LayerPatternOne` will match the following layer:

        ```python
        TorchSumLayer(1,1,1,weight=ParameterType(...)
        ```

        Returns:
            Sequence[Mapping[str, type[GraphOptPatternDefn]]]: List of dictionaries that
                map layer's parameter names to a `ParameterOptPattern`
        """
        return ()

config_patterns() classmethod ¤

Returns a list of dictionaries that match a config name to a config value.

The “config” of a layer / parameter node is simply the dictionary returned by Layer.config.

The dictionary at position x in the list define the config for the x-th element of the entries list.

For example, the sum layer with config:

{
"num_input_units":2,
"num_output_units":1,
"arity":1
}

Will match the pattern:

class ExamplePattern(LayerOptPatternDefn):
    def config_patterns():
        return [{"arity":1}]
    def entries():
        return [TorchSumLayer]

Returns:

Type Description
Sequence[Mapping[str, Any]]

Sequence[Mapping[str, Any]]: List of config name -> config value mappings

Source code in cirkit/backend/torch/graph/optimize.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@classmethod
def config_patterns(cls) -> Sequence[Mapping[str, Any]]:
    """Returns a list of dictionaries that match a config name to a config value.

    The “config” of a layer / parameter node is simply the dictionary returned by `Layer.config`.

    The dictionary at position x in the list define the config for the x-th element of the `entries` list.

    For example, the sum layer with config:

    ```python
    {
    "num_input_units":2,
    "num_output_units":1,
    "arity":1
    }
    ```

    Will match the pattern:

    ```python
    class ExamplePattern(LayerOptPatternDefn):
        def config_patterns():
            return [{"arity":1}]
        def entries():
            return [TorchSumLayer]
    ```


    Returns:
        Sequence[Mapping[str, Any]]: List of config name -> config value mappings
    """
    return ()

entries() classmethod ¤

Returns an ordered sequence of module type that need to be matched in the order of the sequence when going through the graph in the reverse topological order.

For example: the entry [LayerType3, LayerType2] will match the graph:

LayerType1 -> LayerType2 -> LayerType3 -> LayerType4

The match will be the graph rooted at LayerType3.

Raises:

Type Description
NotImplementedError

This method need to be implemented for any pattern

Returns:

Type Description
Sequence[type[TorchModuleT]]

Sequence[type[TorchModuleT]]: the sequence in revese topological order.

Source code in cirkit/backend/torch/graph/optimize.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@classmethod
def entries(cls) -> Sequence[type[TorchModuleT]]:
    """Returns an ordered sequence of module type that need to be matched in the
    order of the sequence when going through the graph in the reverse topological order.

    For example: the entry `[LayerType3, LayerType2]` will match the graph:

    `LayerType1 -> LayerType2 -> LayerType3 -> LayerType4`

    The match will be the graph rooted at `LayerType3`.

    Raises:
        NotImplementedError: This method need to be implemented for any pattern

    Returns:
        Sequence[type[TorchModuleT]]: the sequence in revese topological order.
    """
    raise NotImplementedError

is_output() classmethod ¤

Define if the pattern should be searched in the graph's outputs only.

Returns:

Name Type Description
bool bool

if True, search only in the outputs.

Source code in cirkit/backend/torch/graph/optimize.py
19
20
21
22
23
24
25
26
@classmethod
def is_output(cls) -> bool:
    """Define if the pattern should be searched in the graph's outputs only.

    Returns:
        bool: if True, search only in the outputs.
    """
    return False

sub_patterns() classmethod ¤

Returns a list of dictionaries that map layer's parameter names to a ParameterOptPattern.

The dictionary at position x in the list define the config for the x-th element of the entries list.

For example, you can match the weight parameter of a sum layer to be of a certain ParameterType:

class LayerPatternOne(LayerOptPatternDefn):
    @classmethodParameterType
    def entries(cls) -> Sequence[type[TorchLayer]]:
        return [TorchSumLayer]

    @classmethod
    def sub_patterns(cls) -> Sequence[dict[str, ParameterOptPattern]]:
        return [{"weight": ParameterPatternOne}]

class ParameterPatternOne(ParameterOptPatternDefn):
    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [ParameterType]

LayerPatternOne will match the following layer:

TorchSumLayer(1,1,1,weight=ParameterType(...)

Returns:

Type Description
Sequence[Mapping[str, type[GraphOptPatternDefn]]]

Sequence[Mapping[str, type[GraphOptPatternDefn]]]: List of dictionaries that map layer's parameter names to a ParameterOptPattern

Source code in cirkit/backend/torch/graph/optimize.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@classmethod
def sub_patterns(cls) -> Sequence[Mapping[str, type["GraphOptPatternDefn"]]]:
    """Returns a list of dictionaries that map layer's parameter names to a `ParameterOptPattern`.

    The dictionary at position x in the list define the config for the x-th element of the `entries` list.

    For example, you can match the weight parameter of a sum layer to be of a certain `ParameterType`:

    ```python
    class LayerPatternOne(LayerOptPatternDefn):
        @classmethodParameterType
        def entries(cls) -> Sequence[type[TorchLayer]]:
            return [TorchSumLayer]

        @classmethod
        def sub_patterns(cls) -> Sequence[dict[str, ParameterOptPattern]]:
            return [{"weight": ParameterPatternOne}]

    class ParameterPatternOne(ParameterOptPatternDefn):
        @classmethod
        def entries(cls) -> list[type[TorchParameterNode]]:
            return [ParameterType]
    ```

    `LayerPatternOne` will match the following layer:

    ```python
    TorchSumLayer(1,1,1,weight=ParameterType(...)
    ```

    Returns:
        Sequence[Mapping[str, type[GraphOptPatternDefn]]]: List of dictionaries that
            map layer's parameter names to a `ParameterOptPattern`
    """
    return ()

MatchOptimizerFunc ¤

Bases: Protocol[TorchModuleT]

Defines the signature of a valide match optimizer.

Match optimizer are function which take a match object and returns the tuple of module that should be put in its place to optimize the graph.

Source code in cirkit/backend/torch/graph/optimize.py
187
188
189
190
191
192
193
194
195
196
197
198
class MatchOptimizerFunc(Protocol[TorchModuleT]):
    """Defines the signature of a valide match optimizer.

    Match optimizer are function which take a match object and returns
    the tuple of module that should be put in its place to optimize
    the graph.
    """

    def __call__(
        self,
        match: GraphOptMatch[TorchModuleT],
    ) -> tuple[TorchModuleT, ...]: ...

__call__(match) ¤

Source code in cirkit/backend/torch/graph/optimize.py
195
196
197
198
def __call__(
    self,
    match: GraphOptMatch[TorchModuleT],
) -> tuple[TorchModuleT, ...]: ...

OptMatchStrategy ¤

Bases: IntEnum

Strategy used to sort the matches and determine the one to keep

Source code in cirkit/backend/torch/graph/optimize.py
10
11
12
13
class OptMatchStrategy(IntEnum):
    """Strategy used to sort the matches and determine the one to keep"""

    LARGEST_MATCH = auto()

LARGEST_MATCH = auto() class-attribute instance-attribute ¤

PatternMatcherFunc ¤

Bases: Protocol[TorchModuleT]

Defines the signature of a valid pattern matching function.

Pattern matching function are functions which attempt to match a pattern in a graph starting at a given module.

They return either a match object or None if the match fails.

Source code in cirkit/backend/torch/graph/optimize.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class PatternMatcherFunc(Protocol[TorchModuleT]):
    """Defines the signature of a valid pattern matching function.

    Pattern matching function are functions which attempt to match
    a pattern in a graph starting at a given module.

    They return either a match object or None if the match fails.
    """

    def __call__(
        self,
        module: TorchModuleT,
        pattern: GraphOptPattern[TorchModuleT],
        /,
        *,
        incomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
        outcomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    ) -> GraphOptMatch[TorchModuleT] | None: ...

__call__(module, pattern, /, *, incomings_fn, outcomings_fn) ¤

Source code in cirkit/backend/torch/graph/optimize.py
176
177
178
179
180
181
182
183
184
def __call__(
    self,
    module: TorchModuleT,
    pattern: GraphOptPattern[TorchModuleT],
    /,
    *,
    incomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    outcomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
) -> GraphOptMatch[TorchModuleT] | None: ...

match_optimization_patterns(ordering, outputs, patterns, *, incomings_fn, outcomings_fn, pattern_matcher_fn, strategy=OptMatchStrategy.LARGEST_MATCH) ¤

Find and filter sections of a graph matching patterns. The function works as follows: 1. Use pattern_matcher_function to retrieve all matches for all patterns. 2. Filter the matches according to the strategy.

Parameters:

Name Type Description Default
ordering Iterable[TorchModuleT]

Torch modules in the graph to optimize.

required
outputs Iterable[TorchModuleT]

Torch modules acting as the graph's outputs.

required
patterns Iterable[GraphOptPattern[TorchModuleT]]

Iterable of patterns to search for.

required
incomings_fn Callable[[TorchModuleT], Sequence[TorchModuleT]]

Function that returns the inputs of the given module.

required
outcomings_fn Callable[[TorchModuleT], Sequence[TorchModuleT]]

Function that returns the outputs of the given module.

required
pattern_matcher_fn PatternMatcherFunc[TorchModuleT]

Function that tries to match a pattern using the given node as the root.

required
strategy OptMatchStrategy

Optimization strategy to deciding which match should be applied. Defaults to OptMatchStrategy.LARGEST_MATCH.

LARGEST_MATCH

Returns:

Type Description
tuple[list[GraphOptMatch[TorchModuleT]], dict[TorchModuleT, GraphOptMatch[TorchModuleT]]]

tuple[list[GraphOptMatch[TorchModuleT]], dict[TorchModuleT, GraphOptMatch[TorchModuleT]]]: - List of all the matches after priority-based filtering - Mapping from the module that matches the root of the pattern to the corresponding match.

Source code in cirkit/backend/torch/graph/optimize.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def match_optimization_patterns(
    ordering: Iterable[TorchModuleT],
    outputs: Iterable[TorchModuleT],
    patterns: Iterable[GraphOptPattern[TorchModuleT]],
    *,
    incomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    outcomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    pattern_matcher_fn: PatternMatcherFunc[TorchModuleT],
    strategy: OptMatchStrategy = OptMatchStrategy.LARGEST_MATCH,
) -> tuple[list[GraphOptMatch[TorchModuleT]], dict[TorchModuleT, GraphOptMatch[TorchModuleT]]]:
    """Find and filter sections of a graph matching patterns.
    The function works as follows:
    1. Use `pattern_matcher_function` to retrieve all matches
        for all patterns.
    2. Filter the matches according to the `strategy`.

    Args:
        ordering (Iterable[TorchModuleT]): Torch modules in the graph to optimize.
        outputs (Iterable[TorchModuleT]): Torch modules acting as the graph's outputs.
        patterns (Iterable[GraphOptPattern[TorchModuleT]]): Iterable of patterns to search for.
        incomings_fn (Callable[[TorchModuleT], Sequence[TorchModuleT]]):
            Function that returns the inputs of the given module.
        outcomings_fn (Callable[[TorchModuleT], Sequence[TorchModuleT]]):
            Function that returns the outputs of the given module.
        pattern_matcher_fn (PatternMatcherFunc[TorchModuleT]):
            Function that tries to match a pattern using the given node as the root.
        strategy (OptMatchStrategy, optional): Optimization strategy to deciding
            which match should be applied. Defaults to OptMatchStrategy.LARGEST_MATCH.

    Returns:
        tuple[list[GraphOptMatch[TorchModuleT]], dict[TorchModuleT, GraphOptMatch[TorchModuleT]]]:
            - List of all the matches after priority-based  filtering
            - Mapping from the module that matches the root of the pattern to the corresponding match.
    """
    ordering = list(ordering) if isinstance(ordering, Iterator) else ordering
    outputs = list(outputs) if isinstance(outputs, Iterator) else outputs

    # A map from modules to the list of found matches they belong to
    module_matches: dict[TorchModuleT, list[GraphOptMatch[TorchModuleT]]] = defaultdict(list)

    # For each given pattern, match it on the graph
    for pattern in patterns:
        # Get an iterator of matches, for a given pattern
        modules = outputs if pattern.is_output() else ordering
        for match in _match_pattern_graph(
            modules,
            pattern,
            incomings_fn=incomings_fn,
            outcomings_fn=outcomings_fn,
            pattern_matcher_fn=pattern_matcher_fn,
        ):
            # For each module found in a match, update the map from modules to found matches
            for matched_module in match.entries:
                module_matches[matched_module].append(match)

    # Prioritize the matched patterns
    prioritized_module_matches = _prioritize_optimization_strategy(
        ordering, module_matches, strategy=strategy, in_place=True
    )

    # Extract all the matches that are still active
    prioritized_matches = list(set(prioritized_module_matches.values()))

    return prioritized_matches, prioritized_module_matches

optimize_graph(ordering, outputs, patterns, *, incomings_fn, outcomings_fn, pattern_matcher_fn, match_optimizer_fn, strategy=OptMatchStrategy.LARGEST_MATCH) ¤

Search and Apply the optimization patterns on the graph.

Parameters:

Name Type Description Default
ordering Iterable[TorchModuleT]

Torch modules in the graph to optimize.

required
outputs Iterable[TorchModuleT]

Torch modules acting as the graph's outputs.

required
patterns Iterable[GraphOptPattern[TorchModuleT]]

Iterable of patterns to search for.

required
incomings_fn Callable[[TorchModuleT], Sequence[TorchModuleT]]

Function that returns the inputs of the given module.

required
outcomings_fn Callable[[TorchModuleT], Sequence[TorchModuleT]]

Function that returns the outputs of the given module.

required
pattern_matcher_fn PatternMatcherFunc[TorchModuleT]

Function that tries to match a pattern using the given node as the root.

required
match_optimizer_fn MatchOptimizerFunc[TorchModuleT]

Function that takes a match as parameter and return the tuple of module to replace it.

required
strategy OptMatchStrategy

Optimization strategy to deciding which match should be applied. Defaults to OptMatchStrategy.LARGEST_MATCH.

LARGEST_MATCH

Returns:

Type Description
tuple[list[TorchModuleT], dict[TorchModuleT, list[TorchModuleT]], list[TorchModuleT]] | None

tuple[list[TorchModuleT],dict[TorchModuleT, list[TorchModuleT]],list[TorchModuleT]] | None: - The list of all modules in the optimized graph. - The adjacency dictionary of the graph. - The list of modules from the graph that are outputs. If there are no optimization applied, the function simply returns None.

Source code in cirkit/backend/torch/graph/optimize.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def optimize_graph(
    ordering: Iterable[TorchModuleT],
    outputs: Iterable[TorchModuleT],
    patterns: Iterable[GraphOptPattern[TorchModuleT]],
    *,
    incomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    outcomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    pattern_matcher_fn: PatternMatcherFunc[TorchModuleT],
    match_optimizer_fn: MatchOptimizerFunc[TorchModuleT],
    strategy: OptMatchStrategy = OptMatchStrategy.LARGEST_MATCH,
) -> (
    tuple[
        list[TorchModuleT],
        dict[TorchModuleT, list[TorchModuleT]],
        list[TorchModuleT],
    ]
    | None
):
    """Search and Apply the optimization patterns on the graph.

    Args:
        ordering (Iterable[TorchModuleT]): Torch modules in the graph to optimize.
        outputs (Iterable[TorchModuleT]): Torch modules acting as the graph's outputs.
        patterns (Iterable[GraphOptPattern[TorchModuleT]]): Iterable of patterns to search for.
        incomings_fn (Callable[[TorchModuleT], Sequence[TorchModuleT]]):
            Function that returns the inputs of the given module.
        outcomings_fn (Callable[[TorchModuleT], Sequence[TorchModuleT]]):
            Function that returns the outputs of the given module.
        pattern_matcher_fn (PatternMatcherFunc[TorchModuleT]):
            Function that tries to match a pattern using the given node as the root.
        match_optimizer_fn (MatchOptimizerFunc[TorchModuleT]):
            Function that takes a match as parameter and return the tuple of module to replace it.
        strategy (OptMatchStrategy, optional): Optimization strategy to deciding
            which match should be applied. Defaults to OptMatchStrategy.LARGEST_MATCH.

    Returns:
        tuple[list[TorchModuleT],dict[TorchModuleT, list[TorchModuleT]],list[TorchModuleT]] | None:
            - The list of all modules in the optimized graph.
            - The adjacency dictionary of the graph.
            - The list of modules from the graph that are outputs.
            If there are no optimization applied, the function simply returns None.
    """
    # TODO: generalize this as to cover patterns with multiply entry or exit points?

    ordering = list(ordering) if isinstance(ordering, Iterator) else ordering
    outputs = list(outputs) if isinstance(outputs, Iterator) else outputs

    # Match optimization patterns
    # matches: list of all matched and grounded optimization rules
    # module_matches: a map from modules to the matches they belong to, if any
    matches, module_matches = match_optimization_patterns(
        ordering,
        outputs,
        patterns,
        incomings_fn=incomings_fn,
        outcomings_fn=outcomings_fn,
        pattern_matcher_fn=pattern_matcher_fn,
        strategy=strategy,
    )

    # Check if no matches have been found. If so, then just return None
    if not matches:
        return None

    # Run the matched optimization rules and collect the optimized modules
    match_opt_modules: dict[GraphOptMatch[TorchModuleT], tuple[TorchModuleT, ...]] = {}
    for match in matches:
        match_opt_modules[match] = match_optimizer_fn(match)

    # The list of optimized layer and the inputs of each optimized module
    modules: list[TorchModuleT] = []
    in_modules: dict[TorchModuleT, list[TorchModuleT]] = {}

    # A map from matches to their entry point unoptimized modules
    match_entry_points: dict[GraphOptMatch[TorchModuleT], TorchModuleT] = {}

    # A map from matches to their exit point unoptimized modules
    match_exit_points: dict[GraphOptMatch[TorchModuleT], TorchModuleT] = {}

    # Build the optimize graph by following the topological ordering
    for module in ordering:
        optional_match = module_matches.get(module, None)

        # Check if the layer does not belong to any matched pattern
        # If so, then just add it to the optimize layer as is
        if optional_match is None:
            modules.append(module)
            in_modules[module] = [
                match_exit_points[module_matches[mi]] if mi in module_matches else mi
                for mi in incomings_fn(module)
            ]
            continue
        match = optional_match

        # If the module belongs to a matched pattern (there can only be a single one
        # by construction), but it is not the root in that pattern,
        # then register it as the entry point of the matched sub-computational-graph,
        # if not other entry point has been registered before.
        if match not in match_entry_points:
            match_entry_points[match] = module

        # Check if the module is the root within the matched pattern
        # If so, then add the corresponding sub-computational-graph optimization to the
        # optimized graph, and build the connections
        if module == match.entries[0]:
            opt_modules = match_opt_modules[match]
            modules.extend(opt_modules)
            for i, om in enumerate(opt_modules):
                if not i:
                    in_modules[om] = [
                        match_exit_points[module_matches[mi]] if mi in module_matches else mi
                        for mi in incomings_fn(match_entry_points[match])
                    ]
                else:
                    in_modules[om] = [opt_modules[i - 1]]
            # Set the root model of the match as the exit point of the matched pattern
            match_exit_points[match] = opt_modules[-1]
            continue

    # Retrieve the sequence of output modules of the computational graph
    opt_outputs = [
        match_exit_points[module_matches[m]] if m in module_matches else m for m in outputs
    ]

    return modules, in_modules, opt_outputs