Skip to content

parameters

parameters ¤

DEFAULT_PARAMETER_OPT_RULES = {LogSoftmaxPattern: apply_log_softmax, ReduceSumOuterProductPattern: apply_sum_outer_prod_einsum} module-attribute ¤

KroneckerOutParameterPattern ¤

Bases: ParameterOptPatternDefn

This pattern detects Kronecker parameter which are output of the graph.

It is used when performing the tensor dot trick on sum or dot layers that have weights coming from such node.

See DenseKroneckerPattern.

Source code in cirkit/backend/torch/optimization/parameters.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class KroneckerOutParameterPattern(ParameterOptPatternDefn):
    """This pattern detects Kronecker parameter which are output of the graph.

    It is used when performing the tensor dot trick on sum or dot layers that have
    weights coming from such node.

    See [DenseKroneckerPattern][cirkit.backend.torch.optimization.layers.DenseKroneckerPattern].
    """

    @classmethod
    def is_output(cls) -> bool:
        return True

    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [TorchKroneckerParameter]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
40
41
42
@classmethod
def entries(cls) -> list[type[TorchParameterNode]]:
    return [TorchKroneckerParameter]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
36
37
38
@classmethod
def is_output(cls) -> bool:
    return True

LogSoftmaxPattern ¤

Bases: ParameterOptPatternDefn

Detect a sequence of Softmax node -> Log node

Source code in cirkit/backend/torch/optimization/parameters.py
45
46
47
48
49
50
51
52
53
54
class LogSoftmaxPattern(ParameterOptPatternDefn):
    """Detect a sequence of Softmax node -> Log node"""

    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [TorchLogParameter, TorchSoftmaxParameter]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
52
53
54
@classmethod
def entries(cls) -> list[type[TorchParameterNode]]:
    return [TorchLogParameter, TorchSoftmaxParameter]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
48
49
50
@classmethod
def is_output(cls) -> bool:
    return False

ReduceSumOuterProductPattern ¤

Bases: ParameterOptPatternDefn

Source code in cirkit/backend/torch/optimization/parameters.py
57
58
59
60
61
62
63
64
class ReduceSumOuterProductPattern(ParameterOptPatternDefn):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchParameterNode]]:
        return [TorchReduceSumParameter, TorchOuterProductParameter]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
62
63
64
@classmethod
def entries(cls) -> list[type[TorchParameterNode]]:
    return [TorchReduceSumParameter, TorchOuterProductParameter]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/parameters.py
58
59
60
@classmethod
def is_output(cls) -> bool:
    return False

apply_log_softmax(compiler, match) ¤

Fuse the log and softmax in one logsoftmax node.

Parameters:

Name Type Description Default
compiler TorchCompiler

The current compiler.

required
match ParameterOptMatch

The match object containing the modules to optimize.

required

Returns:

Type Description
tuple[TorchLogSoftmaxParameter]

tuple[TorchLogSoftmaxParameter]: the corresponding logsoftmax node.

Source code in cirkit/backend/torch/optimization/parameters.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def apply_log_softmax(  # pylint: disable=unused-argument
    compiler: "TorchCompiler", match: ParameterOptMatch
) -> tuple[TorchLogSoftmaxParameter]:
    """Fuse the log and softmax in one logsoftmax node.

    Args:
        compiler (TorchCompiler): The current compiler.
        match (ParameterOptMatch): The match object containing the modules to optimize.

    Returns:
       tuple[TorchLogSoftmaxParameter]: the corresponding logsoftmax node.
    """
    softmax = cast(TorchSoftmaxParameter, match.entries[1])
    log_softmax = TorchLogSoftmaxParameter(softmax.in_shapes[0], dim=softmax.dim)
    return (log_softmax,)

apply_sum_outer_prod_einsum(compiler, match) ¤

Transform the sum on an outer product into a single einsum to reduce memory usage.

Parameters:

Name Type Description Default
compiler TorchCompiler

Current torch compiler.

required
match ParameterOptMatch

Match containing the module to fuse.

required

Returns:

Type Description
tuple[TorchEinsumParameter] | tuple[TorchEinsumParameter, TorchFlattenParameter]

tuple[TorchEinsumParameter] | tuple[TorchEinsumParameter,TorchFlattenParameter]: returns the einsum corresponding to the matched modules.

Raises:

Type Description
NotImplementedError

The function is not implemented for more than 4 dimensions.

Source code in cirkit/backend/torch/optimization/parameters.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def apply_sum_outer_prod_einsum(  # pylint: disable=unused-argument
    compiler: "TorchCompiler", match: ParameterOptMatch
) -> tuple[TorchEinsumParameter] | tuple[TorchEinsumParameter, TorchFlattenParameter]:
    """Transform the sum on an outer product into a single einsum to reduce memory usage.

    Args:
        compiler (TorchCompiler): Current torch compiler.
        match (ParameterOptMatch): Match containing the module to fuse.

    Returns:
        tuple[TorchEinsumParameter] | tuple[TorchEinsumParameter,TorchFlattenParameter]:
            returns the einsum corresponding to the matched modules.

    Raises:
        NotImplementedError: The function is not implemented for more than 4 dimensions.
    """
    outer_prod = cast(TorchOuterProductParameter, match.entries[1])
    reduce_sum = cast(TorchReduceSumParameter, match.entries[0])
    in_shape1, in_shape2 = outer_prod.in_shapes
    if len(in_shape1) > 4:
        raise NotImplementedError()
    outer_dim = outer_prod.dim
    reduce_dim = reduce_sum.dim
    return _emit_outer_reduce_flatten_parameter(in_shape1, in_shape2, outer_dim, reduce_dim)