Skip to content

parameters

parameters ¤

DEFAULT_PARAMETER_COMPILATION_RULES = {TensorParameter: compile_tensor_parameter, ConstantParameter: compile_constant_parameter, ReferenceParameter: compile_reference_parameter, IndexParameter: compile_index_parameter, SumParameter: compile_sum_parameter, HadamardParameter: compile_hadamard_parameter, KroneckerParameter: compile_kronecker_parameter, OuterProductParameter: compile_outer_product_parameter, OuterSumParameter: compile_outer_sum_parameter, ExpParameter: compile_exp_parameter, LogParameter: compile_log_parameter, SquareParameter: compile_square_parameter, SigmoidParameter: compile_sigmoid_parameter, ScaledSigmoidParameter: compile_scaled_sigmoid_parameter, ClampParameter: compile_clamp_parameter, ConjugateParameter: compile_conjugate_parameter, ReduceSumParameter: compile_reduce_sum_parameter, ReduceProductParameter: compile_reduce_product_parameter, ReduceLSEParameter: compile_reduce_lse_parameter, SoftmaxParameter: compile_softmax_parameter, LogSoftmaxParameter: compile_log_softmax_parameter, MixingWeightParameter: compile_mixing_weight_parameter, GaussianProductMean: compile_gaussian_product_mean, GaussianProductStddev: compile_gaussian_product_stddev, GaussianProductLogPartition: compile_gaussian_product_log_partition, PolynomialProduct: compile_polynomial_product, PolynomialDifferential: compile_polynomial_differential} module-attribute ¤

_retrieve_dtype(dtype) ¤

Source code in cirkit/backend/torch/rules/parameters.py
69
70
71
72
73
74
75
76
77
78
79
def _retrieve_dtype(dtype: DataType) -> torch.dtype:
    if dtype == DataType.INTEGER:
        return torch.int64
    default_float_dtype = torch.get_default_dtype()
    if dtype == DataType.REAL:
        return default_float_dtype
    if dtype == DataType.COMPLEX:
        return default_float_dtype.to_complex()
    raise ValueError(
        f"Cannot determine the torch.dtype to use, current default: {default_float_dtype}, given dtype: {dtype}"
    )

compile_clamp_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
178
179
180
def compile_clamp_parameter(compiler: "TorchCompiler", p: ClampParameter) -> TorchClampParameter:
    (in_shape,) = p.in_shapes
    return TorchClampParameter(in_shape, vmin=p.vmin, vmax=p.vmax)

compile_conjugate_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
183
184
185
186
187
def compile_conjugate_parameter(
    compiler: "TorchCompiler", p: ClampParameter
) -> TorchConjugateParameter:
    (in_shape,) = p.in_shapes
    return TorchConjugateParameter(in_shape)

compile_constant_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
92
93
94
95
96
97
98
def compile_constant_parameter(
    compiler: "TorchCompiler", p: ConstantParameter
) -> TorchTensorParameter:
    initializer_ = compiler.compile_initializer(p.initializer)
    compiled_p = TorchTensorParameter(*p.shape, requires_grad=False, initializer_=initializer_)
    compiler.state.register_compiled_parameter(p, compiled_p)
    return compiled_p

compile_exp_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
149
150
151
def compile_exp_parameter(compiler: "TorchCompiler", p: ExpParameter) -> TorchExpParameter:
    (in_shape,) = p.in_shapes
    return TorchExpParameter(in_shape)

compile_gaussian_product_log_partition(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
244
245
246
247
def compile_gaussian_product_log_partition(
    compiler: "TorchCompiler", p: GaussianProductLogPartition
) -> TorchGaussianProductLogPartition:
    return TorchGaussianProductLogPartition(*p.in_shapes)

compile_gaussian_product_mean(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
232
233
234
235
def compile_gaussian_product_mean(
    compiler: "TorchCompiler", p: GaussianProductMean
) -> TorchGaussianProductMean:
    return TorchGaussianProductMean(*p.in_shapes)

compile_gaussian_product_stddev(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
238
239
240
241
def compile_gaussian_product_stddev(
    compiler: "TorchCompiler", p: GaussianProductStddev
) -> TorchGaussianProductStddev:
    return TorchGaussianProductStddev(*p.in_shapes)

compile_hadamard_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
120
121
122
123
124
def compile_hadamard_parameter(
    compiler: "TorchCompiler", p: HadamardParameter
) -> TorchHadamardParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchHadamardParameter(in_shape1, in_shape2)

compile_index_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
110
111
112
def compile_index_parameter(compiler: "TorchCompiler", p: IndexParameter) -> TorchIndexParameter:
    (in_shape,) = p.in_shapes
    return TorchIndexParameter(in_shape, indices=p.indices, dim=p.axis)

compile_kronecker_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
127
128
129
130
131
def compile_kronecker_parameter(
    compiler: "TorchCompiler", p: KroneckerParameter
) -> TorchKroneckerParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchKroneckerParameter(in_shape1, in_shape2)

compile_log_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
154
155
156
def compile_log_parameter(compiler: "TorchCompiler", p: LogParameter) -> TorchLogParameter:
    (in_shape,) = p.in_shapes
    return TorchLogParameter(in_shape)

compile_log_softmax_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
218
219
220
221
222
def compile_log_softmax_parameter(
    compiler: "TorchCompiler", p: SoftmaxParameter
) -> TorchLogSoftmaxParameter:
    (in_shape,) = p.in_shapes
    return TorchLogSoftmaxParameter(in_shape, dim=p.axis)

compile_mixing_weight_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
225
226
227
228
229
def compile_mixing_weight_parameter(
    compiler: "TorchCompiler", p: MixingWeightParameter
) -> TorchMixingWeightParameter:
    (in_shape,) = p.in_shapes
    return TorchMixingWeightParameter(in_shape)

compile_outer_product_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
134
135
136
137
138
139
def compile_outer_product_parameter(
    compiler: "TorchCompiler",
    p: OuterProductParameter,
) -> TorchOuterProductParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchOuterProductParameter(in_shape1, in_shape2, dim=p.axis)

compile_outer_sum_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
142
143
144
145
146
def compile_outer_sum_parameter(
    compiler: "TorchCompiler", p: OuterSumParameter
) -> TorchOuterSumParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchOuterSumParameter(in_shape1, in_shape2, dim=p.axis)

compile_polynomial_differential(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
256
257
258
259
def compile_polynomial_differential(
    compiler: "TorchCompiler", p: PolynomialDifferential
) -> TorchPolynomialDifferential:
    return TorchPolynomialDifferential(*p.in_shapes, order=p.order)

compile_polynomial_product(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
250
251
252
253
def compile_polynomial_product(
    compiler: "TorchCompiler", p: PolynomialProduct
) -> TorchPolynomialProduct:
    return TorchPolynomialProduct(*p.in_shapes)

compile_reduce_lse_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
204
205
206
207
208
def compile_reduce_lse_parameter(
    compiler: "TorchCompiler", p: ReduceSumParameter
) -> TorchReduceLSEParameter:
    (in_shape,) = p.in_shapes
    return TorchReduceLSEParameter(in_shape, dim=p.axis)

compile_reduce_product_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
197
198
199
200
201
def compile_reduce_product_parameter(
    compiler: "TorchCompiler", p: ReduceProductParameter
) -> TorchReduceProductParameter:
    (in_shape,) = p.in_shapes
    return TorchReduceProductParameter(in_shape, dim=p.axis)

compile_reduce_sum_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
190
191
192
193
194
def compile_reduce_sum_parameter(
    compiler: "TorchCompiler", p: ReduceSumParameter
) -> TorchReduceSumParameter:
    (in_shape,) = p.in_shapes
    return TorchReduceSumParameter(in_shape, dim=p.axis)

compile_reference_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
101
102
103
104
105
106
107
def compile_reference_parameter(
    compiler: "TorchCompiler", p: ReferenceParameter
) -> TorchPointerParameter:
    # Obtain the other parameter's graph (and its fold index),
    # and wrap it in a pointer parameter node.
    compiled_p, fold_idx = compiler.state.retrieve_compiled_parameter(p.deref())
    return TorchPointerParameter(compiled_p, fold_idx=fold_idx)

compile_scaled_sigmoid_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
171
172
173
174
175
def compile_scaled_sigmoid_parameter(
    compiler: "TorchCompiler", p: ScaledSigmoidParameter
) -> TorchScaledSigmoidParameter:
    (in_shape,) = p.in_shapes
    return TorchScaledSigmoidParameter(in_shape, vmin=p.vmin, vmax=p.vmax)

compile_sigmoid_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
164
165
166
167
168
def compile_sigmoid_parameter(
    compiler: "TorchCompiler", p: SigmoidParameter
) -> TorchSigmoidParameter:
    (in_shape,) = p.in_shapes
    return TorchSigmoidParameter(in_shape)

compile_softmax_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
211
212
213
214
215
def compile_softmax_parameter(
    compiler: "TorchCompiler", p: SoftmaxParameter
) -> TorchSoftmaxParameter:
    (in_shape,) = p.in_shapes
    return TorchSoftmaxParameter(in_shape, dim=p.axis)

compile_square_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
159
160
161
def compile_square_parameter(compiler: "TorchCompiler", p: SquareParameter) -> TorchSquareParameter:
    (in_shape,) = p.in_shapes
    return TorchSquareParameter(in_shape)

compile_sum_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
115
116
117
def compile_sum_parameter(compiler: "TorchCompiler", p: SumParameter) -> TorchSumParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchSumParameter(in_shape1, in_shape2)

compile_tensor_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
82
83
84
85
86
87
88
89
def compile_tensor_parameter(compiler: "TorchCompiler", p: TensorParameter) -> TorchTensorParameter:
    initializer_ = compiler.compile_initializer(p.initializer)
    dtype = _retrieve_dtype(p.dtype)
    compiled_p = TorchTensorParameter(
        *p.shape, requires_grad=p.learnable, initializer_=initializer_, dtype=dtype
    )
    compiler.state.register_compiled_parameter(p, compiled_p)
    return compiled_p