Issue with st.experimental_fragment functions

I have a sample app that does not execute a global rerun except of course during startup, explicit user page reload and rerun. The session variables are shared between fragments, callbacks are kept within the fragmented functions.

So we have 3 fragmented functions, ticker_selector(), display() and chart().

def main():
    cols = st.columns([1, 1], gap='large')
    with cols[0]:
        st.markdown('**1. Ticker Selector**')
        ticker_selector()

    with cols[1]:
        st.markdown('**2. Display**')
        display()

    st.markdown('**3. Plot**')
    chart()

ticker_selector

The ticker_selector() has 2 controls, both are select boxes. User can select ticker symbol [aapl, amzn, msft …] and the interval [30m, 1d]. The values selected are stored in the session variables. That means we can access it anywhere.

@st.experimental_fragment
def ticker_selector():
    def ticker_cb():
        ss.ticker = ss.tss

    def interval_cb():
        ss.interval = ss.int

    st.selectbox('Select ticker', options=TICKERS,
                 key='tss', on_change=ticker_cb, label_visibility='collapsed')
    st.selectbox('Select interval', options=['30m', '1d'], key='int',
                 on_change=interval_cb, label_visibility='collapsed')

display()

Now we also have the display(), it displays the selected ticker from ticker_selector() selections. It also has its own select box for the ticker symbol.

However, one issue that needs to be solved is how can the ticker info in the display() and how can the chart be updated when there are changes in ticker symbol and/or interval. Why this is an issue? It is an issue because both the display() and chart() are fragmented. We may successfully change the ticker symbol in ticker_selector() but they are not yet applied ui-wise in other fragments.

One solution to solve this is to utilize the run_every parameter of the experimental_fragment decorator. Let’s say the value is 2 sec, so the display() will be run every 2 sec. The user can interact on the ticker_selector() while the display() is run every 2 sec, once it is run whatever is there in the display() will be executed.

@st.experimental_fragment(run_every=2)
def display():
    def ticker_cb():
        ss.ticker = ss.tsd

    st.markdown('<span></span>', unsafe_allow_html=True)  # vertical spacer only
    st.write(f'ticker: {ss.ticker}, interval: {ss.interval}')
    st.selectbox('Select ticker', options=TICKERS, key='tsd',
                 on_change=ticker_cb, label_visibility='collapsed')

The user can also override the value of ticker because display() has a select box for ticker selection.

chart

The same principle is used in display() is now applied in chart(). Ticker info is pulled thru api every 20 sec and data is plotted. The session variables ss.ticker and ss.interval are available here which are used to pull specific data.

@st.experimental_fragment(run_every=20)
def chart():
    # Plot chart    
    chart = StreamlitChart(width=None, height=500)
    chart.legend(True)

    interval = ss.interval
    chart.topbar.textbox('symbol', ss.ticker)

    ...

    df = get_bar_data(ss.ticker, interval)  # interval or timeframe
    chart.set(df)

    ...

    chart.load()

The user can interact with the chart, zooming, etc. without global reruns.

complete code

In case someone is interested to experiment etc., here is the complete code.

streamlit_app.py

import streamlit as st
from streamlit import session_state as ss
import pandas as pd
import yfinance as yf
from lightweight_charts.widgets import StreamlitChart


st.set_page_config(layout='wide')


TICKERS = ['AAPL', 'MSFT', 'AMZN', 'GOOGL']


if 'ticker' not in ss:
    ss.ticker = 'AAPL'

if 'interval' not in ss:
    ss.interval = '30m'


def get_bar_data(symbol, timeframe):
    """timeframe = ['15m', '1d'] or interval"""
    if timeframe.endswith('m'):
        period = '60d'
    else:
        period = '3y'

    df = get_data(symbol, period, timeframe)

    if timeframe.endswith('m'):
        df.set_index('Datetime', inplace=True, drop=True)

    return df


def calculate_price(df, interval: str, period: int = 1):
    if interval.endswith('m'):
        if df.index.name == 'Datetime':
            col = df.index
        else:
            col = df['Datetime']
    else:
        col = df['Date']

    return pd.DataFrame({
        'time': col,
        'Close': df['Close'].rolling(window=period).mean()
    }).dropna()


def convert_to_utc_plus_zero(datetime_str):
    dt_with_tz = pd.to_datetime(datetime_str)
    dt_utc_plus_zero = dt_with_tz.tz_convert('UTC')

    return dt_utc_plus_zero


def get_data(ticker: str, period: str, interval: str):
    """Retrieve data from yf.

    ['Date/Datetime', 'Open', 'High', 'Low', 'Close', 'Volume']
    It outputs Datetime if interval is below 1d
    """
    df = yf.Ticker(ticker).history(
        period=period,
        interval=interval,
        auto_adjust=False
    )[['Open', 'High', 'Low', 'Close', 'Volume']].reset_index()
    col_name = 'Date'
    if 'Datetime' in df.columns:
        col_name = 'Datetime'
    df[col_name] = df[col_name].apply(convert_to_utc_plus_zero)

    return df


@st.experimental_fragment
def ticker_selector():
    def ticker_cb():
        ss.ticker = ss.tss

    def interval_cb():
        ss.interval = ss.int

    st.selectbox('Select ticker', options=TICKERS,
                 key='tss', on_change=ticker_cb, label_visibility='collapsed')
    st.selectbox('Select interval', options=['30m', '1d'], key='int',
                 on_change=interval_cb, label_visibility='collapsed')


@st.experimental_fragment(run_every=2)
def display():
    def ticker_cb():
        ss.ticker = ss.tsd

    st.markdown('<span></span>', unsafe_allow_html=True)  # vertical spacer only
    st.write(f'ticker: {ss.ticker}, interval: {ss.interval}')
    st.selectbox('Select ticker', options=TICKERS, key='tsd',
                 on_change=ticker_cb, label_visibility='collapsed')


@st.experimental_fragment(run_every=120)
def chart():
    # Plot chart    
    chart = StreamlitChart(width=None, height=500)
    chart.legend(True)

    interval = ss.interval
    chart.topbar.textbox('symbol', ss.ticker)
    chart.topbar.textbox('interval', f'interval: {interval}')

    df = get_bar_data(ss.ticker, interval)  # interval or timeframe
    chart.set(df)

    # Price close
    price = chart.create_line(name='Close', color='rgb(204, 235, 255)', width=1)
    price_df = calculate_price(df, interval, period=1)
    price.set(price_df)
    chart.load()


def main():
    cols = st.columns([1, 1], gap='large')
    with cols[0]:
        st.markdown('**1. Ticker Selector**')
        ticker_selector()

    with cols[1]:
        st.markdown('**2. Display**')
        display()

    st.markdown('**3. Plot**')
    chart()


if __name__ == '__main__':
    main()

requirements.txt

streamlit==1.34.0
yfinance==0.2.38
lightweight-charts==1.0.21