Hi @pchalasani, welcome to the forum !
Sorry, I’m not sure what you mean by a button that waits to be clicked. How would you like the following toy example to be improved ? assume load_data
is a file_uploader instead and the transformation of the user selected columns to categorical is made in the train_model
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
import streamlit as st
TITANIC_URL = "https://gist.githubusercontent.com/michhar/2dfd2de0d4f8727f873422c5d959fff5/raw/23da2b7680e0c9e1fd831f05f53de3958f0d75fe/titanic.csv"
def main():
st.header("Titanic: Machine Learning from Disaster")
df = load_data(TITANIC_URL)
st.sidebar.header("Configuration")
cols_to_train = st.sidebar.multiselect("Select categorical cols", [c for c in df.columns.values if c != "Survived"])
if st.button('Train model'):
with st.spinner("Training ongoing"):
clf, confusion_matrix = train_rf(df, cols_to_train)
st.balloons()
st.pyplot(confusion_matrix)
@st.cache
def load_data(url):
df = pd.read_csv(url)
df.drop(["PassengerId"], axis=1, inplace=True)
df.drop(["Name"], axis=1, inplace=True)
df.drop(["Ticket"], axis=1, inplace=True)
df.drop(["Cabin"], axis=1, inplace=True)
df.fillna(df.mean(), inplace=True)
df = pd.concat(
[df, pd.get_dummies(df["Sex"].astype("category"), prefix="sex")], axis=1
)
df = pd.concat(
[df, pd.get_dummies(df["Embarked"].astype("category"), prefix="embarked")],
axis=1,
)
df.drop(["Sex"], axis=1, inplace=True)
df.drop(["Embarked"], axis=1, inplace=True)
return df
def train_rf(df, features, n_estimators=100, max_depth=3):
target = "Survived"
X = df[features]
y = df[target].astype("category")
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
clf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
clf.fit(X_train, y_train)
fig, ax = plt.subplots()
plot_confusion_matrix(clf, X_test, y_test, ax=ax)
return clf, fig
if __name__ == "__main__":
main()