Skip to content

initializers

initializers ¤

DEFAULT_INITIALIZER_COMPILATION_RULES = {ConstantTensorInitializer: compile_constant_tensor_initializer, UniformInitializer: compile_uniform_initializer, NormalInitializer: compile_normal_initializer, DirichletInitializer: compile_dirichlet_initializer} module-attribute ¤

compile_constant_tensor_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
21
22
23
24
25
26
def compile_constant_tensor_initializer(
    compiler: "TorchCompiler", init: ConstantTensorInitializer
) -> InitializerFunc:
    if isinstance(init.value, np.ndarray):
        return functools.partial(copy_from_ndarray_, array=init.value)
    return functools.partial(torch.fill_, value=init.value)

compile_dirichlet_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
41
42
43
44
45
def compile_dirichlet_initializer(
    compiler: "TorchCompiler", init: DirichletInitializer
) -> InitializerFunc:
    axis = init.axis if init.axis < 0 else init.axis + 1
    return functools.partial(dirichlet_, alpha=init.alpha, dim=axis)

compile_normal_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
35
36
37
38
def compile_normal_initializer(
    compiler: "TorchCompiler", init: NormalInitializer
) -> InitializerFunc:
    return functools.partial(nn.init.normal_, mean=init.mean, std=init.stddev)

compile_uniform_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
29
30
31
32
def compile_uniform_initializer(
    compiler: "TorchCompiler", init: UniformInitializer
) -> InitializerFunc:
    return functools.partial(nn.init.uniform_, a=init.a, b=init.b)