How to create multiple chat_input in same app

I like ur UI style. It prevents chatbot from bloating and gives other elements to add details and then keep them.

My workload is very similar but I am too confused how to get multiple inputs in validation and correction phase.

You already have in earlier response the main validate_Input_query() where is bulk of get_user_input() - a wrapper on chat_input() is called whereever we need user to add more field correction or some details. It’s like 4-5 places in that function as you see.

My app kick starts with a run_chatbot() function in main script as below:


import streamlit as st
from src.platform_intelligence.language.process_input_query import process_chat_response, tools
from src.utils.display_utility import stream_message

def run_chatbot():
    # Modularize chatbot UI elements
       
    """Run the Streamlit chatbot interface."""
    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hello👋 , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]
    
    for message in st.session_state.messages:
        if message["role"] == "assistant":
            with st.chat_message(message["role"], avatar = "💁"):
                st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            with st.chat_message(message["role"]):
                st.markdown(f':black[{message["content"]}]')

    if user_input := st.chat_input("Ask something", key = "user_input_main"):
        with st.chat_message("user", avatar = None):
            st.markdown(f':black[{user_input}]')
        st.session_state.messages.append({"role": "user", "content": user_input})
        assistant_response = process_chat_response(user_input, tools = tools, language = language).strip()
        
        # Stream assistant's response in real-time
        stream_message(assistant_response)

process_chat_response():

def process_chat_response(
    user_query: str, 
    system_prompt: str = SYSTEM_PROMPT, 
    tools: Any = None, 
    schema_model: BaseModel = InputQuery, 
    language: str = "English"
    ) -> Union[BaseModel, str]:
    """
    Processes the chat response from a completion request and outputs function details and arguments.

    This function sends a list of messages to a chat completion request, processes the response to extract 
    function calls and arguments, and prints relevant information. It also merges function arguments and 
    initializes an `InputQuery` object using the merged arguments.

    Args:
        user_query (str): Query entered by user
        system_prompt (str): If a user has any specific prompt to enter. 
        tools (Any): The tools to be used with the chat completion request.
        schema_model (BaseModel): The Pydantic model class used for validating user query.
        language (str): Language in which you want to a response

    Returns:
        response_output (Union[BaseModel, str]): Returns the response for the query.
    """

      # Ensure chat_history is initialized in session state
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []
    
     # Convert chat history into a formatted string
    chat_history_str = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in st.session_state.chat_history])
    
    # Format the system prompt with the chat history
    formatted_system_prompt = system_prompt.format(chat_history=chat_history_str, language = language)

    messages = [{"role": "system", "content": formatted_system_prompt}]
    messages.append({"role": "user", "content": user_query})

    print(f'Full Prompt: {messages}')

    print(f'Tools : {tools}')

    response = chat_completion_request(messages, tools=tools, response_format={"type": "text"})

    print(f'\n{response}')

    merged_arguments = defaultdict(lambda: None)

    if response.choices[0].finish_reason == "tool_calls":
        for tool_call in response.choices[0].message.tool_calls:
            function_arguments = json.loads(tool_call.function.arguments)
            merged_arguments.update(function_arguments)

        merged_arguments = dict(merged_arguments)
        
        print()
        print(f'function call arguments: {function_arguments}')
        print(f"Merged Arguments: {merged_arguments}")

        # Convert merged_arguments to a JSON-like string and escape curly braces
        merged_arguments_str = str(merged_arguments).replace("{", "{{").replace("}", "}}")
        
        # Append the user's query and the assistant's response to the chat history
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": merged_arguments_str})
        
        # Verifying the Output with Verifier LLM Agent
        #verifier_response = verifier_agent_response(user_query, merged_arguments, tools)
        print()
        #print(f"Verifier LLM Agent Response: {verifier_response}")

        # Validate the InputQuery object with re-prompting if necessary
        final_response = validate_input_query(merged_arguments, schema_model)
        print(f"Process Chat Final Response: {final_response}")

    elif response.choices[0].finish_reason == 'stop' and response.choices[0].message.content is not None:
        final_response = response.choices[0].message.content.strip()

         # Verifying the Output with Verifier LLM Agent
        #verifier_response = verifier_agent_response(user_query, final_response, tools)
        print()
        #print(f"Verifier LLM Agent Response: {verifier_response}")

        # Append the user's query and the assistant's response to the chat history
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": final_response})

        #print(f'\n{final_response}')

    print(f"chat history: {st.session_state.chat_history}")
    
    return final_response

Flow:

The user first enters a query; Which happens in above code- the kickoff code. As u see, the user input is passed to process_chat_response() - this functiion is part of script where validate_input_query() is there too.

In that script the input is sent to LLM. If the input is irrelevant to what the LLM prompt is designed for, it returns a text output which is then returned back and shown to the user in chat_message.

If the query is relevant to our task, then it is passed to validate_input_query() function which is where the validation begins. The validation uses a pydantic schema to make sure all entities extracted from user query meet the constraints of schema.

If any entity is missing or fails validation, then we end up asking a bunch of times from users to enter. This is the correction phase.

Its in this part where I need to add multiple chat_input. In ur case I see in the validation or ask phase, u only had one chat_input.

But in my case, I have multiple user input needed (depending upon how many errors) within same phase.

The query could be any random irrelevant query.

If you want to collect a batch of fields all at once vs collecting them one by one, I think you might be better off using st.text_input and a form, or possibly st.data_editor in the chat message, just like the choice buttons. Is it correct to say that your correction phase effectively updates a parsed set of key-value pairs (values in specific fields as you say)? This is as opposed to inline correction like I had in my example.

We want to give an option to user to enter in natural language for any change or choice 2 is when they can enter field by field. For 2nd option is where we can show form. But for 1, we need to enter language in chat_input

Ok, so probably not st.data_editor, but I still think you might want st.text_input if you are trying to batch responses. st.chat_input is a trigger; it’s like a button. It doesn’t have state and won’t retain its info. Therefore, it’s inherently ill-suited to batching responses. Even if you put multiple st.chat_input widgets on a page, you would necessarily be force to handle each one, one-at-a-time.

Can you draw a picture with a specific example to show what UI you’re trying to accomplish? My example was simple in that it just replaced text, but in principle, it iterated through “things that the user needed to address” and handled that action before moving on to the next item. Do you need to collect those action items before processing them? In that case, rather than directly applying the correction, you can build a list of corrections, much like there was a list of errors. You’d just need a little extra code to keep track of where the user is. For example, build a list of corrections in st.session_state.corrections and always take the next error as:

st.session_state.errors[len(st.session_state.corrections)]

So I think I do that in my code in validate_input_query() function. It looks at all errors, then messages one by one to enter correct value for each field. It does it via chat_input. Problem is this whole thing works as normal python Input() over terminal, but if I put multiple chat_inputs in streamlit it starts to lose its mind.

You need to remove the while True in your validation function.

What happens when you enter your validation function:

  1. You enter the while True loop.
  2. Let’s say you enter the correction stage.
  3. Then you enter a for loop for all your errors. You code loops through all your errors in the blink of eye, faster than any user could ever respond. In so doing you call get_user_input over and over, always returning None (because the user has had no chance to respond yet).
  4. You restart your while True loop.
  5. You re-enter the correction stage.
  6. You loop through your all your errors again and keep calling get_user_input over and over, which is still always returning None since only a fraction of a second has lapsed since you started asking the user for an answer.
  7. Repeat ad nauseum.

If Streamlit didn’t stop in the middle of this because of a duplicat widget, you’d be seeing hundreds and hundreds of repeated prompts on the screen.

If you want to prompt the user for each error in succession, your correction stage needs to ask about a single error. When the user responds, Streamlit will rerun (you can process that response to clear that one off the list). When Streamlit gets to that point in your script again, it will repeat the code to ask about one thing only (which through Session State should be the next thing after crossing the previous one off the list).

Thats exactly is. This is exactly what happens and I have been thinking why it is sending None when user hasn’t entered anything.

This whole thiing works on terminal but streamlit has caused me issues on this.

Wud you have an example snippet you would guide from one of the validation_input_query() snippet to see how it works?

Ok, so here’s a question. I understand ur script now.

The easier part in ur script is 1) its one script 2) it has everything within IF conditions as per stage.

So that somehow makes it work with multiple input scenario.

How will that work in my case where I have effectively two scripts. How does transfer moves from one script to another if the first script only has “ask” stage while validation and correct stage is actually in next script. HOw will control move there if I do st.rerun in “ask”?

So, I took learnings from ur code, and tried refactoring my code:

Here’s two scripts I have that are modified;

main.py (this is kickstart of streamlit app)


import streamlit as st
from src.platform_intelligence.language.process_input import process_chat_response, tools
from src.utils.display_utility import stream_message, get_user_input

def run_chatbot():
    # Modularize chatbot UI elements
       
    """Run the Streamlit chatbot interface."""
    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hello👋 , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]

    if "current_stage" not in st.session_state: 
        st.session_state.current_stage = 'ask'
    
    for message in st.session_state.messages:
        if message["role"] == "assistant":
            with st.chat_message(message["role"], avatar = "💁"):
                st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

    if st.session_state.current_stage == 'ask':
        # Handle chatbot flow using get_user_input
        if user_input:= st.chat_input("Ask Something", key = "main_chat_key", disabled = False):
            with st.chat_message("user", avatar=None):
                st.markdown(user_input)
            assistant_response = process_chat_response(user_input, tools = tools, language = language).strip()
            #    #st.session_state.messages.append({"role": "user", "content": user_input})
            # Stream assistant's response in real-time
            st.session_state.messages.append({"role": "user", "content": user_input})
            stream_message(assistant_response)
            print(f"stage of code: From process()-> main: {st.session_state.current_stage}")
            st.session_state.current_stage = "ask"
            st.rerun()
    #st.stop()

The validated_query() where validation and correction are happening.


def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
        merged_arguments (Dict): The parsed arguments from previous LLM
        schema_model (BaseModel): InputQuery Schema
        
    Returns:
        BaseModel: Verified Pydantic Model
    """ 

    if "current_stage" not in st.session_state:
        st.session_state.current_stage = "validate"

    # Initialize or retrieve stored errors
    if "errors" not in st.session_state:
        st.session_state.errors = []
    
    #while True:
    errors = {}
    print(f"Current session_state: {st.session_state.current_stage}")
    print(f"stage of code: Entered Validate_query()")
    # Validation Stage
    if st.session_state.current_stage == "validate":
        st.chat_input("...", disabled=True)
        try:
            validated_query = schema_model(**merged_arguments)
            # Convert the validated Pydantic model into a dictionary for easier iteration
            validated_data = validated_query.dict()
            message = "Great! Here's what we learnt from your intent:\n\n"
            # Stream each key-value pair one by one with bold keys
            for key, value in validated_data.items():
                message += f"**{key}**: {value}\n\n"  # Key is bolded

            stream_message(message, avatar="💁")  # Stream the full message to the chat

            print(f"Before stage change: {st.session_state.current_stage}")

            st.session_state.current_stage = "ask_additional"  # Mark validation as complete

            print(f"After stage change: {st.session_state.current_stage}")

            print(f'validated query in validate_input_query:{validated_query}\n')

            print()
            print(f"stage of code: From Validate_query() to Ask_Additional()")

            st.rerun()

        except ValidationError as e:
            # Collect all validation errors without prompting for input right away
            st.session_state.current_stage = "correct"

            print(f"All Validation Errors: {e.errors()}")
            print()
            
            # Capture each error from ValidationError and append to st.session_state.errors
            for error in e.errors():
                if error['loc']:
                    field = error.get('loc', [None])[0]  # Get the field where error occurred
                    error_message = error.get('msg', 'Unknown error')
                    # Store the error in session state
                    st.session_state.errors.append({field: error_message})
                else:
                    print(f"Missing Mandatory Fields Error: {error['msg']}")
                    missing_fields_message = error['msg']
                    if "Missing required fields" in missing_fields_message:
                        missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                        # Add an error for each missing field
                        for field in missing_fields:
                            # Store each missing field error in the same format as field errors
                            st.session_state.errors.append({
                                "field": field,
                                "error_message": f"Missing required field: {field}"
                            })

            print(f"All stored errors in session_state: {st.session_state.errors}")
            
            # Optional: Log the collected errors
            for err in st.session_state.errors:
                print(f"Collected error: {err}")

            st.session_state.current_stage = 'correct'
            st.rerun()

    if st.session_state.current_stage == "correct":
        # After errors are collected, correct them
        if st.session_state.errors:
            stream_message("The following fields need fixing:\n\n")
            
            for error in st.session_state.errors:
                stream_message(f"- **{error['field']}**: {error['message']}\n\n")

            # Prompt user to correct all specific errors together
            for error in st.session_state.errors:
                # Ask for input for the field
                stream_message(f"Please provide a valid value for '{error['field']}':\n\n")
                if user_input := st.chat_input(f"Please provide a valid value for '{error['field']}':", key=f"correct_{error['field']}"):
                    merged_arguments[error['field']] = user_input
                    # st.session_state.chat_history.append({"role": "user", "content": user_input})
                    st.session_state.errors.remove[error]
        else: 
            st.session_state.current_stage = "validate"
            st.rerun()
                    

    # Additional details or corrections
    #while st.session_state.current_stage == "validation" and not st.session_state.errors:
    #while True:
    print()
    print(f"stage of code: Entered Ask_Additional()")
    print(f"Current session_state: {st.session_state.current_stage}")
    if st.session_state.current_stage == "ask_additional":
        stream_message(f"Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize):\n")
        # user_choice = get_user_input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ",
        # key = get_unique_key(f"user_choice_key"), stage = "ask_additional")

        if user_input:= st.chat_input("Enter your choice 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize)"):
            if user_input == 'n':
                st.session_stage.current_stage = "ask"
                return validated_query  
            st.rerun()

        elif st.session_state.user_input == '1':
            stream_message("Please enter additional details:")
            if user_input:= st.chat_input("Please enter additional details:", key = get_unique_key(f"additional_info_key")):
                validated_query = process_chat_response(user_input, tools = tools)
                if(isinstance(validated_query, str)):
                    print(f'Yes {validated_query} is string')
          
        # elif st.session_state.user_input == '2':

        #     st.session_state.current_substage = "choice2_field_name"
        #     stream_message("Enter the field name you want to correct:")
        #     #field_to_correct = get_user_input("Enter the field name you want to correct:", key = get_unique_key(f"field_to_correct_key"), stage = "ask_additional")
        #     st.session_state.choice2_field_name = st.session_state.user_input
        #     if st.session_state.choice2_field_name in merged_arguments:
        #         st.session_state.current_substage = "choice2_field_value"
        #         stream_message(f"Enter the new value for '{st.session_state.choice2_field_name}':")
        #         #new_value = get_user_input(f"Enter the new value for '{field_to_correct}':", key = get_unique_key(f"new_value_{field_to_correct}"), stage = "ask_additional")
        #         if st.session_state.user_input is not None:
        #             merged_arguments[st.session_state.choice2_field_name] = st.session_state.user_input
        #             st.session_state.current_stage = "validation"
        #             validated_query = schema_model(**merged_arguments)  # Re-validate with the updated field
        #             break
        #     else:
        #         stream_message(f"Field '{st.session_state.choice2_field_name}' is not present in the current query.")
        else:
            stream_message("Invalid option. Please enter 1, 2, or 'n'.")

If u see, every part of code is under an IF st.session_stage.current_stage and the chat_input() are placed under those.

However, my code doesn’t run anything post first st.rerun() under validated_query() especially below part:

except ValidationError as e:
            # Collect all validation errors without prompting for input right away
            st.session_state.current_stage = "correct"

            print(f"All Validation Errors: {e.errors()}")
            print()
            
            # Capture each error from ValidationError and append to st.session_state.errors
            for error in e.errors():
                if error['loc']:
                    field = error.get('loc', [None])[0]  # Get the field where error occurred
                    error_message = error.get('msg', 'Unknown error')
                    # Store the error in session state
                    st.session_state.errors.append({field: error_message})
                else:
                    print(f"Missing Mandatory Fields Error: {error['msg']}")
                    missing_fields_message = error['msg']
                    if "Missing required fields" in missing_fields_message:
                        missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                        # Add an error for each missing field
                        for field in missing_fields:
                            # Store each missing field error in the same format as field errors
                            st.session_state.errors.append({
                                "field": field,
                                "error_message": f"Missing required field: {field}"
                            })

            print(f"All stored errors in session_state: {st.session_state.errors}")
            
            # Optional: Log the collected errors
            for err in st.session_state.errors:
                print(f"Collected error: {err}")

            st.session_state.current_stage = 'correct'
            st.rerun()

basically, it captures all errors correctly and stores it in st.session_stage.errors, the line above st.rerun(). However after that the nothing happens. I was expecting app to refresh, stage is changed to “correct” right above that line and it should then go if condition corresponding “correct” stage, (which is below st.rerun()). However, it doesn’t do that.

is st.rerun() correct usage here?

If you see all the print statements for “Collected error…”, then as written, the value in st.session_state.current_stage should be updated and the app should rerun with that new stage in place.

I think some of your code is missing, since I only see your “ask” stage. Try adding this as the very first line in your app:

st.write(st.session_state)

Then you’ll always see that as the first line as visual confirmation while you debug. It will be empty when you start a new session, but after your script has run once, you’ll see the values there and you’ll have better visibility to see if something is not going as expected.

So, I believe the code wasn’t running, because after st.rerun() it reruns the app code, but this time no function is getting called so even if stage changes, those stage changes are within function (validate_input_query()) etc.

However, I modified the code by keeping all stage change logic in main script and keeping only validation and correction logic in process_input.py

The multiple chat_input issue is resolved and it at a time only shows one. However, other issue I see is, things are getting overwritten very fast on chat_message() which is adding a lot of wrong UX. I don’t want every assistant message to remain in session_state.messages history, as intermediate steps(correction etc) should not be stored in history to prevent chat blotting.

Here’s below:

main.py

from typing import Dict

import streamlit as st

from src.platform_intelligence.language.process_input import process_chat_response, tools, validate_input_query
from src.utils.display_utility import stream_message

def run_chatbot():
    """Run the Streamlit chatbot interface."""

    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hello👋 , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]

    if "current_stage" not in st.session_state:
        st.session_state.current_stage = 'ask'
    
    if "merged_arguments" not in st.session_state:
        st.session_state.merged_arguments = {}
    
    # for message in st.session_state.messages:
    #     with st.chat_message(message["role"], avatar="💁" if message["role"] == "assistant" else None):
    #         st.markdown(f'{message["content"]}', unsafe_allow_html=True)
    
    # display archived messages from history to chat
    for message in st.session_state.messages:
        if message["role"] == "assistant":
            with st.chat_message(message["role"], avatar = "💁"):
                st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

    if st.session_state.current_stage == 'ask':
        user_input = st.chat_input("Enter your Query", key="main_chat_key")
        # Handle chatbot flow using chat_input
        if user_input:
            st.session_state.messages.append({"role": "user", "content": user_input})
            # Output User input in chat Message
            with st.chat_message("user", avatar = None):
                st.markdown(user_input)
            # Returned response of user query
            assistant_response = process_chat_response(user_input, tools=tools, language=language)
            # This part checks if entities captured or irrelvant query and routes accordingly
            if isinstance(assistant_response, Dict):
                # We have merged arguments due to entities in user query, move to validate stage
                st.session_state.merged_arguments = assistant_response
                with st.chat_message("assistant", avatar = "💁"):
                    stream_message("I've processed your request. Let me validate the information.")
                    st.session_state.current_stage = "validate"
                    st.rerun()
            else:
                # We have a string response, display it and stay in 'ask' stage
                with st.chat_message("assistant", avatar = "💁"):
                    stream_message(assistant_response)
                    st.session_state.messages.append({"role": "assistant", "content": assistant_response})

            print(f"stage of code: From process()-> main: {st.session_state.current_stage}")

    # Validate stage
    elif st.session_state.current_stage == "validate":
        st.chat_input("...", disabled = True)
        # Call the validate_input_query() function to validate the entities parsed by LLM
        validation_result = validate_input_query(st.session_state.get("merged_arguments", {}))
        if validation_result == "validation_failed":
            st.session_state.current_stage = "correct"
            st.rerun()
        else:
            # Check if a user wants to modify/add info before segment creation
            with st.chat_message("assistant", avatar = "💁"):
                stream_message(validation_result)
            st.session_state.current_stage = "ask_additional"
            st.rerun()

    # If the user wants to enter more additional info- ask_additional stage
    elif st.session_state.current_stage == "ask_additional":
        # Disable the main chat for this stage
        st.chat_input("Choose one of the options from above", disabled = True)
        # Add two buttons for User actions
        with st.chat_message("assistant", avatar = "💁"):
            stream_message("Would you like to add more details or finalize?")
            col1, col2 = st.columns(2)
            with col1: # Add Button for No Addition
                if st.button("No, I'm good"): 
                    st.session_state.current_stage = "ask"
                    stream_message("Alright, let's get segment created!")
                    st.rerun()
            with col2: # Add Button if user wants to add more
                if st.button("Add More"):
                    st.session_state.current_stage = "ask"
                    stream_message("Sure, what else would you like to add?")
                    st.rerun()

    # Correction stage
    elif st.session_state.current_stage == "correct":
        # Disable the main chat for this stage
        st.chat_input("Correct Errors ", disabled = True)
        print("stage of code: From Validation -> Correct stage")
        if "errors" in st.session_state and st.session_state.errors:
            with st.chat_message("assistant", avatar = "💁"):
                stream_message("The following fields need fixing:")  # Display all the errors
                for error in st.session_state.errors:
                    stream_message(f"- **{error['field']}**: {error['message']}")
            # Add details to st.form
            with st.form("correction_form"):
                st.write("Please provide valid values for the following fields:")
                corrected_fields = {}
                # Iterating one by one for each field correction
                for error in st.session_state.errors:
                    field = error['field']
                    placeholder = f"Enter a valid value for {field}"
                    if "Missing required field" not in error['message']:
                        placeholder += f" ({error['message']})"
                    corrected_value = st.text_input(
                        f"{field}:", 
                        key=f"correct_{field}",
                        placeholder=placeholder
                    )
                    if corrected_value:
                        corrected_fields[field] = corrected_value
                    
                # User action to submit all corrections
                submit_button = st.form_submit_button("Submit Corrections")
            
            if submit_button:
                # Update merged_arguments with corrected values
                st.session_state.merged_arguments.update(corrected_fields)
                # Clear errors and go back to validate stage
                st.session_state.errors = []
                st.session_state.current_stage = "validate"
                with st.chat_message("assistant", avatar = "💁"):
                    stream_message("Thank you for the corrections. I'm validating your input now.")
                st.rerun()
        else:
            with st.chat_message("assistant", avatar = "💁"):
                stream_message("No errors to correct. Returning to main chat.")
            st.session_state.current_stage = "ask"
            st.rerun()

    # Rerun the app once at the end of the function
    #st.rerun()

process_input.py


from src.platform_intelligence.language.input_query_schema import InputQuery
from models.openai.schema_to_OAI_spec import transform_schema_to_OAI_spec
from models.openai.openai_model import chat_completion_request

parent_dir= Path(__file__).resolve().parent.parent.parent.parent
config_path = parent_dir/'config'

# Reading the SYSTEM PROMPT config file
with open(f'{config_path}/prompts/segment_system_prompt.yml', 'r') as file:
        SYSTEM_PROMPT = yaml.safe_load(file)['PROMPT_TEMPLATE']

# Reading the VERIFIER LLM PROMPT config file
with open(f'{config_path}/prompts/verifier_agent_prompt.yml', 'r') as file:
        VERIFIER_AGENT_PROMPT= yaml.safe_load(file)['PROMPT_TEMPLATE']

# Define a global chat history variable to hold chat history
chat_history: List[Dict[str, str]] = []
# Add the tools for the Function Calling
# Transform the schema to OpenAI spec if needed for GPT processing
inputquery_spec = transform_schema_to_OAI_spec(InputQuery)
tools = [inputquery_spec]


def process_chat_response(
    user_query: str, 
    system_prompt: str = SYSTEM_PROMPT, 
    tools: Any = None, 
    schema_model: BaseModel = InputQuery, 
    language: str = "English"
) -> Union[BaseModel, str]:
    """
    Processes the chat response from a completion request and outputs function details and arguments.

    This function sends a list of messages to a chat completion request, processes the response to extract 
    function calls and arguments, and prints relevant information. It also merges function arguments and 
    initializes an `InputQuery` object using the merged arguments.

    Args:
        user_query (str): Query entered by user
        system_prompt (str): If a user has any specific prompt to enter. 
        tools (Any): The tools to be used with the chat completion request.
        schema_model (BaseModel): The Pydantic model class used for validating user query.
        language (str): Language in which you want to a response

    Returns:
        response_output (Union[BaseModel, str]): Returns the response for the query.
    """
    
    # Ensure chat_history is initialized in session state
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []
    
    # Convert chat history into a formatted string
    chat_history_str = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in st.session_state.chat_history])
    # Format the system prompt with the chat history
    formatted_system_prompt = system_prompt.format(chat_history=chat_history_str, language=language)

    messages = [
        {"role": "system", "content": formatted_system_prompt},
        {"role": "user", "content": user_query}
    ]

    print(f'Full Prompt: {messages}')

    response = chat_completion_request(messages, tools=tools, response_format={"type": "text"})

    print(f'\n{response}')

    if response.choices[0].finish_reason == "tool_calls":
        merged_arguments = defaultdict(lambda: None)
        for tool_call in response.choices[0].message.tool_calls:
            function_arguments = json.loads(tool_call.function.arguments)
            merged_arguments.update(function_arguments)

        merged_arguments = dict(merged_arguments)
        
        print()
        print(f'function call arguments: {function_arguments}')
        print(f"Merged Arguments: {merged_arguments}")

        st.session_state.merged_arguments = merged_arguments  # Store for validation

        # Convert merged_arguments to a JSON-like string and escape curly braces
        merged_arguments_str = str(merged_arguments).replace("{", "{{").replace("}", "}}")
        # Append the user's query and the assistant's response to the chat history
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": merged_arguments_str})

        # Verifying the Output with Verifier LLM Agent
        #verifier_response = verifier_agent_response(user_query, merged_arguments, tools)
        print()
        #print(f"Verifier LLM Agent Response: {verifier_response}")

        print()
        print(f"stage of code: From process()-> Validate_query()")

        return merged_arguments

    elif response.choices[0].finish_reason == 'stop' and response.choices[0].message.content is not None:
        final_response = response.choices[0].message.content.strip()
        # Append the user's query and the assistant's response to the chat history
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": final_response})

        return final_response


def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> Union[BaseModel, str]:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
        merged_arguments (Dict): The parsed arguments from previous LLM
        schema_model (BaseModel): InputQuery Schema
        
    Returns:
        Union[BaseModel, str]: Verified Pydantic Model or string if errors
    """ 

    print(f"Current session_state: {st.session_state.current_stage}")
    print(f"stage of code: Entered Validate_query()")

    try:
        validated_query = schema_model(**merged_arguments)
        validated_message = "Great! Here's what we learned from your intent:\n\n"
        # Convert the validated Pydantic model into a dictionary for easier iteration
        for key, value in validated_query.dict().items():
            validated_message += f"**{key}**: {value}\n\n"
        # Return the output message with validated schema 
        return validated_message

    except ValidationError as e:
        # Assign a list of dict to capture all validation errors
        st.session_state.errors = []
        print(f"All Validation Errors: {e.errors()}")
        print()
        # Capture each error from ValidationError and append to st.session_state.errors
        for error in e.errors():
            if error['loc']:
                field = error.get('loc', [None])[0]  # Field name
                error_message = error.get('msg', 'Unknown error')  # Error message for field
                # Store the error in session state
                st.session_state.errors.append({"field": field, "message": error_message})
            else:
                print(f"Missing Mandatory Fields Error: {error['msg']}")
                missing_fields_message = error['msg']
                if "Missing required fields" in missing_fields_message:
                    missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                    # Add an error for each missing field
                    for field in missing_fields:
                        # Store each missing field error in the same format as field errors
                        st.session_state.errors.append({
                            "field": field,
                            "message": f"Missing required field: {field}"
                        })
        print(f"All stored errors in session_state: {st.session_state.errors}")

        return "validation_failed"

You could create another Session State variable (st.session_state.temp_messages) to work with at the end of the permanent messages (and then clear it out when you need to. )

not sure I understood this. I mainly now need help from u on the smooth transition from stage to stage as it displays on screen.

I have learnt everything from u last few days by spending hours in dissecting ur earlier code to understand what st.rerun() does and that made me successfully get the chat input issue resolved. I think I have learnt how to now use session_state to control that and that works perfectly.

The main issue I have now is, my printing messages, other widgets that are shown when a stage is active, is kind of jumpy and very fast on screen which is leading to bad UX to some extent. Here’s my last modified code that has st.rerun() at various stages and other widgets in some (like st.form).

How do I make this stage transition reflect smooth and slowly in chat_message()

code:

from typing import Dict

import streamlit as st

from src.platform_intelligence.language.process_input import process_chat_response, tools, validate_input_query
from src.utils.display_utility import stream_message, display_validated_response

def run_chatbot():
    """Run the Streamlit chatbot interface."""

    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hello👋 , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]

    if "current_stage" not in st.session_state:
        st.session_state.current_stage = 'ask'
    
    if "merged_arguments" not in st.session_state:
        st.session_state.merged_arguments = {}
    
    # for message in st.session_state.messages:
    #     with st.chat_message(message["role"], avatar="💁" if message["role"] == "assistant" else None):
    #         st.markdown(f'{message["content"]}', unsafe_allow_html=True)
    
    # display archived messages from history to chat
    for message in st.session_state.messages:
        if message["role"] == "assistant":
            with st.chat_message(message["role"], avatar = "💁"):
                st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

    if st.session_state.current_stage == 'ask':
        user_input = st.chat_input("Enter your Query", key="main_chat_key")
        # Handle chatbot flow using chat_input
        if user_input:
            st.session_state.messages.append({"role": "user", "content": user_input})
            # Output User input in chat Message
            with st.chat_message("user", avatar = None):
                st.markdown(user_input)
            # Returned response of user query
            assistant_response = process_chat_response(user_input, tools=tools, language=language)
            # This part checks if entities captured or irrelvant query and routes accordingly
            if isinstance(assistant_response, Dict):
                # We have merged arguments due to entities in user query, move to validate stage
                st.session_state.merged_arguments = assistant_response
                # with st.chat_message("assistant", avatar = "💁"):
                #     #stream_message("I've processed your request. Let me validate the information.")
                #     message = "Great, Thank You! Here's what we got from your intent:\n\n"
                #     for key, value in assistant_response.items():
                #         message += f"\n**{key}**: {value}\n\n"
                #     message += "Let me also validate this information. Give me a moment!"
                #     stream_message(message)
                    #st.session_state.messages.append({"role": "assistant", "content": message})
                st.session_state.current_stage = "validate"
                st.rerun()
            else:
                # We have a string response, display it and stay in 'ask' stage
                with st.chat_message("assistant", avatar = "💁"):
                    stream_message(assistant_response)
                    st.session_state.messages.append({"role": "assistant", "content": assistant_response})

            print(f"stage of code: From process()-> main: {st.session_state.current_stage}")

    # Validate stage
    elif st.session_state.current_stage == "validate":
        st.chat_input("...", disabled = True)
        # Spinner to say validating
        st.spinner("Validating...")
        # Call the validate_input_query() function to validate the entities parsed by LLM
        validation_result = validate_input_query(st.session_state.get("merged_arguments", {}))
        if validation_result == "validation_failed":
            st.session_state.current_stage = "correct"
            st.divider()
            st.rerun()
        else:

            st.session_state.merged_arguments = validation_result  # Update with validated dict
            # Check if a user wants to modify/add info before segment creation
            with st.chat_message("assistant", avatar = "💁"):
                #stream_message(validation_result)
                message = "Great,Here's what we have learnt from your query intent\n\n"
                # for key, value in validation_result.items():
                #     message += f"\n**{key}**: {value}\n\n"
                stream_message(message)
                display_validated_response(validation_result)
            st.session_state.messages.append({"role": "assistant", "content": message})
            st.session_state.current_stage = "ask_additional"
            st.rerun()

    # If the user wants to enter more additional info- ask_additional stage
    elif st.session_state.current_stage == "ask_additional":
        # Disable the main chat for this stage
        st.chat_input("Choose one of the options from above", disabled = True)
        # Add two buttons for User actions
        with st.chat_message("assistant", avatar = "💁"):
            stream_message("Would you like to add more details or finalize?")
            col1, col2 = st.columns(2)
            with col1: # Add Button for No Addition
                if st.button("No, I'm good"):
                    st.session_state.current_stage = "complete"
                    st.rerun()
            with col2: # Add Button if user wants to add more
                if st.button("Add More"):
                    st.session_state.current_stage = "ask"
                    stream_message("Sure, what else would you like to add?")
                    st.rerun()

    # Completion Stage if user doesn't have to add any more details
    elif st.session_state.current_stage == "complete":
        st.chat_input("Almost there, finalize segment...", disabled = True)
        with st.chat_message("assistant", avatar = "💁"):
            message = "Great, here's what we finally have:\n\n"
            # for key, value in st.session_state.merged_arguments.items():
            #     message += f"**{key}**: {value}\n\n"
            # message += "Let's get your segment created!"
            stream_message(message)
            display_validated_response(st.session_state.merged_arguments)
            stream_message("\n\nLet's get your segment created!")
            if st.button("Click to Create Segment & See Details"):
                st.session_state.current_stage = "segment_details"
                st.rerun()

    # Segment Details stage- happen when user clicks Create Segment Button
    elif st.session_state.current_stage == "segment_details":
        st.chat_input("Head over to Segment Explainability tab...", disabled = True)
        st.sidebar.success("Segment Details")
        st.title("Segment Explainability Details")
        # Add your segment explainability details here
        st.write("This is where you would display the segment explainability details.")

    # Correction stage
    elif st.session_state.current_stage == "correct":
        # Disable the main chat for this stage
        st.chat_input("Correct Errors ", disabled = True)
        print("stage of code: From Validation -> Correct stage")
        if "errors" in st.session_state and st.session_state.errors:
            with st.chat_message("assistant", avatar = "💁"):
                stream_message("Ok, so Looks like some info is missing to create a good segment:")  # Display all the errors
                for error in st.session_state.errors:
                    stream_message(f"- **{error['field']}**: {error['message']}")
            # Add details to st.form
            with st.form("correction_form"):
                st.write("Please provide valid values for the following fields:")
                corrected_fields = {}
                # Iterating one by one for each field correction
                for error in st.session_state.errors:
                    field = error['field']
                    placeholder = f"Enter a valid value for {field}"
                    if "Missing required field" not in error['message']:
                        placeholder += f" ({error['message']})"
                    corrected_value = st.text_input(
                        f"{field}:", 
                        key=f"correct_{field}",
                        placeholder=placeholder
                    )
                    if corrected_value:
                        corrected_fields[field] = corrected_value
                    
                # User action to submit all corrections
                submit_button = st.form_submit_button("Submit Corrections")
            
            if submit_button:
                # Update merged_arguments with corrected values
                st.session_state.merged_arguments.update(corrected_fields)
                # Clear errors and go back to validate stage
                st.session_state.errors = []
                st.session_state.current_stage = "validate"
                with st.chat_message("assistant", avatar = "💁"):
                    stream_message("Thank you for the corrections. I'm validating your input now.")
                st.rerun()
        else:
            with st.chat_message("assistant", avatar = "💁"):
                stream_message("No errors to correct. Returning to main chat.")
            st.session_state.current_stage = "ask"
            st.rerun()

    # Rerun the app once at the end of the function
    #st.rerun()

Can you share a video that demonstrates the jumpiness?

The general solution for that kind of thing is to use containers so you can push the variation into a container and give the app a more stable view since the containers are like boxes for you to store varying pieces in.

I sent u the video.

Also, was wondering do u know how can one add audio input along with chat_input? Is there a way?