Skip to content

layers

layers ¤

DEFAULT_LAYER_FUSE_OPT_RULES = {TuckerPattern: apply_tucker, CandecompPattern: apply_candecomp} module-attribute ¤

DEFAULT_LAYER_SHATTER_OPT_RULES = {DenseKroneckerPattern: apply_dense_tensordot, TensorDotKroneckerPattern: apply_tensordot_tensordot} module-attribute ¤

CandecompPattern ¤

Bases: LayerOptPattern

Source code in cirkit/backend/torch/optimization/layers.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class CandecompPattern(LayerOptPattern):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchLayer]]:
        return [TorchSumLayer, TorchHadamardLayer]

    @classmethod
    def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
        return [{} for _ in cls.entries()]

    @classmethod
    def cpatterns(cls) -> list[dict[str, Any]]:
        return [{"arity": 1}, {}]

cpatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
56
57
58
@classmethod
def cpatterns(cls) -> list[dict[str, Any]]:
    return [{"arity": 1}, {}]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
48
49
50
@classmethod
def entries(cls) -> list[type[TorchLayer]]:
    return [TorchSumLayer, TorchHadamardLayer]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
44
45
46
@classmethod
def is_output(cls) -> bool:
    return False

ppatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
52
53
54
@classmethod
def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
    return [{} for _ in cls.entries()]

DenseKroneckerPattern ¤

Bases: LayerOptPattern

Source code in cirkit/backend/torch/optimization/layers.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class DenseKroneckerPattern(LayerOptPattern):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchLayer]]:
        return [TorchSumLayer]

    @classmethod
    def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
        return [{"weight": KroneckerOutParameterPattern}]

    @classmethod
    def cpatterns(cls) -> list[dict[str, Any]]:
        return [{"arity": 1}]

cpatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
74
75
76
@classmethod
def cpatterns(cls) -> list[dict[str, Any]]:
    return [{"arity": 1}]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
66
67
68
@classmethod
def entries(cls) -> list[type[TorchLayer]]:
    return [TorchSumLayer]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
62
63
64
@classmethod
def is_output(cls) -> bool:
    return False

ppatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
70
71
72
@classmethod
def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
    return [{"weight": KroneckerOutParameterPattern}]

TensorDotKroneckerPattern ¤

Bases: LayerOptPattern

Source code in cirkit/backend/torch/optimization/layers.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class TensorDotKroneckerPattern(LayerOptPattern):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchLayer]]:
        return [TorchTensorDotLayer]

    @classmethod
    def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
        return [{"weight": KroneckerOutParameterPattern}]

    @classmethod
    def cpatterns(cls) -> list[dict[str, Any]]:
        return [{}]

cpatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
92
93
94
@classmethod
def cpatterns(cls) -> list[dict[str, Any]]:
    return [{}]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
84
85
86
@classmethod
def entries(cls) -> list[type[TorchLayer]]:
    return [TorchTensorDotLayer]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
80
81
82
@classmethod
def is_output(cls) -> bool:
    return False

ppatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
88
89
90
@classmethod
def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
    return [{"weight": KroneckerOutParameterPattern}]

TuckerPattern ¤

Bases: LayerOptPattern

Source code in cirkit/backend/torch/optimization/layers.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class TuckerPattern(LayerOptPattern):
    @classmethod
    def is_output(cls) -> bool:
        return False

    @classmethod
    def entries(cls) -> list[type[TorchLayer]]:
        return [TorchSumLayer, TorchKroneckerLayer]

    @classmethod
    def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
        return [{} for _ in cls.entries()]

    @classmethod
    def cpatterns(cls) -> list[dict[str, Any]]:
        return [{"arity": 1}, {}]

cpatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
38
39
40
@classmethod
def cpatterns(cls) -> list[dict[str, Any]]:
    return [{"arity": 1}, {}]

entries() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
30
31
32
@classmethod
def entries(cls) -> list[type[TorchLayer]]:
    return [TorchSumLayer, TorchKroneckerLayer]

is_output() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
26
27
28
@classmethod
def is_output(cls) -> bool:
    return False

ppatterns() classmethod ¤

Source code in cirkit/backend/torch/optimization/layers.py
34
35
36
@classmethod
def ppatterns(cls) -> list[dict[str, ParameterOptPattern]]:
    return [{} for _ in cls.entries()]

_apply_tensordot_rule(compiler, num_input_units, num_output_units, weight, kronecker) ¤

Source code in cirkit/backend/torch/optimization/layers.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def _apply_tensordot_rule(
    compiler: "TorchCompiler",
    num_input_units: int,
    num_output_units: int,
    weight: TorchParameter,
    kronecker: TorchKroneckerParameter,
) -> tuple[TorchTensorDotLayer, TorchTensorDotLayer]:
    # Build new torch parameter computational graphs by taking
    # the sub-computational graph rooted at the inputs of the kronecker parameter node
    in_kronecker1, in_kronecker2 = weight.node_inputs(kronecker)
    weight1 = weight.subgraph(in_kronecker1)
    weight2 = weight.subgraph(in_kronecker2)

    # Instantiate two tensor dot layers
    num_inner_units = weight1.shape[0] * (num_input_units // weight1.shape[1])
    tdot1 = TorchTensorDotLayer(
        num_input_units,
        num_inner_units,
        weight=weight1,
        semiring=compiler.semiring,
    )
    tdot2 = TorchTensorDotLayer(
        num_inner_units,
        num_output_units,
        weight=weight2,
        semiring=compiler.semiring,
    )
    return tdot1, tdot2

apply_candecomp(compiler, match) ¤

Source code in cirkit/backend/torch/optimization/layers.py
110
111
112
113
114
115
116
117
118
119
120
def apply_candecomp(compiler: "TorchCompiler", match: LayerOptMatch) -> tuple[TorchCPTLayer]:
    dense = cast(TorchSumLayer, match.entries[0])
    hadamard = cast(TorchHadamardLayer, match.entries[1])
    cpt = TorchCPTLayer(
        hadamard.num_input_units,
        dense.num_output_units,
        hadamard.arity,
        weight=dense.weight,
        semiring=compiler.semiring,
    )
    return (cpt,)

apply_dense_tensordot(compiler, match) ¤

Source code in cirkit/backend/torch/optimization/layers.py
153
154
155
156
157
158
159
160
161
def apply_dense_tensordot(
    compiler: "TorchCompiler", match: LayerOptMatch
) -> tuple[TorchTensorDotLayer, TorchTensorDotLayer]:
    dense = cast(TorchSumLayer, match.entries[0])
    weight_patterns = match.pentries[0]["weight"]
    kronecker = cast(TorchKroneckerParameter, weight_patterns[0].entries[0])
    return _apply_tensordot_rule(
        compiler, dense.num_input_units, dense.num_output_units, dense.weight, kronecker
    )

apply_tensordot_tensordot(compiler, match) ¤

Source code in cirkit/backend/torch/optimization/layers.py
164
165
166
167
168
169
170
171
172
def apply_tensordot_tensordot(
    compiler: "TorchCompiler", match: LayerOptMatch
) -> tuple[TorchTensorDotLayer, TorchTensorDotLayer]:
    tdot = cast(TorchTensorDotLayer, match.entries[0])
    weight_patterns = match.pentries[0]["weight"]
    kronecker = cast(TorchKroneckerParameter, weight_patterns[0].entries[0])
    return _apply_tensordot_rule(
        compiler, tdot.num_input_units, tdot.num_output_units, tdot.weight, kronecker
    )

apply_tucker(compiler, match) ¤

Source code in cirkit/backend/torch/optimization/layers.py
 97
 98
 99
100
101
102
103
104
105
106
107
def apply_tucker(compiler: "TorchCompiler", match: LayerOptMatch) -> tuple[TorchTuckerLayer]:
    dense = cast(TorchSumLayer, match.entries[0])
    kronecker = cast(TorchKroneckerLayer, match.entries[1])
    tucker = TorchTuckerLayer(
        kronecker.num_input_units,
        dense.num_output_units,
        kronecker.arity,
        weight=dense.weight,
        semiring=compiler.semiring,
    )
    return (tucker,)