Simple button that waits for input?

I’m surprised to see that there is no functionality to create a button that waits to be clicked. I know st.button exists but it doesn’t wait to be clicked. I know I may not be “getting” the core control-flow philosophy of streamlit, but this type of thing should be a priority.

One place where this comes up is: say we want an app where someone can upload a dataset, and then select (via st.multiselect) which columns should be treated as categorical, and then train a model based on these settings. Well, I’d like to be able to wait until all categoricals are selected, as indicated by the user clicking a button like “Train Model”…

Hi @pchalasani, welcome to the forum !

Sorry, I’m not sure what you mean by a button that waits to be clicked. How would you like the following toy example to be improved ? assume load_data is a file_uploader instead and the transformation of the user selected columns to categorical is made in the train_model

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
import streamlit as st

TITANIC_URL = "https://gist.githubusercontent.com/michhar/2dfd2de0d4f8727f873422c5d959fff5/raw/23da2b7680e0c9e1fd831f05f53de3958f0d75fe/titanic.csv"

def main():
    st.header("Titanic: Machine Learning from Disaster")
    df = load_data(TITANIC_URL)

    st.sidebar.header("Configuration")
    cols_to_train = st.sidebar.multiselect("Select categorical cols", [c for c in df.columns.values if c != "Survived"])

    if st.button('Train model'):
        with st.spinner("Training ongoing"):
            clf, confusion_matrix = train_rf(df, cols_to_train)
            st.balloons()
            st.pyplot(confusion_matrix)
    

@st.cache
def load_data(url):
    df = pd.read_csv(url)
    df.drop(["PassengerId"], axis=1, inplace=True)
    df.drop(["Name"], axis=1, inplace=True)
    df.drop(["Ticket"], axis=1, inplace=True)
    df.drop(["Cabin"], axis=1, inplace=True)
    df.fillna(df.mean(), inplace=True)
    df = pd.concat(
        [df, pd.get_dummies(df["Sex"].astype("category"), prefix="sex")], axis=1
    )
    df = pd.concat(
        [df, pd.get_dummies(df["Embarked"].astype("category"), prefix="embarked")],
        axis=1,
    )
    df.drop(["Sex"], axis=1, inplace=True)
    df.drop(["Embarked"], axis=1, inplace=True)

    return df


def train_rf(df, features, n_estimators=100, max_depth=3):
    target = "Survived"
    X = df[features]
    y = df[target].astype("category")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.33, random_state=42
    )
    clf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
    clf.fit(X_train, y_train)

    fig, ax = plt.subplots()
    plot_confusion_matrix(clf, X_test, y_test, ax=ax)

    return clf, fig


if __name__ == "__main__":
    main()

6 Likes

Exactly what I’m looking for, thank you for a very instructive example with some new tricks that I didn’t know about !

2 Likes