Skip to content

parameters

parameters ¤

DEFAULT_PARAMETER_OPT_RULES = {LogSoftmaxPattern: apply_log_softmax, ReduceSumOuterProductPattern: apply_sum_outer_prod_einsum} module-attribute ¤

KroneckerOutParameterPattern ¤

Bases: ParameterOptPattern

Source code in cirkit/backend/torch/optimization/parameters.py
26
27
28
29
30
31
32
33
class KroneckerOutParameterPattern(ParameterOptPattern):
    @classmethod
    def is_output(cls) -> bool:
        return True

    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [TorchKroneckerParameter]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
31
32
33
@classmethod
def entries(cls) -> list[type[TorchParameterNode]]:
    return [TorchKroneckerParameter]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
27
28
29
@classmethod
def is_output(cls) -> bool:
    return True

LogSoftmaxPattern ¤

Bases: ParameterOptPatternDefn

Source code in cirkit/backend/torch/optimization/parameters.py
36
37
38
39
40
41
42
43
class LogSoftmaxPattern(ParameterOptPatternDefn):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [TorchLogParameter, TorchSoftmaxParameter]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
41
42
43
@classmethod
def entries(cls) -> list[type[TorchParameterNode]]:
    return [TorchLogParameter, TorchSoftmaxParameter]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
37
38
39
@classmethod
def is_output(cls) -> bool:
    return False

ReduceSumOuterProductPattern ¤

Bases: ParameterOptPatternDefn

Source code in cirkit/backend/torch/optimization/parameters.py
46
47
48
49
50
51
52
53
class ReduceSumOuterProductPattern(ParameterOptPatternDefn):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [TorchReduceSumParameter, TorchOuterProductParameter]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
51
52
53
@classmethod
def entries(cls) -> list[type[TorchParameterNode]]:
    return [TorchReduceSumParameter, TorchOuterProductParameter]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
47
48
49
@classmethod
def is_output(cls) -> bool:
    return False

_emit_outer_reduce_flatten_parameter(in_shape1, in_shape2, outer_dim, reduce_dim) ¤

Source code in cirkit/backend/torch/optimization/parameters.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def _emit_outer_reduce_flatten_parameter(
    in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], outer_dim: int, reduce_dim: int
) -> tuple[TorchEinsumParameter] | tuple[TorchEinsumParameter, TorchFlattenParameter]:
    # in_idx1 = [0, 1, 2, ..., N - 1]
    in_idx1: tuple[int, ...] = tuple(range(len(in_shape1)))
    # in_idx2 = [0, 1, 2, ..., N + 1, ..., N - 1]
    in_idx2: tuple[int, ...] = (
        tuple(range(outer_dim)) + (len(in_shape1),) + tuple(range(outer_dim + 1, len(in_shape1)))
    )
    # Apply the reduction to the indices, as to get the output indices of the einsum
    reduce_idx: list[tuple[int, ...]] = (
        list((i,) for i in range(outer_dim))
        + [(outer_dim, len(in_shape1))]
        + list((i,) for i in range(outer_dim + 1, len(in_shape1)))
    )
    del reduce_idx[reduce_dim]
    out_idx: tuple[int, ...] = tuple(itertools.chain.from_iterable(reduce_idx))

    # If we are reducing the dimension along which we compute the Kronecker product,
    # we just need an einsum
    einsum = TorchEinsumParameter((in_shape1, in_shape2), einsum=(in_idx1, in_idx2, out_idx))
    if outer_dim == reduce_dim:
        return (einsum,)

    # If we are NOT reducing the dimension along which we compute the Kronecker product,
    # we need to flatten some dimensions after the einsum
    if reduce_dim < outer_dim:
        start_dim, end_dim = outer_dim - 1, outer_dim
    else:
        start_dim, end_dim = outer_dim, outer_dim + 1
    flatten = TorchFlattenParameter(einsum.shape, start_dim=start_dim, end_dim=end_dim)
    return einsum, flatten

apply_log_softmax(compiler, match) ¤

Source code in cirkit/backend/torch/optimization/parameters.py
56
57
58
59
60
61
def apply_log_softmax(
    compiler: "TorchCompiler", match: ParameterOptMatch
) -> tuple[TorchLogSoftmaxParameter]:
    softmax = cast(TorchSoftmaxParameter, match.entries[1])
    log_softmax = TorchLogSoftmaxParameter(softmax.in_shapes[0], dim=softmax.dim)
    return (log_softmax,)

apply_sum_outer_prod_einsum(compiler, match) ¤

Source code in cirkit/backend/torch/optimization/parameters.py
 98
 99
100
101
102
103
104
105
106
107
108
def apply_sum_outer_prod_einsum(
    compiler: "TorchCompiler", match: ParameterOptMatch
) -> tuple[TorchEinsumParameter] | tuple[TorchEinsumParameter, TorchFlattenParameter]:
    outer_prod = cast(TorchOuterProductParameter, match.entries[1])
    reduce_sum = cast(TorchReduceSumParameter, match.entries[0])
    in_shape1, in_shape2 = outer_prod.in_shapes
    if len(in_shape1) > 4:
        raise NotImplementedError()
    outer_dim = outer_prod.dim
    reduce_dim = reduce_sum.dim
    return _emit_outer_reduce_flatten_parameter(in_shape1, in_shape2, outer_dim, reduce_dim)