Skip to content

io

io ¤

plot_circuit(circuit, out_path=None, orientation='vertical', node_shape='box', label_font='times italic bold', label_size='21pt', label_color='white', sum_label='+', sum_color='#607d8b', product_label=None, product_color='#24a5af', input_label=None, input_color='#ffbd2a') ¤

Plot the current symbolic circuit using graphviz. A graphviz object is returned, which can be visualized in jupyter notebooks. If format is not provided, SVG is used for optimal rendering in notebooks.

Parameters:

Name Type Description Default
circuit Circuit

The symbolic circuit to plot.

required
out_path str | PathLike[str] | None

The output path where the plot is save If it is None, the plot is not saved to a file. Defaults to None. The Output file format is deduce from the path. Possible formats are: {'jp2', 'plain-ext', 'sgi', 'x11', 'pic', 'jpeg', 'imap', 'psd', 'pct', 'json', 'jpe', 'tif', 'tga', 'gif', 'tk', 'xlib', 'vmlz', 'json0', 'vrml', 'gd', 'xdot', 'plain', 'cmap', 'canon', 'cgimage', 'fig', 'svg', 'dot_json', 'bmp', 'png', 'cmapx', 'pdf', 'webp', 'ico', 'xdot_json', 'gtk', 'svgz', 'xdot1.4', 'cmapx_np', 'dot', 'tiff', 'ps2', 'gd2', 'gv', 'ps', 'jpg', 'imap_np', 'wbmp', 'vml', 'eps', 'xdot1.2', 'pov', 'pict', 'ismap', 'exr'}. See https://graphviz.org/docs/outputs/ for more.

None
orientation str

Orientation of the graph. "vertical" puts the root node at the top, "horizontal" at left. Defaults to "vertical".

'vertical'
node_shape str

Default shape for a node in the graph. Defaults to "box". See https://graphviz.org/doc/info/shapes.html for the supported shapes.

'box'
label_font str

Font used to render labels. Defaults to "times italic bold". See https://graphviz.org/faq/font/ for the available fonts.

'times italic bold'
label_size str

Size of the font for labels in points. Defaults to 21pt.

'21pt'
label_color str

Color for the labels in the nodes. Defaults to "white". See https://graphviz.org/docs/attr-types/color/ for supported color.

'white'
sum_label str | Callable[[SumLayer], str]

Either a string or a function. If a function is provided, then it must take as input a sum layer and returns a string that will be used as label. Defaults to "+".

'+'
sum_color str | Callable[[SumLayer], str]

Either a string or a function. If a function is provided, then it must take as input a sum layer and returns a string that will be used as color for the sum node. Defaults to "#607d8b".

'#607d8b'
product_label str | Callable[[ProductLayer], str] | None

Either a string or a function. If a function is provided, then it must take as input a product layer and returns a string that will be used as label. If None, it defaults to "⊙" for Hadamard layers and "⊗" for Kronecker layers.

None
product_color str | Callable[[ProductLayer], str]

Either a string or a function. If a function is provided, then it must take as input a product layer and returns a string that will be used as color for the product node. Defaults to "#24a5af".

'#24a5af'
input_label str | Callable[[InputLayer], str] | None

Either a string or a function. If a function is provided, then it must take as input an input layer and returns a string that will be used as label. If None, it defaults to using the scope of the layer.

None
input_color str | Callable[[InputLayer], str]

Either a string or a function. If a function is provided, then it must take as input an input layer and returns a string that will be used as color for the input layer node. Defaults to "#ffbd2a".

'#ffbd2a'

Raises:

Type Description
ValueError

The format is not among the supported ones.

ValueError

The direction is not among the supported ones.

Returns:

Type Description
Digraph

graphviz.Digraph: description

Source code in cirkit/symbolic/io.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def plot_circuit(
    circuit: Circuit,
    out_path: str | PathLike[str] | None = None,
    orientation: str = "vertical",
    node_shape: str = "box",
    label_font: str = "times italic bold",
    label_size: str = "21pt",
    label_color: str = "white",
    sum_label: str | Callable[[SumLayer], str] = "+",
    sum_color: str | Callable[[SumLayer], str] = "#607d8b",
    product_label: str | Callable[[ProductLayer], str] | None = None,
    product_color: str | Callable[[ProductLayer], str] = "#24a5af",
    input_label: str | Callable[[InputLayer], str] | None = None,
    input_color: str | Callable[[InputLayer], str] = "#ffbd2a",
) -> graphviz.Digraph:
    """Plot the current symbolic circuit using graphviz.
    A graphviz object is returned, which can be visualized in jupyter notebooks.
    If format is not provided, SVG is used for optimal rendering in notebooks.

    Args:
        circuit: The symbolic circuit to plot.
        out_path: The output path where the plot is save
            If it is None, the plot is not saved to a file. Defaults to None.
            The Output file format is deduce from the path. Possible formats are:
            {'jp2', 'plain-ext', 'sgi', 'x11', 'pic', 'jpeg', 'imap', 'psd', 'pct',
             'json', 'jpe', 'tif', 'tga', 'gif', 'tk', 'xlib', 'vmlz', 'json0', 'vrml',
             'gd', 'xdot', 'plain', 'cmap', 'canon', 'cgimage', 'fig', 'svg', 'dot_json',
             'bmp', 'png', 'cmapx', 'pdf', 'webp', 'ico', 'xdot_json', 'gtk', 'svgz',
             'xdot1.4', 'cmapx_np', 'dot', 'tiff', 'ps2', 'gd2', 'gv', 'ps', 'jpg',
             'imap_np', 'wbmp', 'vml', 'eps', 'xdot1.2', 'pov', 'pict', 'ismap', 'exr'}.
             See https://graphviz.org/docs/outputs/ for more.
        orientation: Orientation of the graph. "vertical" puts the root
            node at the top, "horizontal" at left. Defaults to "vertical".
        node_shape: Default shape for a node in the graph. Defaults to "box".
            See https://graphviz.org/doc/info/shapes.html for the supported shapes.
        label_font: Font used to render labels. Defaults to "times italic bold".
            See https://graphviz.org/faq/font/ for the available fonts.
        label_size: Size of the font for labels in points. Defaults to 21pt.
        label_color: Color for the labels in the nodes. Defaults to "white".
            See https://graphviz.org/docs/attr-types/color/ for supported color.
        sum_label: Either a string or a function.
            If a function is provided, then it must take as input a sum layer and returns a string
            that will be used as label. Defaults to "+".
        sum_color: Either a string or a function.
            If a function is provided, then it must take as input a sum layer and returns a string
            that will be used as color for the sum node. Defaults to "#607d8b".
        product_label: Either a string or a function.
            If a function is provided, then it must take as input a product layer and returns a
            string that will be used as label. If None, it defaults to "⊙" for Hadamard layers and
            "⊗" for Kronecker layers.
        product_color: Either a string or a function.
            If a function is provided, then it must take as input a product layer and returns a
            string that will be used as color for the product node. Defaults to "#24a5af".
        input_label: Either a string or a function.
            If a function is provided, then it must take as input an input layer and returns a
            string that will be used as label. If None, it defaults to using the scope of the layer.
        input_color: Either a string or a function.
            If a function is provided, then it must take as input an input layer and returns a
            string that will be used as color for the input layer node. Defaults to "#ffbd2a".

    Raises:
        ValueError: The format is not among the supported ones.
        ValueError: The direction is not among the supported ones.

    Returns:
        graphviz.Digraph: _description_
    """
    fmt: str
    if out_path is None:
        fmt = "svg"
    else:
        fmt = Path(out_path).suffix.replace(".", "")
        if fmt not in graphviz.FORMATS:
            raise ValueError(f"Supported formats are {graphviz.FORMATS}.")

    if orientation not in ["vertical", "horizontal"]:
        raise ValueError("Supported graph directions are only 'vertical' and 'horizontal'.")

    def _default_product_label(sl: ProductLayer) -> str:
        match sl:
            case HadamardLayer():
                return "⊙"
            case KroneckerLayer():
                return "⊗"
            case _:
                raise NotImplementedError(
                    f"No default label for product layer of type {sl.__class__}"
                )

    def _default_input_label(sl: InputLayer) -> str:
        return " ".join(map(str, sl.scope))

    if product_label is None:
        product_label = _default_product_label
    if input_label is None:
        input_label = _default_input_label

    dot: graphviz.Digraph = graphviz.Digraph(
        format=fmt,
        node_attr={
            "shape": node_shape,
            "style": "filled",
            "fontcolor": label_color,
            "fontsize": label_size,
            "fontname": label_font,
        },
        engine="dot",
    )
    dot.graph_attr["rankdir"] = "BT" if orientation == "vertical" else "LR"

    for sl in circuit.layers:
        match sl:
            case ProductLayer():
                dot.node(
                    str(id(sl)),
                    product_label if isinstance(product_label, str) else product_label(sl),
                    color=product_color if isinstance(product_color, str) else product_color(sl),
                )
            case SumLayer():
                dot.node(
                    str(id(sl)),
                    sum_label if isinstance(sum_label, str) else sum_label(sl),
                    color=sum_color if isinstance(sum_color, str) else sum_color(sl),
                )
            case InputLayer():
                dot.node(
                    str(id(sl)),
                    input_label if isinstance(input_label, str) else input_label(sl),
                    color=input_color if isinstance(input_color, str) else input_color(sl),
                )

        for sli in circuit.layer_inputs(sl):
            dot.edge(str(id(sli)), str(id(sl)))

    if out_path is not None:
        out_dir: Path = Path(out_path).with_suffix("")

        if fmt == "dot":
            with open(out_dir, "w", encoding="utf8") as f:
                f.write(dot.source)
        else:
            dot.format = fmt
            dot.render(out_dir, cleanup=True)

    return dot