64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156 | def optimize_graph(
ordering: Iterable[TorchModule],
outputs: Iterable[TorchModule],
patterns: Iterable[GraphOptPattern],
*,
incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
outcomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
pattern_matcher_fn: PatternMatcherFunc,
match_optimizer_fn: MatchOptimizerFunc,
strategy: OptMatchStrategy = OptMatchStrategy.LARGEST_MATCH,
) -> tuple[list[TorchModule], dict[TorchModule, list[TorchModule]], list[TorchModule],] | None:
# TODO: generalize this as to cover patterns with multiply entry or exit points? (much more difficult)
ordering = list(ordering) if isinstance(ordering, Iterator) else ordering
outputs = list(outputs) if isinstance(outputs, Iterator) else outputs
# Match optimization patterns
# matches: list of all matched and grounded optimization rules
# module_matches: a map from modules to the matches they belong to, if any
matches, module_matches = match_optimization_patterns(
ordering,
outputs,
patterns,
incomings_fn=incomings_fn,
outcomings_fn=outcomings_fn,
pattern_matcher_fn=pattern_matcher_fn,
strategy=strategy,
)
# Check if no matches have been found. If so, then just return None
if not matches:
return None
# Run the matched optimization rules and collect the optimized modules
match_opt_modules: dict[GraphOptMatch, tuple[TorchModule, ...]] = {}
for match in matches:
match_opt_modules[match] = match_optimizer_fn(match)
# The list of optimized layer and the inputs of each optimized module
modules: list[TorchModule] = []
in_modules: dict[TorchModule, list[TorchModule]] = {}
# A map from matches to their entry point unoptimized modules
match_entry_points: dict[GraphOptMatch, TorchModule] = {}
# A map from matches to their exit point unoptimized modules
match_exit_points: dict[GraphOptMatch, TorchModule] = {}
# Build the optimize graph by following the topological ordering
for module in ordering:
match = module_matches.get(module, None)
# Check if the layer does not belong to any matched pattern
# If so, then just add it to the optimize layer as is
if match is None:
modules.append(module)
in_modules[module] = [
match_exit_points[module_matches[mi]] if mi in module_matches else mi
for mi in incomings_fn(module)
]
continue
# If the module belongs to a matched pattern (there can only be a single one by construction),
# but it is not the root in that pattern,
# then register it as the entry point of the matched sub-computational-graph, if not other entry
# point has been registered before.
if match not in match_entry_points:
match_entry_points[match] = module
# Check if the module is the root within the matched pattern
# If so, then add the corresponding sub-computational-graph optimization to the
# optimized graph, and build the connections
if module == match.entries[0]:
opt_modules = match_opt_modules[match]
modules.extend(opt_modules)
for i, om in enumerate(opt_modules):
if not i:
in_modules[om] = [
match_exit_points[module_matches[mi]] if mi in module_matches else mi
for mi in incomings_fn(match_entry_points[match])
]
else:
in_modules[om] = [opt_modules[i - 1]]
# Set the root model of the match as the exit point of the matched pattern
match_exit_points[match] = opt_modules[-1]
continue
# Retrieve the sequence of output modules of the computational graph
opt_outputs = [
match_exit_points[module_matches[m]] if m in module_matches else m for m in outputs
]
return modules, in_modules, opt_outputs
|