Skip to content

chow_liu

chow_liu ¤

ChowLiuTree(data, input_type, root=None, chunk_size=None, num_categories=None, num_bins=None, as_region_graph=False) ¤

Learns a Chow-Liu Tree and returns it either as a list of predecessors (Bayesian net) or as region graph (HCLT).

Details in https://arxiv.org/abs/2409.07953.

Parameters:

Name Type Description Default
data Tensor

The input data over which running the CLT algorithm, it must be in tabular form (i.e. a matrix).

required
input_type str

The type of the input data, e.g. 'categorical', 'gaussian'.

required
root int | None

The index of the variable desired as root.

None
chunk_size int | None

Chunked computation, useful in case of large input data.

None
num_categories int | None

Specifies the number of categories in case of categorical data.

None
num_bins int | None

In case of categorical input, it is used to rescale categories in bins for ordinal features, e.g. [0, 255] -> [0, 7], which is useful for images.

None
as_region_graph Optional[bool]

True to returns a region graph, False to return a list of predecessors.

False

Returns:

Type Description
ndarray | RegionGraph

A Chow-Liu Tree, either a list of predecessors or as a region graph.

Raises:

Type Description
ValueError

If the number of categories has not been specified but the number of bins has.

NotImplementedError

If the input type is neither 'categorical' nor 'gaussian'.

Source code in cirkit/templates/region_graph/algorithms/chow_liu.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def ChowLiuTree(
    data: Tensor,
    input_type: str,
    root: int | None = None,
    chunk_size: int | None = None,
    num_categories: int | None = None,
    num_bins: int | None = None,
    as_region_graph: int | None = False,
) -> np.ndarray | RegionGraph:
    """Learns a Chow-Liu Tree and returns it either as a
    list of predecessors (Bayesian net) or as region graph (HCLT).

    Details in https://arxiv.org/abs/2409.07953.

    Args:
        data (Tensor): The input data over which running the CLT algorithm,
            it must be in tabular form (i.e. a matrix).
        input_type (str): The type of the input data, e.g. 'categorical', 'gaussian'.
        root (int | None): The index of the variable desired as root.
        chunk_size (int | None): Chunked computation, useful in case of large input data.
        num_categories (int | None): Specifies the number of categories in case of
            categorical data.
        num_bins (int | None): In case of categorical input, it is used to rescale
            categories in bins for ordinal features, e.g. [0, 255] -> [0, 7],
            which is useful for images.
        as_region_graph (Optional[bool]): True to returns a region graph,
            False to return a list of predecessors.

    Returns:
        A Chow-Liu Tree, either a list of predecessors or as a region graph.

    Raises:
        ValueError: If the number of categories has not been specified but the number of bins has.
        NotImplementedError: If the input type is neither 'categorical' nor 'gaussian'.
    """
    assert data.ndim == 2
    assert root is None or -1 < root < data.size(-1)
    if input_type == "categorical":
        if num_bins is not None:
            if num_categories is None:
                raise ValueError("Number of categories must be known if rescaling in bins")
            data = torch.div(data, num_categories // num_bins, rounding_mode="floor")
        mutual_info = _categorical_mutual_info(
            data.long(), num_categories=num_categories, chunk_size=chunk_size
        )
    elif input_type == "gaussian":
        # todo: implement chunked computation
        mutual_info = -0.5 * torch.log(1 - torch.corrcoef(data.t()) ** 2)
    else:
        raise NotImplementedError(f"MI computation not implemented for {input_type} input units")

    _, tree = _maximum_spanning_tree(adj_matrix=mutual_info, root=root)
    if as_region_graph:
        return tree2rg(tree)
    return tree

_categorical_mutual_info(data, alpha=0.01, num_categories=None, chunk_size=None) ¤

Computes the mutual information (MI) matrix of a matrix of integers.

Parameters:

Name Type Description Default
data Tensor

The input data over which computing the MI matrix, it must be in tabular form (i.e. a matrix).

required
alpha Tensor

Laplace smoothing factor.

0.01
num_categories int | None

Specifies the number of categories.

None
chunk_size int | None

Chunked computation, useful in case of large input data.

None

Returns:

Type Description
Tensor

The mutual information matrix (main diagonal is 0).

Source code in cirkit/templates/region_graph/algorithms/chow_liu.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def _categorical_mutual_info(
    data: LongTensor,
    alpha: float = 0.01,
    num_categories: int | None = None,
    chunk_size: int | None = None,
) -> Tensor:
    """Computes the mutual information (MI) matrix of a matrix of integers.

    Args:
        data (Tensor): The input data over which computing the MI matrix,
            it must be in tabular form (i.e. a matrix).
        alpha (Tensor): Laplace smoothing factor.
        num_categories (int | None): Specifies the number of categories.
        chunk_size (int | None): Chunked computation, useful in case of large input data.

    Returns:
        The mutual information matrix (main diagonal is 0).
    """
    assert data.dtype == torch.long and data.ndim == 2
    n_samples, n_features = data.size()
    if num_categories is None:
        num_categories = int(data.max().item() + 1)
    if chunk_size is None:
        chunk_size = n_samples

    idx_features = torch.arange(0, n_features)
    idx_categories = torch.arange(0, num_categories)

    joint_counts = torch.zeros(
        n_features, n_features, num_categories**2, dtype=torch.long, device=data.device
    )
    for chunk in data.split(chunk_size):
        joint_values = chunk.t().unsqueeze(1) * num_categories + chunk.t().unsqueeze(0)
        joint_counts.scatter_add_(-1, joint_values.long(), torch.ones_like(joint_values))
    joint_counts = joint_counts.view(n_features, n_features, num_categories, num_categories)
    marginal_counts = joint_counts[idx_features, idx_features][:, idx_categories, idx_categories]

    marginals = (marginal_counts + num_categories * alpha) / (
        n_samples + num_categories**2 * alpha
    )
    joints = (joint_counts + alpha) / (n_samples + num_categories**2 * alpha)
    joints[idx_features, idx_features] = torch.diag_embed(
        marginals
    )  # Correct Laplace's smoothing for the marginals
    outers = torch.einsum("ik,jl->ijkl", marginals, marginals)

    return (joints * (joints.log() - outers.log())).sum(dim=(2, 3)).fill_diagonal_(0)

_maximum_spanning_tree(adj_matrix, root=None) ¤

Runs the maximum spanning tree of a given adjacency matrix rooted at a given variable.

Parameters:

Name Type Description Default
adj_matrix Tensor

The adjacency matrix.

required
root int | None

The index of the variable desired as root. If None, picks the one that minimizes depth.

None

Returns:

Name Type Description
bfs ndarray

The BFS order of the spanning tree.

tree ndarray

The spanning tree in form of list of predecessors.

Source code in cirkit/templates/region_graph/algorithms/chow_liu.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def _maximum_spanning_tree(
    adj_matrix: Tensor, root: int | None = None
) -> tuple[np.ndarray, np.ndarray]:
    """Runs the maximum spanning tree of a given adjacency matrix rooted at a given variable.

    Args:
        adj_matrix (Tensor): The adjacency matrix.
        root (int | None): The index of the variable desired as root.
            If None, picks the one that minimizes depth.

    Returns:
        bfs: The BFS order of the spanning tree.
        tree: The spanning tree in form of list of predecessors.
    """
    mst = sp.csgraph.minimum_spanning_tree(-(adj_matrix.cpu().numpy() + 1.0), overwrite=True)
    if root is None:
        dist_from_all_nodes = sp.csgraph.dijkstra(
            abs(mst).todense(), directed=False, return_predecessors=False
        )
        root = np.argmin(np.max(dist_from_all_nodes, axis=1))
    bfs, tree = sp.csgraph.breadth_first_order(
        mst, directed=False, i_start=root, return_predecessors=True
    )
    tree[root] = -1
    return bfs, tree