def dirichlet_(tensor: Tensor, alpha: float | list[float], *, dim: int = -1) -> Tensor:
shape = tensor.shape
if not shape:
raise ValueError(
"Cannot initialize a tensor with no dimensions by sampling from a Dirichlet"
)
dim = dim if dim >= 0 else dim + len(shape)
if isinstance(alpha, float):
concentration = torch.full([shape[dim]], fill_value=alpha)
else:
if shape[dim] != len(alpha):
raise ValueError(
"The selected dim of the tensor and the size of concentration parameters "
"do not match"
)
concentration = Tensor(alpha)
dirichlet = torch.distributions.Dirichlet(concentration)
samples = dirichlet.sample(torch.Size([d for i, d in enumerate(shape) if i != dim]))
tensor.copy_(torch.transpose(samples, dim, -1))
return tensor