Skip to content

folding

folding ¤

build_address_book_entry(module, in_fold_idx, *, num_folds) ¤

Source code in cirkit/backend/torch/graph/folding.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def build_address_book_entry(
    module: TorchModuleT | None,
    in_fold_idx: list[list[tuple[int, int]]],
    *,
    num_folds: dict[int, int],
) -> AddressBookEntry[TorchModuleT]:
    # Transpose the index information, since we will build the
    # address book information for each operand independently
    # (this is because the inputs of modules might not be stacked,
    # e.g., in the parameter torch graph)
    in_fold_idx = [list(hi) for hi in zip(*in_fold_idx)]

    # Retrieve the unique fold indices that reference the module inputs
    in_module_ids = [list(dict.fromkeys(idx[0] for idx in hi)) for hi in in_fold_idx]

    # Compute the cumulative indices of the folded inputs
    cum_module_ids = [
        dict(zip(mids, itertools.accumulate([0] + [num_folds[mid] for mid in mids])))
        for mids in in_module_ids
    ]
    cum_fold_idx_t: list[Tensor | tuple] = []
    for i, hi in enumerate(in_fold_idx):
        cum_fold_i_idx: list[int] = [cum_module_ids[i][idx[0]] + idx[1] for idx in hi]

        # The following checks whether using the fold index would yield the same tensor
        # If so, then avoid indexing at all
        module_id = hi[0][0]
        cum_fold_i_idx_t: Tensor | tuple
        if all(idx[0] == module_id for idx in hi) and cum_fold_i_idx == list(
            range(num_folds[module_id])
        ):
            cum_fold_i_idx_t = ()
        else:
            cum_fold_i_idx_t = torch.tensor(cum_fold_i_idx)
        cum_fold_idx_t.append(cum_fold_i_idx_t)
    return AddressBookEntry(module, in_module_ids, cum_fold_idx_t)

build_address_book_stacked_entry(module, in_fold_idx, *, num_folds, output=False) ¤

Source code in cirkit/backend/torch/graph/folding.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def build_address_book_stacked_entry(
    module: TorchModuleT | None,
    in_fold_idx: list[list[tuple[int, int]]],
    *,
    num_folds: dict[int, int],
    output: bool = False,
) -> AddressBookEntry[TorchModuleT]:
    # Retrieve the unique fold indices that reference the module inputs
    in_module_ids = list(dict.fromkeys(idx[0] for fi in in_fold_idx for idx in fi))

    # Compute the cumulative indices of the folded inputs
    module_fold_sizes = [num_folds[mid] for mid in in_module_ids]
    cum_module_ids = dict(
        zip(
            in_module_ids,
            itertools.accumulate([0] + module_fold_sizes),
        )
    )

    # Build the bookkeeping entry
    cum_fold_idx = [[cum_module_ids[idx[0]] + idx[1] for idx in fi] for fi in in_fold_idx]

    # Check if we are computing the output stacked address book entry
    # If so, then squeeze the fold dimension that is equal to one
    if output:
        assert len(cum_fold_idx) == 1
        cum_fold_idx_t = torch.tensor(cum_fold_idx[0])
        return AddressBookEntry(module, [in_module_ids], [cum_fold_idx_t])

    # If we are computing a non-output stacked address book entry,
    # then check if the fold index would be equivalent to an 'unsqueeze' on dimensions 0 or 1.
    # If so, then replace the fold index with a more efficient unsqueezing operation
    fold_size = sum(module_fold_sizes)
    if [i for idx in cum_fold_idx for i in idx] == list(range(fold_size)):
        if len(cum_fold_idx) == 1 and len(cum_fold_idx[0]) == fold_size:
            # Equivalent to .unsqueeze(dim=0)
            return AddressBookEntry(module, [in_module_ids], [(None,)])
        if len(cum_fold_idx) == fold_size and len(cum_fold_idx[0]) == 1:
            # Equivalent to .unsqueeze(dim=1)
            return AddressBookEntry(module, [in_module_ids], [(slice(None), None)])
    cum_fold_idx_t = torch.tensor(cum_fold_idx)
    return AddressBookEntry(module, [in_module_ids], [cum_fold_idx_t])

build_folded_graph(ordering, *, outputs, incomings_fn, fold_group_fn) ¤

Find and apply all possible folding on a graph.

Parameters:

Name Type Description Default
ordering Iterable[list[TorchModuleT]]

Module in the graph in the layerwise topological order.

required
outputs Iterable[TorchModuleT]

Outputs of the graph.

required
incomings_fn Callable[[TorchModuleT], Sequence[TorchModuleT]]

Function returning the input modules of a given module.

required
fold_group_fn Callable[[list[TorchModuleT]], TorchModuleT]

Function returning a folded module givena group of modules.

required

Returns:

Type Description
list[TorchModuleT]

tuple[ list[TorchModuleT], dict[TorchModuleT, list[TorchModuleT]], list[TorchModuleT], FoldIndexInfo[TorchModuleT],

dict[TorchModuleT, list[TorchModuleT]]

]: - The final, potentially folded, modules. - The adjacency list updated with the folded modules. - The list of modules that acts as output of the graph. - A FoldIndexInfo objects which stores the information necessary to retrieve "locate" a unfolded module into the folded circuit. It is basically a map between a module from the unfolded circuit and a pair (id_folded_module, fold_id).

Source code in cirkit/backend/torch/graph/folding.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def build_folded_graph(
    ordering: Iterable[list[TorchModuleT]],
    *,
    outputs: Iterable[TorchModuleT],
    incomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
    fold_group_fn: Callable[[list[TorchModuleT]], TorchModuleT],
) -> tuple[
    list[TorchModuleT],
    dict[TorchModuleT, list[TorchModuleT]],
    list[TorchModuleT],
    FoldIndexInfo[TorchModuleT],
]:
    """Find and apply all possible folding on a graph.

    Args:
        ordering (Iterable[list[TorchModuleT]]):
            Module in the graph in the layerwise topological order.
        outputs (Iterable[TorchModuleT]): Outputs of the graph.
        incomings_fn (Callable[[TorchModuleT],Sequence[TorchModuleT]]):
            Function returning the input modules of a given module.
        fold_group_fn (Callable[[list[TorchModuleT]], TorchModuleT]):
            Function returning a folded module givena group of modules.

    Returns:
        tuple[
            list[TorchModuleT],
            dict[TorchModuleT, list[TorchModuleT]],
            list[TorchModuleT],
            FoldIndexInfo[TorchModuleT],
        ]:
            - The final, potentially folded, modules.
            - The adjacency list updated with the folded modules.
            - The list of modules that acts as output of the graph.
            - A `FoldIndexInfo` objects which stores the information necessary
            to retrieve "locate" a unfolded module into the folded circuit.
            It is basically a map between a module from the unfolded circuit
            and a pair (id_folded_module, fold_id).


    """
    # A useful data structure mapping each unfolded module to
    # (i) a 'fold_id' (a natural number) pointing to the module layer it is associated to; and
    # (ii) a 'slice_idx' (a natural number) within the output of the folded module,
    #      which recovers the output of the unfolded module.
    fold_idx: dict[TorchModuleT, tuple[int, int]] = {}

    # A useful data structure mapping each folded module id to
    # a tensor of indices IDX of size (F, H, 2), where F is the number of modules in the fold,
    # H is the number of inputs to each fold. Each entry i,j,: of IDX is a pair
    # (fold_id, slice_idx), pointing to the folded module of id 'fold_id' and to
    # the slice 'slice_idx' within that fold.
    in_fold_idx: dict[int, list[list[tuple[int, int]]]] = {}

    # The list of folded modules and the inputs of each folded module
    modules: list[TorchModuleT] = []
    in_modules: dict[TorchModuleT, list[TorchModuleT]] = {}

    # Fold modules in each frontier, by firstly finding the module groups to fold
    # in each frontier, and then by stacking each group of modules into a folded module
    for frontier in ordering:
        # Retrieve the module groups we can fold
        foldable_groups = group_foldable_modules(frontier)

        # Fold each group of modules
        for group in foldable_groups:
            # Fold the modules group
            folded_module = fold_group_fn(group)

            # For each module in the group, retrieve the unfolded input modules
            in_group_modules = [incomings_fn(m) for m in group]

            # Set the input modules
            folded_in_modules = list(
                {modules[fold_idx[mi][0]] for msi in in_group_modules for mi in msi}
            )
            in_modules[folded_module] = folded_in_modules

            # Check if we are folding input modules
            in_modules_idx: list[list[tuple[int, int]]]
            if in_group_modules[0]:
                in_modules_idx = [[fold_idx[mi] for mi in msi] for msi in in_group_modules]
            else:
                in_modules_idx = []

            # Update the data structures
            cur_module_id = len(modules)
            for i, m in enumerate(group):
                fold_idx[m] = (cur_module_id, i)
            in_fold_idx[cur_module_id] = in_modules_idx

            # Append the folded module
            modules.append(folded_module)

    # Instantiate the information on how aggregate the outputs in a single tensor
    out_fold_idx = [fold_idx[m] for m in outputs]

    # Construct the sequence of folded output modules
    outputs = list(dict.fromkeys(modules[fi[0]] for fi in out_fold_idx))

    return (
        modules,
        in_modules,
        outputs,
        FoldIndexInfo(modules, in_fold_idx, out_fold_idx),
    )

build_unfold_index_info(ordering, *, outputs, incomings_fn) ¤

Source code in cirkit/backend/torch/graph/folding.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def build_unfold_index_info(
    ordering: Iterable[TorchModuleT],
    *,
    outputs: Iterable[TorchModuleT],
    incomings_fn: Callable[[TorchModuleT], Sequence[TorchModuleT]],
) -> FoldIndexInfo[TorchModuleT]:
    # The topological ordering of modules
    ordering_ls: list[TorchModuleT] = list(ordering)

    # A useful data structure mapping each unfolded module to
    # (i) a 'fold_id' (a natural number) pointing to the module layer it is associated to; and
    # (ii) a 'slice_idx' (a natural number) within the output of the folded module,
    #      which recovers the output of the unfolded module.
    fold_idx: dict[AbstractTorchModule, tuple[int, int]] = {}

    # A useful data structure mapping each module id to
    # a tensor of indices IDX of size (F, H, 2), where F is the number of modules in the fold,
    # H is the number of inputs to each fold. Each entry i,j,: of IDX is a pair
    # (fold_id, slice_idx), pointing to the folded module of id 'fold_id' and to the slice
    # 'slice_idx' within that fold.
    in_fold_idx: dict[int, list[list[tuple[int, int]]]] = {}

    # Build the fold index information data structure, by following the topological ordering
    cur_module_id = 0
    for m in ordering_ls:
        if m.num_folds > 1:
            raise ValueError(
                f"Expected modules with fold dimension equal to one, found {m.num_folds}"
            )
        # Retrieve the input modules
        in_modules: Sequence[AbstractTorchModule] = incomings_fn(m)
        # Check if we are folding non-input modules
        in_modules_idx = [fold_idx[mi] for mi in in_modules] if in_modules else []

        # Update the data structures
        fold_idx[m] = (cur_module_id, 0)
        in_fold_idx[cur_module_id] = [in_modules_idx]
        cur_module_id += 1

    # Instantiate the information on how aggregate the outputs in a single tensor
    out_fold_idx = [fold_idx[m] for m in outputs]

    return FoldIndexInfo(ordering_ls, in_fold_idx, out_fold_idx)

group_foldable_modules(modules) ¤

Groups module that can be folded together.

Parameters:

Name Type Description Default
modules list[TorchModuleT]

Modules from the same level in the graph's layerwise topological ordering.

required

Returns:

Type Description
list[list[TorchModuleT]]

list[list[TorchModuleT]]: List of grouped torch module that can be folded together.

Source code in cirkit/backend/torch/graph/folding.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def group_foldable_modules(
    modules: list[TorchModuleT],
) -> list[list[TorchModuleT]]:
    """Groups module that can be folded together.

    Args:
        modules (list[TorchModuleT]): Modules from the same level in the graph's
            layerwise topological ordering.

    Returns:
        list[list[TorchModuleT]]: List of grouped torch module that can be folded together.
    """

    def _gather_fold_settings(module: AbstractTorchModule) -> tuple[Any, ...]:
        ss = [type(m), *m.fold_settings]
        for _, sub_module in module.sub_modules.items():
            sub_ss = _gather_fold_settings(sub_module)
            ss.extend(sub_ss)
        return tuple(ss)

    # A dictionary mapping a module fold settings,
    # which uniquely identifies a group of modules that can be folded,
    # into a group of modules.
    groups: dict[tuple[Any, ...], list[TorchModuleT]] = defaultdict(list)

    # For each module, either create a new group or insert it into an existing one
    for m in modules:
        m_settings = _gather_fold_settings(m)
        groups[m_settings].append(m)

    return list(groups.values())