Skip to content

optimize

optimize ¤

GraphOptPattern = type[GraphOptPatternDefn[TorchModule]] module-attribute ¤

GraphOptMatch ¤

Bases: Generic[TorchModule]

Source code in cirkit/backend/torch/graph/optimize.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class GraphOptMatch(Generic[TorchModule]):
    def __init__(self, pattern: GraphOptPattern[TorchModule], entries: list[TorchModule]):
        self._pattern = pattern
        self._entries = entries

    @property
    def pattern(self) -> GraphOptPattern[TorchModule]:
        return self._pattern

    @property
    def entries(self) -> list[TorchModule]:
        return self._entries

    @property
    def size(self) -> int:
        return len(self._entries)

_entries = entries instance-attribute ¤

_pattern = pattern instance-attribute ¤

entries property ¤

pattern property ¤

size property ¤

__init__(pattern, entries) ¤

Source code in cirkit/backend/torch/graph/optimize.py
27
28
29
def __init__(self, pattern: GraphOptPattern[TorchModule], entries: list[TorchModule]):
    self._pattern = pattern
    self._entries = entries

GraphOptPatternDefn ¤

Bases: Generic[TorchModule]

Source code in cirkit/backend/torch/graph/optimize.py
13
14
15
16
17
18
19
20
class GraphOptPatternDefn(Generic[TorchModule]):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchModule]]:
        ...

entries() classmethod ¤

Source code in cirkit/backend/torch/graph/optimize.py
18
19
20
@classmethod
def entries(cls) -> list[type[TorchModule]]:
    ...

is_output() classmethod ¤

Source code in cirkit/backend/torch/graph/optimize.py
14
15
16
@classmethod
def is_output(cls) -> bool:
    return False

MatchOptimizerFunc ¤

Bases: Protocol

Source code in cirkit/backend/torch/graph/optimize.py
56
57
58
59
60
61
class MatchOptimizerFunc(Protocol):
    def __call__(
        self,
        match: GraphOptMatch[TorchModule],
    ) -> tuple[TorchModule, ...]:
        ...

__call__(match) ¤

Source code in cirkit/backend/torch/graph/optimize.py
57
58
59
60
61
def __call__(
    self,
    match: GraphOptMatch[TorchModule],
) -> tuple[TorchModule, ...]:
    ...

OptMatchStrategy ¤

Bases: IntEnum

Source code in cirkit/backend/torch/graph/optimize.py
 9
10
class OptMatchStrategy(IntEnum):
    LARGEST_MATCH = auto()

LARGEST_MATCH = auto() class-attribute instance-attribute ¤

PatternMatcherFunc ¤

Bases: Protocol

Source code in cirkit/backend/torch/graph/optimize.py
44
45
46
47
48
49
50
51
52
53
class PatternMatcherFunc(Protocol):
    def __call__(
        self,
        module: TorchModule,
        pattern: GraphOptPattern[TorchModule],
        *,
        incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
        outcomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    ) -> GraphOptMatch[TorchModule] | None:
        ...

__call__(module, pattern, *, incomings_fn, outcomings_fn) ¤

Source code in cirkit/backend/torch/graph/optimize.py
45
46
47
48
49
50
51
52
53
def __call__(
    self,
    module: TorchModule,
    pattern: GraphOptPattern[TorchModule],
    *,
    incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    outcomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
) -> GraphOptMatch[TorchModule] | None:
    ...

_match_pattern_graph(modules, pattern, *, incomings_fn, outcomings_fn, pattern_matcher_fn) ¤

Source code in cirkit/backend/torch/graph/optimize.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def _match_pattern_graph(
    modules: Iterable[TorchModule],
    pattern: GraphOptPattern[TorchModule],
    *,
    incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    outcomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    pattern_matcher_fn: PatternMatcherFunc,
) -> Iterator[GraphOptMatch[TorchModule]]:
    # Tries to match a pattern by rooting it in all the modules of the computational graph
    optional_matches = map(
        lambda m: pattern_matcher_fn(
            m, pattern, incomings_fn=incomings_fn, outcomings_fn=outcomings_fn
        ),
        modules,
    )
    return filter(lambda match: match is not None, optional_matches)

_prioritize_optimization_strategy(ordering, module_matches, *, strategy=OptMatchStrategy.LARGEST_MATCH, in_place=True) ¤

Source code in cirkit/backend/torch/graph/optimize.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def _prioritize_optimization_strategy(
    ordering: Iterable[TorchModule],
    module_matches: dict[TorchModule, list[GraphOptMatch[TorchModule]]],
    *,
    strategy: OptMatchStrategy = OptMatchStrategy.LARGEST_MATCH,
    in_place: bool = True,
) -> dict[TorchModule, GraphOptMatch[TorchModule]]:
    if not in_place:
        module_matches = module_matches.copy()
    prioritized_module_matches: dict[TorchModule, GraphOptMatch[TorchModule]] = {}

    # Follow the topological ordering of the computational graph and prune
    # pattern matches, according to the given prioritization strategy
    for module in ordering:
        matches = module_matches[module]
        if not matches:
            continue
        if len(matches) == 1:
            prioritized_module_matches[module] = matches[0]

        # Sort the matches based on the given strategy
        sorted_matches = _sort_matches_priority(matches, strategy=strategy)

        # Prune the 'excess' pattern matches
        for match in sorted_matches[1:]:
            for m in match.entries:
                module_matches[m].remove(match)
        prioritized_module_matches[module] = sorted_matches[0]

    return prioritized_module_matches

_sort_matches_priority(matches, *, strategy) ¤

Source code in cirkit/backend/torch/graph/optimize.py
233
234
235
236
237
238
239
240
def _sort_matches_priority(
    matches: list[GraphOptMatch[TorchModule]],
    *,
    strategy: OptMatchStrategy,
) -> list[GraphOptMatch[TorchModule]]:
    if strategy == OptMatchStrategy.LARGEST_MATCH:
        return sorted(matches, key=lambda m: m.size, reverse=True)
    assert False

match_optimization_patterns(ordering, outputs, patterns, *, incomings_fn, outcomings_fn, pattern_matcher_fn, strategy=OptMatchStrategy.LARGEST_MATCH) ¤

Source code in cirkit/backend/torch/graph/optimize.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def match_optimization_patterns(
    ordering: Iterable[TorchModule],
    outputs: Iterable[TorchModule],
    patterns: Iterable[GraphOptPattern[TorchModule]],
    *,
    incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    outcomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    pattern_matcher_fn: PatternMatcherFunc,
    strategy: OptMatchStrategy = OptMatchStrategy.LARGEST_MATCH,
) -> tuple[list[GraphOptMatch[TorchModule]], dict[TorchModule, GraphOptMatch[TorchModule]]]:
    ordering = list(ordering) if isinstance(ordering, Iterator) else ordering
    outputs = list(outputs) if isinstance(outputs, Iterator) else outputs

    # A map from modules to the list of found matches they belong to
    module_matches: dict[TorchModule, list[GraphOptMatch[TorchModule]]] = defaultdict(list)

    # For each given pattern, match it on the graph
    for pattern in patterns:
        # Get an iterator of matches, for a given pattern
        modules = outputs if pattern.is_output() else ordering
        for match in _match_pattern_graph(
            modules,
            pattern,
            incomings_fn=incomings_fn,
            outcomings_fn=outcomings_fn,
            pattern_matcher_fn=pattern_matcher_fn,
        ):
            # For each module found in a match, update the map from modules to found matches
            for matched_module in match.entries:
                module_matches[matched_module].append(match)

    # Prioritize the matched patterns
    prioritized_module_matches = _prioritize_optimization_strategy(
        ordering, module_matches, strategy=strategy, in_place=True
    )

    # Extract all the matches that are still active
    prioritized_matches = list(set(prioritized_module_matches.values()))

    return prioritized_matches, prioritized_module_matches

optimize_graph(ordering, outputs, patterns, *, incomings_fn, outcomings_fn, pattern_matcher_fn, match_optimizer_fn, strategy=OptMatchStrategy.LARGEST_MATCH) ¤

Source code in cirkit/backend/torch/graph/optimize.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def optimize_graph(
    ordering: Iterable[TorchModule],
    outputs: Iterable[TorchModule],
    patterns: Iterable[GraphOptPattern],
    *,
    incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    outcomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
    pattern_matcher_fn: PatternMatcherFunc,
    match_optimizer_fn: MatchOptimizerFunc,
    strategy: OptMatchStrategy = OptMatchStrategy.LARGEST_MATCH,
) -> tuple[list[TorchModule], dict[TorchModule, list[TorchModule]], list[TorchModule],] | None:
    # TODO: generalize this as to cover patterns with multiply entry or exit points? (much more difficult)

    ordering = list(ordering) if isinstance(ordering, Iterator) else ordering
    outputs = list(outputs) if isinstance(outputs, Iterator) else outputs

    # Match optimization patterns
    # matches: list of all matched and grounded optimization rules
    # module_matches: a map from modules to the matches they belong to, if any
    matches, module_matches = match_optimization_patterns(
        ordering,
        outputs,
        patterns,
        incomings_fn=incomings_fn,
        outcomings_fn=outcomings_fn,
        pattern_matcher_fn=pattern_matcher_fn,
        strategy=strategy,
    )

    # Check if no matches have been found. If so, then just return None
    if not matches:
        return None

    # Run the matched optimization rules and collect the optimized modules
    match_opt_modules: dict[GraphOptMatch, tuple[TorchModule, ...]] = {}
    for match in matches:
        match_opt_modules[match] = match_optimizer_fn(match)

    # The list of optimized layer and the inputs of each optimized module
    modules: list[TorchModule] = []
    in_modules: dict[TorchModule, list[TorchModule]] = {}

    # A map from matches to their entry point unoptimized modules
    match_entry_points: dict[GraphOptMatch, TorchModule] = {}

    # A map from matches to their exit point unoptimized modules
    match_exit_points: dict[GraphOptMatch, TorchModule] = {}

    # Build the optimize graph by following the topological ordering
    for module in ordering:
        match = module_matches.get(module, None)

        # Check if the layer does not belong to any matched pattern
        # If so, then just add it to the optimize layer as is
        if match is None:
            modules.append(module)
            in_modules[module] = [
                match_exit_points[module_matches[mi]] if mi in module_matches else mi
                for mi in incomings_fn(module)
            ]
            continue

        # If the module belongs to a matched pattern (there can only be a single one by construction),
        # but it is not the root in that pattern,
        # then register it as the entry point of the matched sub-computational-graph, if not other entry
        # point has been registered before.
        if match not in match_entry_points:
            match_entry_points[match] = module

        # Check if the module is the root within the matched pattern
        # If so, then add the corresponding sub-computational-graph optimization to the
        # optimized graph, and build the connections
        if module == match.entries[0]:
            opt_modules = match_opt_modules[match]
            modules.extend(opt_modules)
            for i, om in enumerate(opt_modules):
                if not i:
                    in_modules[om] = [
                        match_exit_points[module_matches[mi]] if mi in module_matches else mi
                        for mi in incomings_fn(match_entry_points[match])
                    ]
                else:
                    in_modules[om] = [opt_modules[i - 1]]
            # Set the root model of the match as the exit point of the matched pattern
            match_exit_points[match] = opt_modules[-1]
            continue

    # Retrieve the sequence of output modules of the computational graph
    opt_outputs = [
        match_exit_points[module_matches[m]] if m in module_matches else m for m in outputs
    ]

    return modules, in_modules, opt_outputs