How to use Circuit Operators to learn Sum of Squares Circuits¶
By the end of this notebook, you will know how to compose circuit operators implemend in Cirkit as to represent the exact computation of some quantities of interest in terms of circuits. In order to motivate and show the usage of circuit operators, we build and learn a Sum of Squares (SOS) probabilistic circuit (PC) for distribution estimation, as introduced in the paper Sum of Squares Circuits. While PCs are typically learned by assuming their parameters and input functions to be non-negative, SOS PCs can be learned with real or complex parameters, which makes them much more expressive models. As such, this notebook also shows how Cirkit allows us you build circuits with real or complex parameters seamlessly.
Background: From Squared Subtractive Mixture Models to Sum of Squares¶
We can build more expressive models by:
- Allowing for negative parameters in mixture models: Squared Subtractive Mixture Models (SMMs)
- Combining these mixtures into further mixtures: Sum of Squares (SOS) SMMs
Before digging into the code, we provide some background on subtractive and SOS mixtures. We use Gaussian Mixture Models for the examples below.
Squared Subtractive Mixture Models (Squared SMM)¶
Mixture models are typically learned by assuming their parameters and input functions to be non-negative. For this reason, mixture models are restricted to adding probability density, thus requiring a high number of mixture components in order to represent some complex probability distributions. For example, in the following GIF we show a ring-shaped 2D target density (ground truth), and the density learned by Gaussian mixture models (GMMs) with either $K=2$ or $K=16$ components.
As can be seen, a GMM with $K=2$ components struggles to represent the ground truth density, and we need to increase the number of components to $K=16$ before we begin to get a good fit. However, we can do better! If we allow the GMM to also subtract probability density by allowing negative mixture weights, we can capture such a distribution with just two Gaussian components. This can be seen on the right where the squared subtractive Gaussian mixture model (squared SGMM) fits the ground truth very well by subtracting an inner Gaussian component from an outer one.
However, if we allow for negative mixture weights, how can we ensure we are still modelling a non-negative function and therefore a density function? To ensure non-negativity, we square a subtractive mixture model.
Sum of Squares (SOS) SMMs¶
However, a single squared subtractive mixture model might not be enough for more complex distributions. For example, consider a 2D target density consisting of three rings, as showed in the GIF below (ground truth). A GMM with enough components can capture the "holes" of the target distribution, while a single squared SGMM struggles to do the same. The reason is that the squaring operation alone introduces a limitation that makes the model less flexible, and this is not the case for the GMM. To overcome this issue, we build a sum of many squared SGMMs, namely a Sum of Squares (SOS) SGMM, that provably surpasses such a limitation and therefore better captures the target distribution with fewer components (i.e., 12 for the GMM while just 6 for the SOS SGMM).
SOS circuits generalize this idea as to model high-dimensional distributions by squaring and summing deep subtractive mixture models represented as circuits. In the rest of the notebook, we show how the circuit operators implemented in Cirkit allows us to construct and learn an SOS circuit in a few lines of code. In particular, here we learn a simple SOS circuit estimating the distribution of MNIST images and taking the form of the sum of just two squares: the real and imaginary part of the output of a circuit with complex parameters.
Building and Learning SOS PCs: the case of Squared Complex Circuits¶
To build an SOS circuit encoding the sum of just two squares, we can start from a circuit $c$ with complex parameters computing a complex function, and then square its magnitude. Formally, we model a probability distribution $p$ over variables $\mathbf{X}$ as $$p(\mathbf{X}) = Z^{-1} |c(\mathbf{X})|^2 = Z^{-1} c(\mathbf{X}) c(\mathbf{X})^\dagger$$ where $(\ \cdot\ )^\dagger$ denotes the complex conjugation operation, i.e., $(a+b\mathbf{i})^\dagger = a-b\mathbf{i}$, and $Z$ is the renormalization constant or partition function. Equivalently, we can write $p$ as proportional to the sum of two squares, i.e., $$p(\mathbf{X}) \propto \Re(c(\mathbf{X}))^2 + \Im(c(\mathbf{X}))^2$$ where $\Re,\Im$ denote real and imaginary part, respectively.
We start by constructing $c$ using the Cirkit library. Since we aim at estimating the distribution of MNIST images, we use the cirkit_templates.image_data utility, similarly to what we have done in the learning a probabilistic circuit notebook.
from cirkit.templates import data_modalities, utils
from cirkit.symbolic.circuit import Circuit
def build_symbolic_complex_circuit(region_graph: str) -> Circuit:
return data_modalities.image_data(
(1, 28, 28), # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
region_graph=region_graph,
# ----------- Input layers hyperparameters ----------- #
input_layer='embedding', # Use Embedding maps for the pixel values (0-255) as input layers
num_input_units=32, # Each input layer consists of 32 input units that output Embedding entries
input_params={ # Set how to parameterize the input layers parameters
# In this case we parameterize the 'weight' parameter of Embedding layers,
# by choosing them to be complex-valued whose real and imaginary part are sampled uniformly in [0, 1)
'weight': utils.Parameterization(dtype='complex', initialization='uniform'),
},
# -------- Sum-product layers hyperparameters -------- #
sum_product_layer='cp-t', # Use CP-T sum-product layers, i.e., alternate hadamard product layers and dense layersfunction
num_sum_units=32, # Each dense sum layer consists of 32 sum units
# Set how to parameterize the sum layers parameters
# We paramterize them to be complex-valued whose real and imaginary part are sampled uniformly in [0, 1)
sum_weight_param=utils.Parameterization(dtype='complex', initialization='uniform')
)
For SOS PCs we do not use the categorical encoding of pixels from earlier notebooks, since we want to increase the expressivity of our circuit by using parameters that are complex numbers and can be negative. We therefore encode inputs as complex embeddings, i.e. we use a look-up table and map each input which is a pixel value in $\{0,1,\ldots,255\}$ to the corresponding complex number entry of our embedding in $\mathbb{C}^{256}$. If we overparametrize the input layer and allow for $K$ output units, our input layer is similar to a deep learning embedding layer, since each pixel value gets mapped to a learnable complex vector in $\mathbb{C}^{K}$.
In addition, we make use of CP-T as sum-product layers, where sum layers are parameterized with complex weights. For more details about this and other layers, see the region graphs notebook.
Computing the Partition Function by Composing Circuit Operators¶
To enable the exact and efficient computation of probabilities, we need to renormalize $p$, i.e., compute the partition function $Z$ exactly1. To do so, we use the circuit operators in the cirkit.symbolic.functional module as to automatically construct the symbolic circuit that would compute $Z$. In short, all we need is to compose the operators as to encode the formula
$$Z = \sum_{\mathbf{x}\in\mathrm{dom}(\mathbf{X})} |c(\mathbf{x})|^2 $$
as yet another circuit, where $\mathrm{dom}(\mathbf{X})$ denotes the domain of variables $\mathbf{X}$, i.e., the space of all images. Namely, we need three circuit operators: conjugate, multiply and integrate2, i.e., we can write in pseudocode Z = integrate(multiply(c, conjugate(c))).
1 Note that in earlier notebooks we did not need to compute the partition function because our circuit was already normalised: we had normalized distributions at inputs (e.g. categoricals), and softmax as an activation for the sum weights.
2 In Cirkit the circuit operator integrate refers to summation when we have discrete variables and integration when we have continuous ones. So in this example, integrate is summation, i.e. summation over all images.
Circuit Structure Needed for Tractably Computing the Partition Function¶
To ensure tractability, each of these operators has pre-conditions on structural properties that the circuit operands must satisfy, and post-conditions which are the resulting properties and semantics of the output circuit:
- Circuit multiplication
c' = multiply(c1, c2)- Pre-condition:
c1andc2are compatible, i.e., they share the same partitionings of variables at the products. - Post-condition:
c'is smooth and decomposable and encodes the product ofc1andc2.
- Pre-condition:
- Circuit conjugation
c' = conjugate(c)- Pre-condition:
cis a circuit with complex parameters. - Post-condition:
c'is a circuit of the same structure ofcand computing the complex conjugation ofc.
- Pre-condition:
- Circuit integration
c' = integrate(c):- Pre-condition:
cis a smooth and decomposable circuit. - Post-condition:
c'is a circuit exactly encoding the integral ofcover the whole variables domain.
- Pre-condition:
By matching pre-conditions and post-conditions, we are allowed to compose these operators as to exactly compute math formulae using circuits.
Needed Structure: Region Graphs are Structured-Decomposable¶
We now explain how to satisfy the above pre-conditions.
The conjugate pre-conditions follow from our complex embedding parameters. However, the multiply and integrate operators are a bit more restrictive on the structure of the circuit.
We construct a circuit from a region graph that is structured-decomposable. This will yield a circuit that is compatible with itself, i.e. it allows us to apply the multiply operator with itself as to square it. Then, the circuit resulting from the multiply operator is smooth and decomposable (due to post-condition of multiply), and therefore it satisfies the pre-conditions of the integrate operator. We build the symbolic circuit below and show its structural properties.
# Build a symbolic complex circuit by overparameterizing a Quad-Tree (4) region graph, which is structured-decomposable
symbolic_circuit = build_symbolic_complex_circuit('quad-tree-4')
# Print which structural properties the circuit satisfies
print(f'Structural properties:')
print(f' - Smoothness: {symbolic_circuit.is_smooth}')
print(f' - Decomposability: {symbolic_circuit.is_decomposable}')
print(f' - Structured-decomposability: {symbolic_circuit.is_structured_decomposable}')
Structural properties: - Smoothness: True - Decomposability: True - Structured-decomposability: True
Next, we compose the circuit operators mentioned above as to construct the symbolic circuit exactly encoding $Z$.
import cirkit.symbolic.functional as SF
# Construct the circuit computing Z, i.e., the integral of |c(X)|^2 over the complete domain of X
symbolic_circuit_partition_func = SF.integrate(
# Construct the circuit computing |c(X)|^2 = c(X) c(X)^* = Re(c(X))^2 + Im(c(X))^2
SF.multiply(symbolic_circuit, SF.conjugate(symbolic_circuit))
)
Note that the above just required a single line of code using Cirkit! The only things we needed to care about are the structural properties of the circuits and the pre- and post-conditions of the circuit operators.
Learning Complex Squared Circuits¶
The loss we will use to learn our model is the negated log-likelihood (NLL), which can be written as $$ \mathrm{NLL} := -\sum_{\mathbf{x}\in\mathcal{D}} \log p(\mathbf{x}) = \log Z - \sum_{\mathbf{x}\in\mathcal{D}} 2 \log |c(\mathbf{x})|, $$ where $\mathcal{D}$ denotes the training data points. For this reason, we need to compile two circuits: (i) the circuit $c$ and (ii) the circuit computing the partition function $Z$, and use them both to compute the NLL loss.
We start by importing torch, setting the seed, the device, and by loading the MNIST dataset.
import random
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
# Set some seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Set the torch device to use
device = torch.device('cuda:2')
# Load the MNIST data set and data loaders
transform = transforms.Compose([
transforms.ToTensor(),
# Flatten the images and set pixel values in the [0-255] range
transforms.Lambda(lambda x: (255 * x.view(-1)).long())
])
data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
data_test = datasets.MNIST('datasets', train=False, download=True, transform=transform)
# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)
Compilation Step and Optimization¶
To compile the circuits, we instantiate a PipelineContext object and refer the reader to the compilation options notebook for a tutorial on compiling circuits and on the meaning of the different flags. Here, one important flag is the evaluation semiring. That is, to ensure numerical stability, we evaluate circuits by computing sum and products as if they were operations of a semiring where the addition is the LogSumExp and the multiplication is the addition. More specifically, since our complex circuit can have negative real or complex parameters, we choose a generalization of the mentioned semiring over the complex plane.
from cirkit.pipeline import PipelineContext, compile
# Instantiate the pipeline context
ctx = PipelineContext(
backend='torch', # Choose PyTorch as compilation backend
# ---- Use the evaluation semiring (C, +, x), where + is the numerically stable LogSumExp and x is the sum ---- #
semiring='complex-lse-sum',
# ------------------------------------------------------------------------------------------------------------- #
fold=True, # Fold the circuit to better exploit GPU parallelism
optimize=True # Optimize the layers of the circuit
)
with ctx: # Compile the circuits computing log |c(X)| and log |Z|
circuit = compile(symbolic_circuit)
circuit_partition_func = compile(symbolic_circuit_partition_func)
Since we have chosen the complex-lse-sum semiring, the circuit being compiled will outputs values in log-space. More precisely, circuit is the circuit computing $\log |c(\mathbf{x})|$, while circuit_partition_func is the circuit computing the logarithm of the partition function, i.e., $\log Z$.
Next, we instantiate a PyTorch optimizer, such as Adam. Note that PyTorch automatically supports automatic differentiation over complex tensors.
# Initialize a torch optimizer of your choice,
# e.g., Adam, by passing the parameters of the circuit
optimizer = optim.Adam(circuit.parameters(), lr=0.01)
In the following training loop, we learn the parameters of the complex squared circuit by minimizing the negative log-likelihood computed on MNIST images.
num_epochs = 15
step_idx = 0
running_loss = 0.0
running_samples = 0
# Move the circuit to chosen device
circuit = circuit.to(device)
for epoch_idx in range(num_epochs):
for i, (batch, _) in enumerate(train_dataloader):
# The circuit expects an input of shape (batch_dim, num_variables)
batch = batch.to(device)
# -------- Computation of the negated log-likelihoods loss -------- #
# Compute the logarithm of the squared scores of the batch, by evaluating the circuit
log_scores = circuit(batch) # log |c(x)|
log_squared_scores = 2.0 * log_scores.real # 2 * log |c(x)|, i.e., equivalent to log |c(x)|^2
# Compute the log-partition function
log_partition_func = circuit_partition_func().real # log Z
# Compute the log-likelihoods, log p(x) = 2 * log |c(X)| - log Z
log_likelihoods = log_squared_scores - log_partition_func
# We take the negated average log-likelihood as loss
loss = -torch.mean(log_likelihoods)
# ------------------------------------------------------------------ #
# Update the parameters of the circuits, as any other model in PyTorch
loss.backward()
optimizer.step()
optimizer.zero_grad()
running_loss += loss.detach() * len(batch)
running_samples += len(batch)
step_idx += 1
if step_idx % 300 == 0:
average_nll = running_loss / running_samples
print(f"Step {step_idx}: Average NLL: {average_nll:.3f}")
running_loss = 0.0
running_samples = 0
Step 300: Average NLL: 1280.521 Step 600: Average NLL: 760.516 Step 900: Average NLL: 712.537 Step 1200: Average NLL: 686.120 Step 1500: Average NLL: 669.953 Step 1800: Average NLL: 660.865 Step 2100: Average NLL: 654.583 Step 2400: Average NLL: 647.114 Step 2700: Average NLL: 643.236 Step 3000: Average NLL: 641.632 Step 3300: Average NLL: 639.239
Next, we evaluate the model on the test MNIST images, and show the bits-per-dimension metric.
with torch.no_grad():
# -------- Compute the log-partition function -------- #
# Note that we need to do it just one, since we are not updating the parameters here
log_partition_func = circuit_partition_func().real
# ---------------------------------------------------- #
test_lls = 0.0
for batch, _ in test_dataloader:
batch = batch.to(device)
# -------- Compute the log-likelihoods of hte unseen samples -------- #
# Compute the logarithm of the squared scores of the batch, by evaluating the circuit
log_scores = circuit(batch)
log_squared_scores = 2.0 * log_scores.real
# Compute the log-likelihoods
log_likelihoods = log_squared_scores - log_partition_func
# ------------------------------------------------------------------- #
test_lls += log_likelihoods.sum().item()
# Compute average test log-likelihood and bits per dimension
average_ll = test_lls / len(data_test)
bpd = -average_ll / (28 * 28 * np.log(2.0))
print(f"Average test LL: {average_ll:.3f}")
print(f"Bits per dimension: {bpd:.3f}")
Average test LL: -680.173 Bits per dimension: 1.252
Conclusion and Further Reading¶
- Circuit operators can be composed together as to represent the exact computation of complicated math formulae in terms of symbolic circuits. The library supports many other circuit operators: see the documentation of the module
cirkit.symbolic.functionalfor more. - We can seamlessly build and learn circuits with real and complex parameters, such as the SOS PCs and squared complex circuits in this notebook. Allowing real/complex parameters in PCs is crucial to increase their expressiveness, as noted in Sum of Squares Circuits and Subtractive Mixture Models via Squaring: Representation and Learning.
- The easy integration with PyTorch allows you to abstract away from all the implementation details of the circuit operators and the chosen parameterization.