Learning a Gaussian Mixture Model¶
In this notebook, we show how we can create a symbolic circuit with cirkit to create a simple Gaussian mixture model, compile it into a regular Pytorch model, and learn the cluster assigments using Adam.
Note that this is an illustrative example to show how to build symbolic circuits manually, and there are better ways of fitting Gaussian mixture models than with stochastic first-order optimization.
Helper functions¶
First, we write down a simple function to help us visualize 2D densities.
import torch
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator
def plot_2D(*fns, title=None, xmin=-2.5, xmax=2.5, nbins=15):
x_min, x_max = xmin, xmax
y_min, y_max = xmin, xmax
dx, dy = 0.01, 0.01
# generate 2 2d grids for the x & y bounds
y, x = np.mgrid[slice(x_min, x_max + dy, dy),
slice(y_min, y_max + dx, dx)]
xy = torch.from_numpy(np.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))).float()
fns = [fn for fn in fns if fn is not None]
ncols = len(fns)
if ncols == 0:
return
fig, axs = plt.subplots(ncols=ncols, figsize=(5*ncols, 5))
if ncols == 1:
axs = [axs]
for fn, ax in zip(fns, axs):
with torch.no_grad():
z = fn(xy)
z = z.view(y.shape).numpy()
z = z[:-1, :-1]
cmap = plt.colormaps['PiYG']
levels = MaxNLocator(nbins=nbins).tick_values(z.min(), z.max())
cf = ax.contourf(
x[:-1, :-1] + dx/2.,
y[:-1, :-1] + dy/2.,
z,
levels=levels, cmap=cmap
)
ax.set_aspect('equal', 'box')
fig.colorbar(cf, ax=axs)
if title is not None:
if ncols == 1:
axs[-1].set_title(title)
else:
fig.suptitle(title)
plt.show()
Generate a simple dataset¶
Next, we are going to use Pytorch distributions to generate a simple dataset composed of eight Gaussians distributed around the origin point.
import torch.distributions as D
import math
radius = 2 # Distance of the centers from the origin
K = 8 # Number of clusters
mus = torch.tensor([
[math.cos(2*math.pi*n / K) for n in range(K)],
[math.sin(2*math.pi*n / K) for n in range(K)]
]).T * radius
sigma = .2 # Standard deviation
mix = D.Categorical(torch.ones(K,))
comp = D.Independent(D.Normal(mus, sigma), 1)
gmm = D.MixtureSameFamily(mix, comp)
def sample_points(n_points):
return gmm.sample((n_points,))
plt.scatter(*sample_points(1000).unbind(-1))
plt.gca().set_aspect('equal', 'box')
plt.title('Original samples')
plt.show()
def true_density(xy):
return gmm.log_prob(xy).exp()
plot_2D(true_density, title='Original density')
Create a symbolic circuit¶
Now, we build a simple mixture model using a symbolic circuit. To create these, we basically need to tell cirkit what are the operations we want and how are they connected. That is, we create circuits by defining the graph of layers that compose it.
In the case of a Gaussian mixture model, we need to create:
- Two Gaussian input layers (one per variable), that output as many values as the number of clusters we want to use.
- A Hadamard layer, which will combine these two Gaussian layers by cluster-wise multiplying the densities.
- A Sum layer with convex weights that will define the mixture weights.
Note that, as of now, Gaussian layers in cirkit are defined for univariate distributions, and thus we defined them as 2D isotropic Gaussians by multiplying two independent univariate Gaussians.
from cirkit.symbolic.circuit import Circuit, Scope
from cirkit.symbolic.layers import GaussianLayer, SumLayer, HadamardLayer
from cirkit.templates import utils
def build_symbolic_circuit() -> Circuit:
# This parametrizes the mixture weights such that they add up to one.
weight_factory = utils.parameterization_to_factory(utils.Parameterization(
activation='softmax', # Parameterize the sum weights by using a softmax activation
initialization='uniform' # Initialize the sum weights by sampling from a standard normal distribution
))
# We introduce one more mixture than in the original model
# Again, SGD/Adam is not the best way to fit a (shallow) Gaussian mixture model
units = K+1
g0 = GaussianLayer(Scope((0,)), units)
g1 = GaussianLayer(Scope((1,)), units)
prod = HadamardLayer(num_input_units=units, arity=2)
sl = SumLayer(units, 1, 1, weight_factory=weight_factory)
return Circuit(
layers=[g0, g1, prod, sl], # Layers that appear in the circuit (i.e. nodes in the graph)
in_layers={ # Connections between layers (i.e. edges in the graph as an adjacency list)
g0: [],
g1: [],
prod: [g0, g1],
sl: [prod],
},
outputs=[sl] # Nodes that are returned by the circuit
)
# Build a symbolic complex circuit by overparameterizing a Quad-Tree (4) region graph, which is structured-decomposable
symbolic_circuit = build_symbolic_circuit()
# Print which structural properties the circuit satisfies
print(f'Structural properties:')
print(f' - Smoothness: {symbolic_circuit.is_smooth}')
print(f' - Decomposability: {symbolic_circuit.is_decomposable}')
print(f' - Structured-decomposability: {symbolic_circuit.is_structured_decomposable}')
Structural properties: - Smoothness: True - Decomposability: True - Structured-decomposability: True
We can use the helper function plot_circuit from circuit to visualize the circuit (and make sure it is doing what we expect)
from cirkit.symbolic.io import plot_circuit
plot_circuit(symbolic_circuit)
By looking at the symbolic circuit representing this GMM, we can see that it is equivalent to a CP layer in tensor factorization, as we will see later in "tensor compression using the CP factorization as a circuit". For example, if we had 2 clusters then the circuit would look like the following:
Learning Gaussian Mixture Models¶
import random
import numpy as np
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
dataset_size = 10000
# Set some seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Set the torch device to use
device = torch.device('cpu')
# Load the MNIST data set and data loaders
data_train = TensorDataset(sample_points(dataset_size))
data_test = TensorDataset(sample_points(dataset_size//10))
# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)
Compilation Step and Optimization¶
To compile the circuit, we instantiate a PipelineContext object and refer the reader to the compilation options notebook for a tutorial on compiling circuits and on the meaning of the different flags.
from cirkit.pipeline import PipelineContext, compile
# Instantiate the pipeline context
ctx = PipelineContext(
backend='torch', # Choose PyTorch as compilation backend
# ---- Use the evaluation semiring (R, +, x), where + is the numerically stable LogSumExp and x is the sum ---- #
semiring='lse-sum',
# ------------------------------------------------------------------------------------------------------------- #
fold=True, # Fold the circuit to better exploit GPU parallelism
optimize=True # Optimize the layers of the circuit
)
with ctx: # Compile the circuits computing log |c(X)| and log |Z|
circuit = compile(symbolic_circuit)
To see the difference, we can visualize the model density at initialization
def model_density(xy):
return circuit(xy).exp()
plot_2D(model_density, title='Model density (before training)', xmin=-3, xmax=3)
Next, we instantiate a PyTorch optimizer, such as Adam.
# Initialize a torch optimizer of your choice,
# e.g., Adam, by passing the parameters of the circuit
optimizer = optim.Adam(circuit.parameters(), lr=0.01)
In the following training loop, we learn the parameters of the mixture model by minimizing the negative log-likelihood computed on training samples.
num_epochs = 30
step_idx = 0
running_loss = 0.0
running_samples = 0
# Move the circuit to chosen device
circuit = circuit.to(device)
for epoch_idx in range(num_epochs):
for i, (batch,) in enumerate(train_dataloader):
# The circuit expects an input of shape (batch_dim, num_variables)
batch = batch.to(device)
log_likelihoods = circuit(batch)
loss = -torch.mean(log_likelihoods)
# Update the parameters of the circuits, as any other model in PyTorch
loss.backward()
optimizer.step()
optimizer.zero_grad()
running_loss += loss.detach() * len(batch)
running_samples += len(batch)
step_idx += 1
if step_idx % 50 == 0:
average_nll = running_loss / running_samples
print(f"Step {step_idx}: Average NLL: {average_nll:.3f}")
running_loss = 0.0
running_samples = 0
Step 50: Average NLL: 4.058 Step 100: Average NLL: 3.072 Step 150: Average NLL: 2.786 Step 200: Average NLL: 2.560 Step 250: Average NLL: 2.328 Step 300: Average NLL: 2.150 Step 350: Average NLL: 2.044 Step 400: Average NLL: 2.021 Step 450: Average NLL: 1.992 Step 500: Average NLL: 1.941 Step 550: Average NLL: 1.869 Step 600: Average NLL: 1.807 Step 650: Average NLL: 1.797 Step 700: Average NLL: 1.769 Step 750: Average NLL: 1.742 Step 800: Average NLL: 1.737 Step 850: Average NLL: 1.729 Step 900: Average NLL: 1.734 Step 950: Average NLL: 1.746 Step 1000: Average NLL: 1.730 Step 1050: Average NLL: 1.733 Step 1100: Average NLL: 1.737 Step 1150: Average NLL: 1.730 Step 1200: Average NLL: 1.735
Next, we evaluate the model on the test data.
with torch.no_grad():
test_lls = 0.0
for batch, in test_dataloader:
batch = batch.to(device)
log_likelihoods = circuit(batch)
test_lls += log_likelihoods.sum().item()
# Compute average test log-likelihood and bits per dimension
average_ll = test_lls / len(data_test)
print(f"Average test LL: {average_ll:.3f}")
Average test LL: -1.726
Now, we can visualize the learned mixture model by plotting its samples.
To sample, we can simply use the SamplingQuery from cirkit, available for the torch backend.
from cirkit.backend.torch.queries import SamplingQuery
num_samples = 1000
query = SamplingQuery(circuit)
samples, _ = query(num_samples=num_samples)
samples.shape
torch.Size([1000, 2])
plt.scatter(*samples.unbind(-1))
plt.title('Circuit samples')
plt.gca().set_aspect('equal', 'box')
plt.show()
And we can make a side-by-side comparison by plotting the density of both models, the original one and the cirkit one.
plot_2D(true_density, model_density)
As we can observe in the plots above, the model is nicely fitting the original data distribution.
A more challenging example¶
The previous example, while instructive, was made out of eight Gaussians, so clearly we could fit the model well. (Note, however, that we had to use 9 clusters, as we gradient descent it does not work otherwise.)
What if the data is non-gaussian? Can we still fit the model to the data well?
Why don't you try to re-run the code with the functions below? What if we change the number of clusters of the GMM?
Give it a try!
from sklearn.datasets import make_s_curve
true_density = None # We do not know the density
def sample_points(n_points):
return torch.from_numpy(make_s_curve(n_points, noise=0.15)[0][:, [0,2]]).float()
points = sample_points(1000)
plt.scatter(*points.unbind(-1))
plt.title('New samples')
plt.gca().set_aspect('equal', 'box')
plt.show()