Skip to content

parameters

parameters ¤

DEFAULT_PARAMETER_COMPILATION_RULES = {TensorParameter: compile_tensor_parameter, ConstantParameter: compile_constant_parameter, ReferenceParameter: compile_reference_parameter, IndexParameter: compile_index_parameter, SumParameter: compile_sum_parameter, HadamardParameter: compile_hadamard_parameter, KroneckerParameter: compile_kronecker_parameter, OuterProductParameter: compile_outer_product_parameter, OuterSumParameter: compile_outer_sum_parameter, ExpParameter: compile_exp_parameter, LogParameter: compile_log_parameter, SquareParameter: compile_square_parameter, SigmoidParameter: compile_sigmoid_parameter, ScaledSigmoidParameter: compile_scaled_sigmoid_parameter, ClampParameter: compile_clamp_parameter, SoftplusParameter: compile_softplus_parameter, ConjugateParameter: compile_conjugate_parameter, ReduceSumParameter: compile_reduce_sum_parameter, ReduceProductParameter: compile_reduce_product_parameter, ReduceLSEParameter: compile_reduce_lse_parameter, SoftmaxParameter: compile_softmax_parameter, LogSoftmaxParameter: compile_log_softmax_parameter, MixingWeightParameter: compile_mixing_weight_parameter, GaussianProductMean: compile_gaussian_product_mean, GaussianProductStddev: compile_gaussian_product_stddev, GaussianProductLogPartition: compile_gaussian_product_log_partition, PolynomialProduct: compile_polynomial_product, PolynomialDifferential: compile_polynomial_differential} module-attribute ¤

compile_clamp_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
188
189
190
def compile_clamp_parameter(compiler: "TorchCompiler", p: ClampParameter) -> TorchClampParameter:
    (in_shape,) = p.in_shapes
    return TorchClampParameter(in_shape, vmin=p.vmin, vmax=p.vmax)

compile_conjugate_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
200
201
202
203
204
def compile_conjugate_parameter(
    compiler: "TorchCompiler", p: ClampParameter
) -> TorchConjugateParameter:
    (in_shape,) = p.in_shapes
    return TorchConjugateParameter(in_shape)

compile_constant_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
 99
100
101
102
103
104
105
106
107
108
def compile_constant_parameter(
    compiler: "TorchCompiler", p: ConstantParameter
) -> TorchTensorParameter:
    initializer_ = compiler.compile_initializer(p.initializer)
    dtype = _retrieve_dtype(p.dtype)
    compiled_p = TorchTensorParameter(
        *p.shape, requires_grad=False, initializer_=initializer_, dtype=dtype
    )
    compiler.state.register_compiled_parameter(p, compiled_p)
    return compiled_p

compile_exp_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
159
160
161
def compile_exp_parameter(compiler: "TorchCompiler", p: ExpParameter) -> TorchExpParameter:
    (in_shape,) = p.in_shapes
    return TorchExpParameter(in_shape)

compile_gaussian_product_log_partition(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
261
262
263
264
def compile_gaussian_product_log_partition(
    compiler: "TorchCompiler", p: GaussianProductLogPartition
) -> TorchGaussianProductLogPartition:
    return TorchGaussianProductLogPartition(*p.in_shapes)

compile_gaussian_product_mean(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
249
250
251
252
def compile_gaussian_product_mean(
    compiler: "TorchCompiler", p: GaussianProductMean
) -> TorchGaussianProductMean:
    return TorchGaussianProductMean(*p.in_shapes)

compile_gaussian_product_stddev(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
255
256
257
258
def compile_gaussian_product_stddev(
    compiler: "TorchCompiler", p: GaussianProductStddev
) -> TorchGaussianProductStddev:
    return TorchGaussianProductStddev(*p.in_shapes)

compile_hadamard_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
130
131
132
133
134
def compile_hadamard_parameter(
    compiler: "TorchCompiler", p: HadamardParameter
) -> TorchHadamardParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchHadamardParameter(in_shape1, in_shape2)

compile_index_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
120
121
122
def compile_index_parameter(compiler: "TorchCompiler", p: IndexParameter) -> TorchIndexParameter:
    (in_shape,) = p.in_shapes
    return TorchIndexParameter(in_shape, indices=p.indices, dim=p.axis)

compile_kronecker_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
137
138
139
140
141
def compile_kronecker_parameter(
    compiler: "TorchCompiler", p: KroneckerParameter
) -> TorchKroneckerParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchKroneckerParameter(in_shape1, in_shape2)

compile_log_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
164
165
166
def compile_log_parameter(compiler: "TorchCompiler", p: LogParameter) -> TorchLogParameter:
    (in_shape,) = p.in_shapes
    return TorchLogParameter(in_shape)

compile_log_softmax_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
235
236
237
238
239
def compile_log_softmax_parameter(
    compiler: "TorchCompiler", p: SoftmaxParameter
) -> TorchLogSoftmaxParameter:
    (in_shape,) = p.in_shapes
    return TorchLogSoftmaxParameter(in_shape, dim=p.axis)

compile_mixing_weight_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
242
243
244
245
246
def compile_mixing_weight_parameter(
    compiler: "TorchCompiler", p: MixingWeightParameter
) -> TorchMixingWeightParameter:
    (in_shape,) = p.in_shapes
    return TorchMixingWeightParameter(in_shape)

compile_outer_product_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
144
145
146
147
148
149
def compile_outer_product_parameter(
    compiler: "TorchCompiler",
    p: OuterProductParameter,
) -> TorchOuterProductParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchOuterProductParameter(in_shape1, in_shape2, dim=p.axis)

compile_outer_sum_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
152
153
154
155
156
def compile_outer_sum_parameter(
    compiler: "TorchCompiler", p: OuterSumParameter
) -> TorchOuterSumParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchOuterSumParameter(in_shape1, in_shape2, dim=p.axis)

compile_polynomial_differential(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
273
274
275
276
def compile_polynomial_differential(
    compiler: "TorchCompiler", p: PolynomialDifferential
) -> TorchPolynomialDifferential:
    return TorchPolynomialDifferential(*p.in_shapes, order=p.order)

compile_polynomial_product(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
267
268
269
270
def compile_polynomial_product(
    compiler: "TorchCompiler", p: PolynomialProduct
) -> TorchPolynomialProduct:
    return TorchPolynomialProduct(*p.in_shapes)

compile_reduce_lse_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
221
222
223
224
225
def compile_reduce_lse_parameter(
    compiler: "TorchCompiler", p: ReduceSumParameter
) -> TorchReduceLSEParameter:
    (in_shape,) = p.in_shapes
    return TorchReduceLSEParameter(in_shape, dim=p.axis)

compile_reduce_product_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
214
215
216
217
218
def compile_reduce_product_parameter(
    compiler: "TorchCompiler", p: ReduceProductParameter
) -> TorchReduceProductParameter:
    (in_shape,) = p.in_shapes
    return TorchReduceProductParameter(in_shape, dim=p.axis)

compile_reduce_sum_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
207
208
209
210
211
def compile_reduce_sum_parameter(
    compiler: "TorchCompiler", p: ReduceSumParameter
) -> TorchReduceSumParameter:
    (in_shape,) = p.in_shapes
    return TorchReduceSumParameter(in_shape, dim=p.axis)

compile_reference_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
111
112
113
114
115
116
117
def compile_reference_parameter(
    compiler: "TorchCompiler", p: ReferenceParameter
) -> TorchPointerParameter:
    # Obtain the other parameter's graph (and its fold index),
    # and wrap it in a pointer parameter node.
    compiled_p, fold_idx = compiler.state.retrieve_compiled_parameter(p.deref())
    return TorchPointerParameter(compiled_p, fold_idx=fold_idx)

compile_scaled_sigmoid_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
181
182
183
184
185
def compile_scaled_sigmoid_parameter(
    compiler: "TorchCompiler", p: ScaledSigmoidParameter
) -> TorchScaledSigmoidParameter:
    (in_shape,) = p.in_shapes
    return TorchScaledSigmoidParameter(in_shape, vmin=p.vmin, vmax=p.vmax)

compile_sigmoid_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
174
175
176
177
178
def compile_sigmoid_parameter(
    compiler: "TorchCompiler", p: SigmoidParameter
) -> TorchSigmoidParameter:
    (in_shape,) = p.in_shapes
    return TorchSigmoidParameter(in_shape)

compile_softmax_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
228
229
230
231
232
def compile_softmax_parameter(
    compiler: "TorchCompiler", p: SoftmaxParameter
) -> TorchSoftmaxParameter:
    (in_shape,) = p.in_shapes
    return TorchSoftmaxParameter(in_shape, dim=p.axis)

compile_softplus_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
193
194
195
196
197
def compile_softplus_parameter(
    compiler: "TorchCompiler", p: SoftplusParameter
) -> TorchSoftplusParameter:
    (in_shape,) = p.in_shapes
    return TorchSoftplusParameter(in_shape)

compile_square_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
169
170
171
def compile_square_parameter(compiler: "TorchCompiler", p: SquareParameter) -> TorchSquareParameter:
    (in_shape,) = p.in_shapes
    return TorchSquareParameter(in_shape)

compile_sum_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
125
126
127
def compile_sum_parameter(compiler: "TorchCompiler", p: SumParameter) -> TorchSumParameter:
    in_shape1, in_shape2 = p.in_shapes
    return TorchSumParameter(in_shape1, in_shape2)

compile_tensor_parameter(compiler, p) ¤

Source code in cirkit/backend/torch/rules/parameters.py
89
90
91
92
93
94
95
96
def compile_tensor_parameter(compiler: "TorchCompiler", p: TensorParameter) -> TorchTensorParameter:
    initializer_ = compiler.compile_initializer(p.initializer)
    dtype = _retrieve_dtype(p.dtype)
    compiled_p = TorchTensorParameter(
        *p.shape, requires_grad=p.learnable, initializer_=initializer_, dtype=dtype
    )
    compiler.state.register_compiled_parameter(p, compiled_p)
    return compiled_p