Help with Streamlit/Shiny comparison

Summary

I’m comparing Streamlit to Shiny using this application and my solution seems incredibly complicated and I’m wondering if there’s a better way.

Details

I’m trying to replicate the behaviour of a simple shiny app which can be viewed and edited using Shinylive.

The key features of this app are that when you change the sample size a new sample is taken, but when you change the plotting scale only the scatter plot changes and a new sample isn’t taken. Shiny does this fairly easily because it only reruns the components which need to change in response to user inputs, but I ran into some problems with Streamlit’s execution model when I tried to implement it.

If I don’t cache the data, then a new sample is taken whenever the plot options change, but if I do cache it then the sample isn’t re-run in response to changing the sample size. The workaround I ended up with is creating a cache-busting callback which updates session state. This works but seems incredibly complicated and awkward for such a simple application.

Is there a better way to accomplish this?

Hi @Gordon_Shotwell

From what I can think of now, there are 2 possible approaches:

  1. Put all input widgets inside st.form such that any changes to any of the input widgets would not trigger the app to reload. Once the user is ready to proceed, he/she can click on the submit button (st.form_submit_button).
  2. Use Session State to manage the various session variables so that a change in one of the session variables would not trigger the entire app to reload. In your situation, you could add session state to he sample size and possibly the plotting scale.

Hope this helps :slight_smile:

Best regards,
Chanin

Thanks for the response, I don’t want to use the form pattern because I’m trying to replicate the shiny experience which doesn’t require a second user action (which isn’t a big deal for this example, but can get a little complicated).

I’m not quite understanding the session state solution. So I should have the slider have a callback which updates session state? That seems to result in the sample being one behind the current value.

Here’s one way to get the behavior you’re looking for – pass a random_state value to sample, and keep the random_state in session_state, and update it when the sample size is changed, like your cache buster. In this way, the data doesn’t change if you change the log plot, and you actually don’t have to do any extra caching beyond loading the initial data.

import streamlit as st
from pandas import read_csv
from plotnine import (
    aes,
    geom_histogram,
    geom_point,
    ggplot,
    scale_x_log10,
    scale_y_log10,
    theme_bw,
)

if "random_state" not in st.session_state:
    st.session_state.random_state = 0


def update_random_state():
    st.session_state.random_state += 1


with st.sidebar:
    sample_ui = st.number_input(
        "sample", 0.0, 1.0, value=0.1, step=0.01, on_change=update_random_state
    )
    log = st.checkbox("Log Scale")


@st.cache_data
def load_data():
    df = read_csv(
        "https://raw.githubusercontent.com/GShotwell/streamlit-shiny-comp/main/nyc-taxi.csv"
    )
    return df


data = load_data()


def take_sample_uncached(df, fraction):
    return df.copy().sample(frac=fraction, random_state=st.session_state.random_state)


def tip_plot(sample, log=False):
    plot = ggplot(sample, aes("tip_amount", "total_amount")) + geom_point() + theme_bw()
    if log:
        plot = plot + scale_x_log10() + scale_y_log10()
    return plot


def amount_histogram(df):
    plot = ggplot(df, aes(x="total_amount")) + geom_histogram(binwidth=5) + theme_bw()
    return plot


def plot(sampled, log=False):
    st.subheader(f'First taxi id: {sampled["taxi_id"].iloc[0]}')

    tips = tip_plot(sampled, log=log)
    st.pyplot(tips.draw())

    amounts = amount_histogram(sampled)
    st.pyplot(amounts.draw())


sampled = take_sample_uncached(data, sample_ui)
plot(sampled, log=log)
1 Like

Thanks! That’s a very interesting solution which I definitely hadn’t thought of. The only issue is that that’s not actually a random sample because it will create the same sample each time you run the app. I suppose it’s not complicated to have it use a random number for the seed instead of incrementing a counter. Overall though this solution seems just as awkward as the cache-busting callback, and also a bit more unclear because the sampling behaviour is split into a bunch of pieces.

Sure, switch out 0random.randint(0, 10000000) for the first part, and you get a new random sample for each session. I suppose that “awkward” is in the eye of the beholder :slight_smile:

2 Likes

That’s true, I suppose what I find awkward is that the solution kind of turns a random sample into a semi-deterministic sample. I think you get the same results in this case, but I had to think about it, and I’m not sure you’d get the same results for other types statistical methods which rely on randomness. In the same way my solution is awkward not because it’s wrong but because you have to exert some brain power to parse the code and figure out the edge cases where it might get the wrong answer.

Fair enough. I think the piece that makes it tricky is that this behavior requires “make this widget only make some parts of the app get rerun, and this widget make other parts of the app get rerun”, and that’s simply not an easy thing to do with streamlit. In streamlit, you get a very straight-forward behavior where the app reruns any time anything changes, so that everything gets the very latest version of every variable. That makes some kinds of apps easier, and some kinds of apps harder.

I will say, it’s easier for me to reason about what will happen in the Streamlit code than the equivalent Shiny code, and more intuitive what variables are getting passed around to what. There’s a lot more implicit state management going on with the Shiny example (e.g. how does the value from ui.input_slider("sample", "Sample Size", 0, 1, value = 0.1) end up being passed to frac = input.sample())?

If it was me, I would probably go with the slightly-less-of-an-exact-match version of the app, which just resamples anytime anything changes :slight_smile: Works great, until you want to say “but I want this widget to only affect these things”, which is when it always gets harder with Streamlit.

import streamlit as st
from pandas import read_csv
from plotnine import (
    aes,
    geom_histogram,
    geom_point,
    ggplot,
    scale_x_log10,
    scale_y_log10,
    theme_bw,
)


@st.cache_data
def load_data():
    df = read_csv(
        "https://raw.githubusercontent.com/GShotwell/streamlit-shiny-comp/main/nyc-taxi.csv"
    )
    return df


with st.sidebar:
    sample_ui = st.number_input("sample", 0.0, 1.0, value=0.1, step=0.01)
    log = st.checkbox("Log Scale")

sample = load_data().sample(frac=sample_ui)

st.subheader(f'First taxi id: {sample["taxi_id"].iloc[0]}')

hist = ggplot(sample, aes(x="total_amount")) + geom_histogram(binwidth=5) + theme_bw()
st.pyplot(hist.draw())

scatter = ggplot(sample, aes("tip_amount", "total_amount")) + geom_point() + theme_bw()
if log:
    scatter = scatter + scale_x_log10() + scale_y_log10()
st.pyplot(scatter.draw())

Right so the reason you don’t want that app statistically is that you want to take a sample and then interrogate that sample using other inputs. For this use case it’s very important that the sample is held constant while the other inputs change, because otherwise you won’t be able to understand what’s happening with a particular sample.

In terms of implicit state management, I would say the main conceptual difference between Shiny and Streamlit is that Shiny is a declarative framework. You tell the Shiny app what you want to have happen and let the framework move the data around and figure out what needs to update. So you don’t need to think about how input.sample() gets passed around because the framework handles that for you. This is a bit of mental shift if you’re used to frameworks that require writing callbacks, but is really great for many use cases because you need to worry less about the wiring between components.

Here’s a good write up on this from the R side of things: Chapter 3 Basic reactivity | Mastering Shiny

And a good talk on the different trade-offs between Shiny and Streamlit: Joe Cheng - Shiny for Python: Interactive apps and dashboards made easy-ish | PyData NYC 2022 - YouTube

Thanks again for responding. This app is mostly meant as a very simple way to highlight the differences between Streamlit and Shiny, so I just wanted to check that I wasn’t missing something obvious about how Streamlit handles these types of problems.

1 Like

Yes, that makes sense.

Alright, here’s my favorite version of the streamlit one:

import random

import streamlit as st
from pandas import read_csv
from plotnine import (
    aes,
    geom_histogram,
    geom_point,
    ggplot,
    scale_x_log10,
    scale_y_log10,
    theme_bw,
)

if "random_state" not in st.session_state:
    st.session_state["random_state"] = random.randint(0, 1_000_000_000)


@st.cache_data
def load_data():
    df = read_csv(
        "https://raw.githubusercontent.com/GShotwell/streamlit-shiny-comp/main/nyc-taxi.csv"
    )
    return df


def sample_data(sample):
    return load_data().sample(
        frac=sample, random_state=int(st.session_state["random_state"] * sample)
    )


with st.sidebar:
    fraction = st.number_input("sample", 0.0, 1.0, value=0.1, step=0.01)
    log = st.checkbox("Log Scale")
    if st.button("Resample"):
        st.session_state["random_state"] = random.randint(0, 1_000_000_000)

sample = sample_data(fraction)

st.subheader(f'First taxi id: {sample["taxi_id"].iloc[0]}')

plot = ggplot(sample, aes(x="total_amount")) + geom_histogram(binwidth=5) + theme_bw()
st.pyplot(plot.draw())

plot = ggplot(sample, aes("tip_amount", "total_amount")) + geom_point() + theme_bw()
if log:
    plot = plot + scale_x_log10() + scale_y_log10()
st.pyplot(plot.draw())

Only caches the initial load data, the sampled data doesn’t change if you don’t change the sample size, but it does allow you to re-run the sample for a given sample size explicitly by a button. And, no callbacks necessary :slight_smile:

Definitely. You want a new sample only when the slider changes. So put the sample in session_state and update it in the slider’s callback.

In my code I use the pandas plot API instead of plotnine for my convenience.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st


def main():
    data = read_data()
    if "sample" not in st.session_state:
        st.session_state.sample = data.sample(frac=0.1)

    st.sidebar.number_input(
        label="Sample size",
        min_value=0.0,
        max_value=1.0,
        value=0.1,
        step=0.01,
        on_change=update_sample,
        args=(data,),
        key="sample_size",
    )
    log_scale = st.sidebar.checkbox("Log Scale")

    st.subheader(first_taxi_id(data=st.session_state.sample))
    st.pyplot(
        tip_plot(data=st.session_state.sample, log_scale=log_scale),
        use_container_width=False,
        dpi=90,
    )
    st.pyplot(
        amount_histogram(data=st.session_state.sample),
        use_container_width=False,
        dpi=90,
    )


@st.cache_data
def read_data():
    return pd.read_csv(
        "https://github.com/rstudio/streamlit-shiny-dash/raw/main/nyc-taxi.csv"
    )


def update_sample(data):
    st.session_state.sample = data.sample(frac=st.session_state.sample_size)


def first_taxi_id(data):
    return f'First taxi ID: {data["taxi_id"].iloc[0]}'


def tip_plot(data, log_scale):
    fig = plt.figure()
    data.plot(
        x="tip_amount",
        y="total_amount",
        style=".",
        legend=False,
        ylabel="total_amount",
        loglog="sym" if log_scale else False,
        grid=True,
        ax=fig.add_subplot(),
    )
    return fig


def amount_histogram(data):
    plot_data = data.total_amount
    bin_width = 5

    fig = plt.figure()
    ax = plot_data.hist(
        bins=np.arange(plot_data.min(), plot_data.max() + bin_width, bin_width),
        ax=fig.add_subplot(),
    )
    ax.set_xlabel("total_amount")
    ax.set_ylabel("count")
    return fig


if __name__ == "__main__":
    main()