How to use a Vega chart in streamlit

Hi all,

I am using latest streamlit version and wanted to use a really nice Vega chart example in streamlit.

It is this one: Labelled Donut Chart Example | Vega

I wanted to use it but I am totally stuck at the moment. I already created some pie chart using this tutorial: Annotate an Altair chart - Streamlit Docs

But mentioned Vega chart is totally strange for me at the moment. If someone could help me or lead me in the right direction on how to use that, I would greatly appreciate it.

Does this help?

import streamlit as st
import pandas as pd

# Load iris dataset
iris = pd.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv")

# 1. Donut chart (mean sepal_length by species)
def donut_chart():
    agg = iris.groupby("species")["sepal_length"].mean().reset_index()
    chart = {
        "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
        "data": {
              "values": agg.to_dict("records")
        },
        "layer": [
            {
                "mark": {"type": "arc", "innerRadius": 60, "stroke": "#fff"},
                "encoding": {
                    "theta": {"field": "sepal_length", "type": "quantitative"},
                    "color": {"field": "species", "type": "nominal"}
                }
            },
            {
                "mark": {"type": "text", "radiusOffset": 20},
                "encoding": {
                    "text": {"field": "species"},
                    "theta": {"field": "sepal_length", "type": "quantitative"},
                }
            }
        ]
    }
    return chart

# 2. Bar chart (mean petal_length by species)
def bar_chart():
    agg = iris.groupby("species")["petal_length"].mean().reset_index()
    chart = {
        "data": {"values": agg.to_dict("records")},
        "mark": "bar",
        "encoding": {
            "x": {"field": "species", "type": "nominal"},
            "y": {"field": "petal_length", "type": "quantitative"},
            "color": {"field": "species", "type": "nominal"}
        }
    }
    return chart

# 3. Scatter plot (sepal_length vs sepal_width, colored by species)
def scatter_plot():
    chart = {
        "data": {"values": iris.to_dict("records")},
        "mark": "point",
        "encoding": {
            "x": {"field": "sepal_length", "type": "quantitative"},
            "y": {"field": "sepal_width", "type": "quantitative"},
            "color": {"field": "species", "type": "nominal"}
        }
    }
    return chart

# 4. Line chart (petal_length mean by species index)
def line_chart():
    agg = iris.groupby(["species"]).petal_length.mean().reset_index()
    agg["index"] = agg.index
    chart = {
        "data": {"values": agg.to_dict("records")},
        "mark": "line",
        "encoding": {
            "x": {"field": "index", "type": "ordinal", "title": "species index"},
            "y": {"field": "petal_length", "type": "quantitative"},
            "color": {"field": "species", "type": "nominal"}
        }
    }
    return chart

# 5. Histogram (distribution of sepal_length)
def histogram():
    chart = {
        "data": {"values": iris.to_dict("records")},
        "mark": "bar",
        "encoding": {
            "x": {"bin": True, "field": "sepal_length", "type": "quantitative"},
            "y": {"aggregate": "count", "type": "quantitative"},
            "color": {"field": "species", "type": "nominal"}
        }
    }
    return chart

# 6. Boxplot (petal_width by species)
def boxplot():
    chart = {
        "data": {"values": iris.to_dict("records")},
        "mark": "boxplot",
        "encoding": {
            "x": {"field": "species", "type": "nominal"},
            "y": {"field": "petal_width", "type": "quantitative"},
            "color": {"field": "species", "type": "nominal"}
        }
    }
    return chart

# --- Layout 3 rows × 2 columns ---
charts = [
    ("Donut Chart", donut_chart()),
    ("Bar Chart", bar_chart()),
    ("Scatter Plot", scatter_plot()),
    ("Line Chart", line_chart()),
    ("Histogram", histogram()),
    ("Boxplot", boxplot()),
]

for i in range(0, len(charts), 2):
    cols = st.columns(2)
    for j, col in enumerate(cols):
        if i + j < len(charts):
            title, chart = charts[i + j]
            with col:
                st.subheader(title)
                st.vega_lite_chart(spec=chart, use_container_width=True)