def build_folded_graph(
ordering: Iterable[list[TorchModule]],
*,
outputs: Iterable[TorchModule],
incomings_fn: Callable[[TorchModule], Sequence[TorchModule]],
fold_group_fn: Callable[[list[TorchModule]], TorchModule],
) -> tuple[
list[TorchModule],
dict[TorchModule, list[TorchModule]],
list[TorchModule],
FoldIndexInfo,
]:
# A useful data structure mapping each unfolded module to
# (i) a 'fold_id' (a natural number) pointing to the module layer it is associated to; and
# (ii) a 'slice_idx' (a natural number) within the output of the folded module,
# which recovers the output of the unfolded module.
fold_idx: dict[AbstractTorchModule, tuple[int, int]] = {}
# A useful data structure mapping each folded module id to
# a tensor of indices IDX of size (F, H, 2), where F is the number of modules in the fold,
# H is the number of inputs to each fold. Each entry i,j,: of IDX is a pair (fold_id, slice_idx),
# pointing to the folded module of id 'fold_id' and to the slice 'slice_idx' within that fold.
in_fold_idx: dict[int, list[list[tuple[int, int]]]] = {}
# The list of folded modules and the inputs of each folded module
modules: list[AbstractTorchModule] = []
in_modules: dict[AbstractTorchModule, list[AbstractTorchModule]] = {}
# Fold modules in each frontier, by firstly finding the module groups to fold
# in each frontier, and then by stacking each group of modules into a folded module
for frontier in ordering:
# Retrieve the module groups we can fold
foldable_groups = group_foldable_modules(frontier)
# Fold each group of modules
for group in foldable_groups:
# Fold the modules group
folded_module = fold_group_fn(group)
# For each module in the group, retrieve the unfolded input modules
in_group_modules: list[Sequence[AbstractTorchModule]] = [incomings_fn(m) for m in group]
# Set the input modules
folded_in_modules = list(
{modules[fold_idx[mi][0]] for msi in in_group_modules for mi in msi}
)
in_modules[folded_module] = folded_in_modules
# Check if we are folding input modules
in_modules_idx: list[list[tuple[int, int]]]
if in_group_modules[0]:
in_modules_idx = [[fold_idx[mi] for mi in msi] for msi in in_group_modules]
else:
in_modules_idx = []
# Update the data structures
cur_module_id = len(modules)
for i, m in enumerate(group):
fold_idx[m] = (cur_module_id, i)
in_fold_idx[cur_module_id] = in_modules_idx
# Append the folded module
modules.append(folded_module)
# Instantiate the information on how aggregate the outputs in a single tensor
out_fold_idx = [fold_idx[m] for m in outputs]
# Construct the sequence of folded output modules
outputs = list(dict.fromkeys(modules[fi[0]] for fi in out_fold_idx))
return modules, in_modules, outputs, FoldIndexInfo(modules, in_fold_idx, out_fold_idx)