Skip to content

initializers

initializers ¤

DEFAULT_INITIALIZER_COMPILATION_RULES = {ConstantTensorInitializer: compile_constant_tensor_initializer, UniformInitializer: compile_uniform_initializer, NormalInitializer: compile_normal_initializer, DirichletInitializer: compile_dirichlet_initializer} module-attribute ¤

compile_constant_tensor_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
24
25
26
27
28
29
def compile_constant_tensor_initializer(
    compiler: "TorchCompiler", init: ConstantTensorInitializer
) -> InitializerFunc:
    if isinstance(init.value, np.ndarray):
        return functools.partial(copy_from_ndarray_, array=init.value)
    return functools.partial(torch.fill_, value=init.value)

compile_dirichlet_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
44
45
46
47
48
def compile_dirichlet_initializer(
    compiler: "TorchCompiler", init: DirichletInitializer
) -> InitializerFunc:
    axis = init.axis if init.axis < 0 else init.axis + 1
    return functools.partial(dirichlet_, alpha=init.alpha, dim=axis)

compile_normal_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
38
39
40
41
def compile_normal_initializer(
    compiler: "TorchCompiler", init: NormalInitializer
) -> InitializerFunc:
    return functools.partial(nn.init.normal_, mean=init.mean, std=init.stddev)

compile_uniform_initializer(compiler, init) ¤

Source code in cirkit/backend/torch/rules/initializers.py
32
33
34
35
def compile_uniform_initializer(
    compiler: "TorchCompiler", init: UniformInitializer
) -> InitializerFunc:
    return functools.partial(nn.init.uniform_, a=init.a, b=init.b)