Streamlit Webapp

I am making a webpage using streamlit. It used connection to databases in SQL, it has a couple of functions to do bits and bobs and then it predicts some components. After all this is done, I want to select the outputs from a drop down and do some interactions with the webpage, but everytime I click on anything the whole webpage reloads itself. I tried session state, st.cache wont work because of database connections. Please need some help on how to do this

Hello @Sibangi_Bhowmick :wave:

Welcome to the community! We’re thrilled to have you join us! :hugs:

Could you provide us with a link to your repo, please? That would allow us to diagnose and advise more accurately.

Best,
Charly

I dont have it on a public repo since it is confidential, i can show the prediction and main function, there are 3 more functions before this which i cannot share due to confidentiality

def get_predictions(serial_number):
df = fetch_data(serial_number)
repair = fetch_repair_data(serial_number)

if df.empty:
    st.warning("We don't have data for this serial number.")
    return

df['DtTm'] = pd.to_datetime(df['DtTm'])
df.sort_values(by='DtTm', ascending=False, inplace=True)
df = df.head(1)
df = df.applymap(lambda x: x.lower() if isinstance(x, str) else x)

assembly_name = df['ProductGroup'].iloc[0]
if "/" in assembly_name:
    assembly_name = assembly_name.replace("/", "")# stores assembly name and addresses the f/a problem
    assembly_name = assembly_name.lower()
    
df = df[['MLFB','Assembly','SystemName','TestNumber','LineNumber','DescriptionId','LowerLimit','UpperLimit',
       'ActualValue','UnitOfMeasureId']]

    
df = df.astype({
    'MLFB' :'object',
    'Assembly':'object',
    'TestNumber': 'object',
    'LineNumber': 'object',
    'UnitOfMeasureId':'object',
    'DescriptionId':'object',
})

df = pd.get_dummies(df, columns=['MLFB','Assembly', 
       'SystemName', 'TestNumber', 'LineNumber','UnitOfMeasureId','DescriptionId'])


df['range'] = df['UpperLimit'] - df['LowerLimit'].astype(int)
df['distance_from_lower'] = df['ActualValue'] - df['LowerLimit']
df['distance_from_upper'] = df['UpperLimit'] - df['ActualValue']
df['exceeds_lower'] = np.where(df['ActualValue'] < df['LowerLimit'], 1, 0)
df['exceeds_upper'] = np.where(df['ActualValue'] > df['UpperLimit'], 1, 0)

df.reset_index(inplace=True)
df.drop(columns=['index'], inplace=True, axis=1)
#st.write(df.columns)
#st.write(dummies.columns)

# Iterate over columns of df
for column in df.columns:
    
    # Check if the column name exists in dummies
    if column in dummies.columns:
        # Update dummies with non-NA values from df for matching columns
        dummies.update(df[column])
    else:
        st.warning(f"Model has not been trained for '{column}'")
        return


scaled_new_data = scale_new_data(dummies, 'scaler_new.pkl')

# Define the directory where the models are saved
models_directory = "Repair_Models"

# Check if the directory exists
if not os.path.exists(models_directory):
    st.error(f"Error: Directory '{models_directory}' does not exist.")
    return

model_predictions = {}  # Dictionary to store model predictions
#st.write(dummies)

txt_filename = f"{assembly_name}.txt"
txt_file_path = os.path.join(models_directory, txt_filename)

if not os.path.exists(txt_file_path):
    st.error(f"Error: File '{txt_filename}' does not exist.")
    return
    
with open(txt_file_path, 'r') as txt_file:
    model_names_from_txt = txt_file.read().splitlines()

st.success("Getting Predictions")
# Iterate over files in the models directory
for filename in os.listdir(models_directory):
    if filename.endswith(".pkl"):  # Check if file is a joblib file
        model_name = os.path.splitext(filename)[0]  # Remove file extension from model name
        # Check if the model name is present in the .txt file's names
        if model_name in model_names_from_txt:
            model_path = os.path.join(models_directory, filename)

            try:
                # Load the model from the joblib file
                model = joblib.load(model_path)

                # Make predictions using the loaded model on new scaled data
                predictions = model.predict(scaled_new_data)
                probabilities = model.predict_proba(scaled_new_data)

                # Store model predictions and probabilities
                model_predictions[model_name] = {'predictions': predictions, 'probabilities': probabilities}

            except Exception as e:
                st.error(f"Error loading or making predictions for model '{model_name}': {e}")
                return

# Create lists to store model names with predictions as 1 and 0, and their probabilities
models_with_predictions_1 = {}
models_with_predictions_0 = {}

# Iterate over the model predictions
for model_name, predictions_info in model_predictions.items():
    predictions = predictions_info['predictions']
    probabilities = predictions_info['probabilities']

    # Check if any prediction is 1 or 0
    if np.any(predictions == 1):
        max_probability = np.max(probabilities[:, 1])  # Get the maximum probability of prediction 1
        models_with_predictions_1[model_name] = max_probability
    elif np.any(predictions == 0):
        min_probability = np.min(probabilities[:, 0])  # Get the minimum probability of prediction 0
        models_with_predictions_0[model_name] = min_probability

# Display the sorted model names with predictions as 1
selected_1_model = None
st.subheader("Top Suggestions for Repair Guidance")
top_5_predictions_1 = sorted(models_with_predictions_1.items(), key=lambda x: x[1], reverse=True)[:5]
for model_name, probability in top_5_predictions_1:
    st.write(f"{model_name} (Probability: {probability * 100:.2f}%)")

# Display the sorted model names with predictions as 0
st.subheader("Suggestions on Other Probable Components")
sorted_predictions_0 = sorted(models_with_predictions_0.items(), key=lambda x: x[1])
top_10_predictions_0 = sorted_predictions_0[:5]
for model_name, probability in top_10_predictions_0:
    st.write(f"{model_name} (Probability: {100 - (probability * 100):.2f}%) ")

Streamlit app

def main():
st.set_page_config(page_title=“Repair Companion”, page_icon=“:wrench:”, layout=“wide”, initial_sidebar_state=“expanded”)
st.title("Repair Companion Project ")

# Input for Serial Number
serial_number = st.text_input("Enter Serial Number:")

if st.button("Get Data and Predictions"):
    if serial_number:
        df = fetch_data(serial_number)# gets the failed data
        data = fetch_pass_data(serial_number)# gets the pass data for that serial number
        repair = fetch_repair_data(serial_number)# gets the repair data
        assembly_series = df['Assembly'] if not df.empty else data['Assembly']
        fig, fig2 = past_10_days(assembly_series.iloc[0]) #two graphs

        st.subheader("All Failed Final Test Data for this Serial Number")
        format_df = df[['DtTm','MLFB','Assembly','SystemName','SerialNumber','Description','TestNumber',
                        'LineNumber','LowerLimit','UpperLimit','ActualValue']]
        st.dataframe(format_df,width=3000)#first df
        
        get_predictions(serial_number)#loads model and prints outputs

        st.subheader("Results if the unit has passed certain/all tests")
        format_data = data[['DtTm','MLFB','Assembly','SystemName','SerialNumber','Description','TestNumber',
                            'LineNumber','LowerLimit','UpperLimit','ActualValue']]
        st.dataframe(format_data,width = 3000)#second df
        
        st.subheader("All Repair Data for this Serial Number")
        st.dataframe(repair,width=3000)#prints repair df
        
        st.subheader("Circuit References identified in the past")
        st.plotly_chart(fig, use_container_width=True)
        st.subheader("Defect Reasons identified in the past")
        st.plotly_chart(fig2, use_container_width=True)
       
additional_functionality = st.checkbox("Enable Additional Functionality")
if additional_functionality:
    st.write("You've enabled additional functionality!")

if name == “main”:
main()

when i click the additional functionality checkbox, the app reloads.

Thanks for sending that info, @Sibangi_Bhowmick! :raised_hands:

Basically, Streamlit apps run from start to finish every time you interact with the page (like clicking a button or checking a checkbox). This means that everything on the page is recalculated and reloaded with each interaction.

To manage the hindrance you’re mentioning, you can try leveraging Session State to remember data—this will prevent the app from reloading all data every time. You store the data once it’s loaded, and only reload if necessary.

if 'data_loaded' not in st.session_state:
    st.session_state.data_loaded = False  # Initialize the state

if st.button("Get Data and Predictions"):
    if not st.session_state.data_loaded:
        # Load your data and predictions here
        predictions = get_predictions(serial_number)
        st.session_state.predictions = predictions  # Store predictions in session state
        st.session_state.data_loaded = True

    display_predictions(st.session_state.predictions)

I can’t check your code as I don’t have access to your repo, but please let me know if that addresses your issues. If not, we can explore other ways to optimize your app further. :slightly_smiling_face:

Best,
Charly

def get_predictions(serial_number):
    df = fetch_data(serial_number)
    repair = fetch_repair_data(serial_number)
    
    if df.empty:
        st.warning("We don't have data for this serial number.")
        return
    
    df['DtTm'] = pd.to_datetime(df['DtTm'])
    df.sort_values(by='DtTm', ascending=False, inplace=True)
    df = df.head(1)
    df = df.applymap(lambda x: x.lower() if isinstance(x, str) else x)
    
    assembly_name = df['ProductGroup'].iloc[0]
    if "/" in assembly_name:
        assembly_name = assembly_name.replace("/", "")# stores assembly name and addresses the f/a problem
        assembly_name = assembly_name.lower()
        
    df = df[['Assembly','SystemName','TestNumber','LineNumber','DescriptionId','LowerLimit','UpperLimit',
           'ActualValue','UnitOfMeasureId']]

        
    df = df.astype({
        
        'Assembly':'object',
        'TestNumber': 'object',
        'LineNumber': 'object',
        'UnitOfMeasureId':'object',
        'DescriptionId':'object',
    })

    df = pd.get_dummies(df, columns=['Assembly', 
           'SystemName', 'TestNumber', 'LineNumber','UnitOfMeasureId','DescriptionId'])


    df['range'] = df['UpperLimit'] - df['LowerLimit'].astype(int)
    df['distance_from_lower'] = df['ActualValue'] - df['LowerLimit']
    df['distance_from_upper'] = df['UpperLimit'] - df['ActualValue']
    df['exceeds_lower'] = np.where(df['ActualValue'] < df['LowerLimit'], 1, 0)
    df['exceeds_upper'] = np.where(df['ActualValue'] > df['UpperLimit'], 1, 0)

    df.reset_index(inplace=True)
    df.drop(columns=['index'], inplace=True, axis=1)
    #st.write(df.columns)
    #st.write(dummies.columns)
    
    # Iterate over columns of df
    for column in df.columns:
        
        # Check if the column name exists in dummies
        if column in dummies.columns:
            # Update dummies with non-NA values from df for matching columns
            dummies.update(df[column])
        else:
            st.warning(f"Model has not been trained for '{column}'")
            return
    
    
    scaled_new_data = scale_new_data(dummies, 'scaler_new.pkl')
    
    # Define the directory where the models are saved
    models_directory = "Repair_Models"

    # Check if the directory exists
    if not os.path.exists(models_directory):
        st.error(f"Error: Directory '{models_directory}' does not exist.")
        return

    model_predictions = {}  # Dictionary to store model predictions
    #st.write(dummies)
    
    txt_filename = f"{assembly_name}.txt"
    txt_file_path = os.path.join(models_directory, txt_filename)
    
    if not os.path.exists(txt_file_path):
        st.error(f"Error: File '{txt_filename}' does not exist.")
        return
        
    with open(txt_file_path, 'r') as txt_file:
        model_names_from_txt = txt_file.read().splitlines()
    
    st.success("Getting Predictions")
    # Iterate over files in the models directory
    for filename in os.listdir(models_directory):
        if filename.endswith(".pkl"):  # Check if file is a joblib file
            model_name = os.path.splitext(filename)[0]  # Remove file extension from model name
            # Check if the model name is present in the .txt file's names
            if model_name in model_names_from_txt:
                model_path = os.path.join(models_directory, filename)

                try:
                    # Load the model from the joblib file
                    model = joblib.load(model_path)

                    # Make predictions using the loaded model on new scaled data
                    predictions = model.predict(scaled_new_data)
                    probabilities = model.predict_proba(scaled_new_data)

                    # Store model predictions and probabilities
                    model_predictions[model_name] = {'predictions': predictions, 'probabilities': probabilities}

                except Exception as e:
                    st.error(f"Error loading or making predictions for model '{model_name}': {e}")
                    return

    # Create lists to store model names with predictions as 1 and 0, and their probabilities
    models_with_predictions_1 = {}
    models_with_predictions_0 = {}

    # Iterate over the model predictions
    for model_name, predictions_info in model_predictions.items():
        predictions = predictions_info['predictions']
        probabilities = predictions_info['probabilities']

        # Check if any prediction is 1 or 0
        if np.any(predictions == 1):
            max_probability = np.max(probabilities[:, 1])  # Get the maximum probability of prediction 1
            models_with_predictions_1[model_name] = max_probability
        elif np.any(predictions == 0):
            min_probability = np.min(probabilities[:, 0])  # Get the minimum probability of prediction 0
            models_with_predictions_0[model_name] = min_probability

    # Display the sorted model names with predictions as 1
    selected_1_model = None
    st.subheader("Top Suggestions for Repair Guidance")
    top_5_predictions_1 = sorted(models_with_predictions_1.items(), key=lambda x: x[1], reverse=True)[:5]
    for model_name, probability in top_5_predictions_1:
        st.write(f"{model_name} (Probability: {probability * 100:.2f}%)")

    # Display the sorted model names with predictions as 0
    st.subheader("Suggestions on Other Probable Components")
    sorted_predictions_0 = sorted(models_with_predictions_0.items(), key=lambda x: x[1])
    top_10_predictions_0 = sorted_predictions_0[:5]
    for model_name, probability in top_10_predictions_0:
        st.write(f"{model_name} (Probability: {100 - (probability * 100):.2f}%) ")
        

@st.cache_data
def fetch_all_data(serial_number):
    df = fetch_data(serial_number)
    data = fetch_pass_data(serial_number)
    repair = fetch_repair_data(serial_number)
    return df, data, repair

 
# Streamlit app
def main():
    st.set_page_config(page_title="Repair Companion", page_icon="🔧", layout="wide", initial_sidebar_state="expanded")
    st.title("Repair Companion Project ")
    
    # Input for Serial Number
    serial_number = st.text_input("Enter Serial Number:")

    if st.button("Get Data and Predictions"):
        if serial_number:
            df, data, repair = fetch_all_data(serial_number)
            assembly_series = df['Assembly'] if not df.empty else data['Assembly']
            fig, fig2 = past_10_days(assembly_series.iloc[0]) #two graphs

            st.subheader("All Failed Final Test Data for this Serial Number")
            format_df = df[['DtTm','MLFB','Assembly','SystemName','SerialNumber','Description','TestNumber',
                            'LineNumber','LowerLimit','UpperLimit','ActualValue']]
            st.dataframe(format_df,width=3000)#first df
            
            get_predictions(serial_number)#loads model and prints outputs

            st.subheader("Results if the unit has passed certain/all tests")
            format_data = data[['DtTm','MLFB','Assembly','SystemName','SerialNumber','Description','TestNumber',
                                'LineNumber','LowerLimit','UpperLimit','ActualValue']]
            st.dataframe(format_data,width = 3000)#second df
            
            st.subheader("All Repair Data for this Serial Number")
            st.dataframe(repair,width=3000)#prints repair df
            
            st.subheader("Circuit References identified in the past")
            st.plotly_chart(fig, use_container_width=True)
            st.subheader("Defect Reasons identified in the past")
            st.plotly_chart(fig2, use_container_width=True)
           
        additional_functionality = st.checkbox("Enable Additional Functionality")
        if additional_functionality:
            st.write("You've enabled additional functionality!")


if __name__ == "__main__":
    main()

This is my current code, will you be able to help seeing this?

1 Like