How to create multiple chat_input in same app

Hello

I am trying to use chat_input in multiple places in my two script code, depending upon wherever we want user to enter some value.

However it is automatically throwing None value if it tries to create another chat_input if earlier is already created. I am using unique key for every widget.

How do I solve issue where a user needs to enter input via chatbot multiple times depending upon what is prompted to them in chatbox

Are your chat inputs in a container? If you call st.chat_input without a container, the widget will pin to the bottom of the app (which doesnā€™t work for multiple inputs). If you put st.chat_input in a container, you can have multiples:

import streamlit as st

cols = st.columns(4)

with cols[0]:
    st.write(st.session_state.get("chat1",None))
    st.chat_input("Chat 1", key="chat1")

with cols[1]:
    st.write(st.session_state.get("chat2",None))
    st.chat_input("Chat 1", key="chat2")

with cols[2]:
    st.write(st.session_state.get("chat3",None))
    st.chat_input("Chat 1", key="chat3")

with cols[3]:
    st.write(st.session_state.get("chat4",None))
    st.chat_input("Chat 1", key="chat4")

No. I am using main page chat input.

Problem is I have two .py script. One is a main script which first creates the chatbot with an initial user input.

The second script takes that input, validates it for few things and then during the validation if there is error it asks user to enter correct value. This it does in multiple places in that script

To get that user input again few times on being prompted I was adding chat_input call again with all unique keys. But it is automatic returning None value which is breaking the entire flow. How do we get a chatbot where a user can enter multiple times on being prompted.

I need to have one chat interface only. My use case isnā€™t requiring multiple chat input in containers. Itā€™s one chatbot talking to the user and could have multiple inputs needed by user on being prompted in code.

How do we do that with chat_input. Donā€™t think that use case hasnā€™t been taken care by ppl. Itā€™s the basic use case.

Iā€™m not sure I precisely understand the workflow you want but here is an example gathering a list of strings, and prompting the user to provide replacements for any ā€œxā€ in the string.

import streamlit as st

if "history" not in st.session_state:
    st.session_state.history = []
    st.session_state.pending = None

def save_new():
    st.session_state.pending = st.session_state.chat

def update_pending(before, after):
    st.session_state.pending = before+st.session_state.chat+after

st.header("History")
for message in st.session_state.history:
    st.chat_message("user").write(message)

if st.session_state.pending:
    st.markdown(f"Validating {st.session_state.pending}")
    x = st.session_state.pending.find("x")
    if x<0:
        st.session_state.history.append(st.session_state.pending)
        st.session_state.pending = None
        st.rerun()
    before = st.session_state.pending[:x]
    after = st.session_state.pending[x+1:] if x<len(st.session_state.pending)-1 else ""
    st.chat_input(
        f"Replace the x after '{before}' and '{after}'.", 
        key="chat", 
        on_submit=update_pending, 
        args=[before, after]
    )

else:
    st.chat_input("Type a string without an 'x' in it.", key="chat", on_submit=save_new)

main.py

import streamlit as st
from src.platform_intelligence.language.process_input_query import process_chat_response, tools
from src.utils.display_utility import stream_message

def run_chatbot():
    # Modularize chatbot UI elements
       
    """Run the Streamlit chatbot interface."""
    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "HellošŸ‘‹ , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]
    
    for message in st.session_state.messages:
        if message["role"] == "assistant":
            with st.chat_message(message["role"], avatar = "šŸ’"):
                st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            with st.chat_message(message["role"]):
                st.markdown(f':black[{message["content"]}]')

    if user_input := st.chat_input("Ask something", key = "user_input_main"):
        with st.chat_message("user", avatar = None):
            st.markdown(user_input)
        st.session_state.messages.append({"role": "user", "content": user_input})
        assistant_response = process_chat_response(user_input, tools = tools, language = language).strip()
        
        # Stream assistant's response in real-time
        stream_message(assistant_response)

validate.py


# Utility to stream messages one by one in the chat interface
def stream_message(response_text: str, avatar: str = "šŸ’", delay: float = 0.07, max_words: int = 25) -> None:
    """
    Streams response text in the chat interface with a typing effect, simulating ChatGPT's output.

    Args:
        response_text (str): The full response to be displayed.
        avatar (str): Avatar for the assistant in the chat.
        delay (float): Time delay between streaming each word.
        max_words (int): Maximum words before inserting a line break.
    """
    
    # Store the assistant's message in session state
    st.session_state.messages.append({"role": "assistant", "content": response_text})

    with st.chat_message("assistant", avatar=avatar):
        message_placeholder = st.empty()
        full_response = ""
        words_count = 0

        # Split response_text while preserving the markdown (** for bold text)
        words = response_text.split(' ')
        for i, chunk in enumerate(words):
            # Add space after each word
            full_response += chunk + " "
            words_count += 1
            
            # Insert line break after max_words or when punctuation ends a sentence
            if words_count >= max_words and chunk.endswith(".") and i < len(words) - 1:
                full_response += "\n\n" # Insert line break for readability
                words_count = 0
            
            # Sleep to simulate delay in typing
            time.sleep(delay)

            # Render the message with markdown (including bold text)
            # message_placeholder.markdown(
            #     f'<div style="color: black;">{full_response} ā–Œ</div>', 
            #     unsafe_allow_html=True
            # )

            message_placeholder.markdown(
                f'{full_response} ā–Œ'
            )
            

        # After streaming is done, render the final message without the typing cursor
        # message_placeholder.markdown(
        #     f'<div style="color: black;">{full_response}</div>', 
        #     unsafe_allow_html=True
        # )

        message_placeholder.markdown(full_response)


def get_user_input(message: str, key: str) -> str:
    """
    Handles user input via the chat interface, displays it, and stores it in session state.

    Args:
        prompt (str): The message shown to the user as a prompt for input.

    Returns:
        str: The user input.
    """
    # Stream the prompt message
    stream_message(message)

    print(f'key: {key}')

    # Accept user input via chat input
    user_input = st.chat_input(message, key = key)

    print(f"user_input : {user_input}")

    # Display the user input back in the chat
    with st.chat_message("user", avatar=None):
        st.markdown(user_input)

    # Store the user input in session state for persistence
    st.session_state.messages.append({"role": "user", "content": user_input})

    return user_input




def get_unique_key(prefix: str) -> str:
    """
    Generates a unique key by using a session-based counter.
    
    Args:
        prefix (str): A prefix string to identify the type of input.
    
    Returns:
        str: A unique key combining the prefix and the current step counter.
    """

    if "input_counter" not in st.session_state:
        st.session_state.input_counter = 0  # Initialize the counter
    print(st.session_state.input_counter)
    st.session_state.input_counter += 1  # Increment the counter for each input field
    return f"{prefix}_{st.session_state.input_counter}"


def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
        merged_arguments (Dict): The parsed arguments from previous LLM
        schema_model (BaseModel): InputQuery Schema
        
    Returns:
        BaseModel: Verified Pydantic Model
    """ 

    while True:
        errors = {}
        try:
            validated_query = schema_model(**merged_arguments)
            # Convert the validated Pydantic model into a dictionary for easier iteration
            validated_data = validated_query.dict()
            message = "Great! Here's what we learnt from your intent:\n\n"
            # Stream each key-value pair one by one with bold keys
            for key, value in validated_data.items():
                message += f"**{key}**: {value}\n\n"  # Key is bolded

            stream_message(message, avatar="šŸ’")  # Stream the full message to the chat

            print(f'validated query in validate_input_query:{validated_query}\n')
            break  # Exit loop if validation is successful
        except ValueError as e:
            # Collect all field errors
            print(f'e value: {e}')
            for error in e.errors():
                print(f'error: {error}')
                if error['loc']:
                    field = error['loc'][0]
                    error_message = error['msg']
                    errors[field] = error_message
                else:
                    # Handle general errors (like missing required fields)
                    print(f"General validation error: {error['msg']}")
                    missing_fields_message = error['msg']
                    if "Missing required fields" in missing_fields_message:
                        missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                        for field in missing_fields:
                            # Prompt the user for missing mandatory fields using get_user_input()
                            user_input = get_user_input(f"**{field} is mandatory. Please provide a value for '{field}':**", key = get_unique_key(f"mandatory_field_{field}"))
                            if user_input is None:
                                continue
                            else:
                                merged_arguments[field] = user_input
                                st.session_state.chat_history.append({"role": "user", "content": user_input})

            if errors:
                # Show all specific errors to the user at once
                #print("The following fields needs fixing:")
                stream_message("The following fields need fixing:\n\n")
                for field, error_message in errors.items():
                    #print(f"- {field}: {error_message}")
                    stream_message(f"- **{field}**: {error_message}\n\n")

                # Prompt user to correct all specific errors together
                for field , error_message in errors.items():
                    # Ask for input for the field
                    stream_message(f"Please provide a valid value for '{field}':\n\n")
                    user_input = get_user_input(f"Please provide a valid value for '{field}':", key = get_unique_key(f"correct_field_{field}"))
                    if user_input is not None:
                        merged_arguments[field] = user_input
                        st.session_state.chat_history.append({"role": "user", "content": user_input})

    while True:
        user_choice = get_user_input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ", 
        key = get_unique_key(f"user_choice_key"))

        if user_choice is None:
                continue  # Wait for valid input
        
        if user_choice == 'n':
            return validated_query  

        elif user_choice == '1':
            additional_info = get_user_input("Please enter additional details:", key = get_unique_key(f"additional_info_key"))
            if additional_info:
                validated_query = process_chat_response(additional_info, tools = tools)
                if(isinstance(validated_query, str)):
                    print(f'Yes {validated_query} is string')
                    continue
            break

        elif user_choice == '2':
            field_to_correct = get_user_input("Enter the field name you want to correct:", key = get_unique_key(f"field_to_correct_key"))

            if field_to_correct in merged_arguments:
                new_value = get_user_input(f"Enter the new value for '{field_to_correct}':", key = get_unique_key(f"new_value_{field_to_correct}"))
                if new_value is not None:
                    merged_arguments[field_to_correct] = new_value
                    validated_query = schema_model(**merged_arguments)  # Re-validate with the updated field
                    break
            else:
                stream_message(f"Field '{field_to_correct}' is not present in the current query.")
        else:
            stream_message("Invalid option. Please enter 1, 2, or 'n'.")

Here are my two scripts. Main.py first calls and takes an input and creates chatbot.

The input is then passed to process_chat_response which calls validated_input_query() method within which get_user_input() is used multiple places to get input from user when prompted.

If you see my validate_input_query() function, it has multiple place get_user_input() function, which as defined above, puts a new chat_input() for user to enter when prompted.

This part needs to change as chat_input() has aleady been added in main.py when the script kickstarts the whole thing

Can you put together a minimal, executable script that highlights the problem? (e.g. I created a toy example with a simple ā€œreplace xā€ validation.)

Chat input always initializes to None and reverts to None after only returning a value once (much like a button is a trigger and isnā€™t stateful). In my example, I stored a chat reply in Session State to loop through processing before clearing it and submitting it to the history.

You might want to look at st.dialog as wellā€”to run this validation and additional prompts in a modal dialog.

I have tried to put together a script which has all functions that are called. The script itself might not run as it need few more global variables and some LLM API calls, but it will give u how the bottom most function run_chatbot() is first run to start the chat_input() and Chat UI.

That then passed the first input text to process_chat_response() within which is the validate_input_query() function which is where other chat_input() are getting called via get_user_input() function defined on the top. Each get_input_user() tht gets called in validate_input_query() has st.chat_input() to ask for user input when prompted in validate_input_query() in stream_message() statement (which is nothing but a common function for chat_message()).

# Utility to stream messages one by one in the chat interface
import streamlit as st

def stream_message(response_text: str, avatar: str = "šŸ’", delay: float = 0.07, max_words: int = 25) -> None:
    """
    Streams response text in the chat interface with a typing effect, simulating ChatGPT's output.

    Args:
        response_text (str): The full response to be displayed.
        avatar (str): Avatar for the assistant in the chat.
        delay (float): Time delay between streaming each word.
        max_words (int): Maximum words before inserting a line break.
    """
    
    # Store the assistant's message in session state
    st.session_state.messages.append({"role": "assistant", "content": response_text})

    with st.chat_message("assistant", avatar=avatar):
        message_placeholder = st.empty()
        full_response = ""
        words_count = 0

        # Split response_text while preserving the markdown (** for bold text)
        words = response_text.split(' ')
        for i, chunk in enumerate(words):
            # Add space after each word
            full_response += chunk + " "
            words_count += 1
            
            # Insert line break after max_words or when punctuation ends a sentence
            if words_count >= max_words and chunk.endswith(".") and i < len(words) - 1:
                full_response += "\n\n" # Insert line break for readability
                words_count = 0
            
            # Sleep to simulate delay in typing
            time.sleep(delay)

            message_placeholder.markdown(
                f'{full_response} ā–Œ'
            )
            

        message_placeholder.markdown(full_response)


def get_user_input(message: str, key: str) -> str:
    """
    Handles user input via the chat interface, displays it, and stores it in session state.

    Args:
        prompt (str): The message shown to the user as a prompt for input.

    Returns:
        str: The user input.
    """
    # Stream the prompt message
    stream_message(message)

    print(f'key: {key}')

    # Accept user input via chat input
    user_input = st.chat_input(message, key = key)

    print(f"user_input : {user_input}")

    # Display the user input back in the chat
    with st.chat_message("user", avatar=None):
        st.markdown(user_input)

    # Store the user input in session state for persistence
    st.session_state.messages.append({"role": "user", "content": user_input})

    return user_input


def get_unique_key(prefix: str) -> str:
    """
    Generates a unique key by using a session-based counter.
    
    Args:
        prefix (str): A prefix string to identify the type of input.
    
    Returns:
        str: A unique key combining the prefix and the current step counter.
    """

    if "input_counter" not in st.session_state:
        st.session_state.input_counter = 0  # Initialize the counter
    print(st.session_state.input_counter)
    st.session_state.input_counter += 1  # Increment the counter for each input field
    return f"{prefix}_{st.session_state.input_counter}"

def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
        merged_arguments (Dict): The parsed arguments from previous LLM
        schema_model (BaseModel): InputQuery Schema
        
    Returns:
        BaseModel: Verified Pydantic Model
    """ 

    while True:
        errors = {}
        try:
            validated_query = schema_model(**merged_arguments)
            # Convert the validated Pydantic model into a dictionary for easier iteration
            validated_data = validated_query.dict()
            message = "Great! Here's what we learnt from your intent:\n\n"
            # Stream each key-value pair one by one with bold keys
            for key, value in validated_data.items():
                message += f"**{key}**: {value}\n\n"  # Key is bolded

            stream_message(message, avatar="šŸ’")  # Stream the full message to the chat

            print(f'validated query in validate_input_query:{validated_query}\n')
            break  # Exit loop if validation is successful
        except ValueError as e:
            # Collect all field errors
            print(f'e value: {e}')
            for error in e.errors():
                print(f'error: {error}')
                if error['loc']:
                    field = error['loc'][0]
                    error_message = error['msg']
                    errors[field] = error_message
                else:
                    # Handle general errors (like missing required fields)
                    print(f"General validation error: {error['msg']}")
                    missing_fields_message = error['msg']
                    if "Missing required fields" in missing_fields_message:
                        missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                        for field in missing_fields:
                            # Prompt the user for missing mandatory fields using get_user_input()
                            user_input = get_user_input(f"**{field} is mandatory. Please provide a value for '{field}':**", key = get_unique_key(f"mandatory_field_{field}"))
                            if user_input is None:
                                continue
                            else:
                                merged_arguments[field] = user_input
                                st.session_state.chat_history.append({"role": "user", "content": user_input})

            if errors:
                # Show all specific errors to the user at once
                #print("The following fields needs fixing:")
                stream_message("The following fields need fixing:\n\n")
                for field, error_message in errors.items():
                    #print(f"- {field}: {error_message}")
                    stream_message(f"- **{field}**: {error_message}\n\n")

                # Prompt user to correct all specific errors together
                for field , error_message in errors.items():
                    # Ask for input for the field
                    stream_message(f"Please provide a valid value for '{field}':\n\n")
                    user_input = get_user_input(f"Please provide a valid value for '{field}':", key = get_unique_key(f"correct_field_{field}"))
                    if user_input is not None:
                        merged_arguments[field] = user_input
                        st.session_state.chat_history.append({"role": "user", "content": user_input})

    while True:
        user_choice = get_user_input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ", 
        key = get_unique_key(f"user_choice_key"))

        if user_choice is None:
                continue  # Wait for valid input
        
        if user_choice == 'n':
            return validated_query  

        elif user_choice == '1':
            additional_info = get_user_input("Please enter additional details:", key = get_unique_key(f"additional_info_key"))
            if additional_info:
                validated_query = process_chat_response(additional_info, tools = tools)
                if(isinstance(validated_query, str)):
                    print(f'Yes {validated_query} is string')
                    continue
            break

        elif user_choice == '2':
            field_to_correct = get_user_input("Enter the field name you want to correct:", key = get_unique_key(f"field_to_correct_key"))

            if field_to_correct in merged_arguments:
                new_value = get_user_input(f"Enter the new value for '{field_to_correct}':", key = get_unique_key(f"new_value_{field_to_correct}"))
                if new_value is not None:
                    merged_arguments[field_to_correct] = new_value
                    validated_query = schema_model(**merged_arguments)  # Re-validate with the updated field
                    break
            else:
                stream_message(f"Field '{field_to_correct}' is not present in the current query.")
        else:
            stream_message("Invalid option. Please enter 1, 2, or 'n'.")


def process_chat_response(
    user_query: str, 
    system_prompt: str = SYSTEM_PROMPT, 
    tools: Any = None, 
    schema_model: BaseModel = InputQuery, 
    language: str = "English"
    ) -> Union[BaseModel, str]:
    """
    Processes the chat response from a completion request and outputs function details and arguments.

    This function sends a list of messages to a chat completion request, processes the response to extract 
    function calls and arguments, and prints relevant information. It also merges function arguments and 
    initializes an `InputQuery` object using the merged arguments.

    Args:
        user_query (str): Query entered by user
        system_prompt (str): If a user has any specific prompt to enter. 
        tools (Any): The tools to be used with the chat completion request.
        schema_model (BaseModel): The Pydantic model class used for validating user query.
        language (str): Language in which you want to a response

    Returns:
        response_output (Union[BaseModel, str]): Returns the response for the query.
    """

      # Ensure chat_history is initialized in session state
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []
    
     # Convert chat history into a formatted string
    chat_history_str = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in st.session_state.chat_history])
    
    # Format the system prompt with the chat history
    formatted_system_prompt = system_prompt.format(chat_history=chat_history_str, language = language)

    messages = [{"role": "system", "content": formatted_system_prompt}]
    messages.append({"role": "user", "content": user_query})

    print(f'Full Prompt: {messages}')

    print(f'Tools : {tools}')

    response = chat_completion_request(messages, tools=tools, response_format={"type": "text"})

    print(f'\n{response}')

    merged_arguments = defaultdict(lambda: None)

    if response.choices[0].finish_reason == "tool_calls":
        for tool_call in response.choices[0].message.tool_calls:
            function_arguments = json.loads(tool_call.function.arguments)
            merged_arguments.update(function_arguments)

        merged_arguments = dict(merged_arguments)
        
        print()
        print(f'function call arguments: {function_arguments}')
        print(f"Merged Arguments: {merged_arguments}")

        # Convert merged_arguments to a JSON-like string and escape curly braces
        merged_arguments_str = str(merged_arguments).replace("{", "{{").replace("}", "}}")
        
        # Append the user's query and the assistant's response to the chat history
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": merged_arguments_str})
        
        # Verifying the Output with Verifier LLM Agent
        #verifier_response = verifier_agent_response(user_query, merged_arguments, tools)
        print()
        #print(f"Verifier LLM Agent Response: {verifier_response}")

        # Validate the InputQuery object with re-prompting if necessary
        final_response = validate_input_query(merged_arguments, schema_model)
        print(f"Process Chat Final Response: {final_response}")

    elif response.choices[0].finish_reason == 'stop' and response.choices[0].message.content is not None:
        final_response = response.choices[0].message.content.strip()

         # Verifying the Output with Verifier LLM Agent
        #verifier_response = verifier_agent_response(user_query, final_response, tools)
        print()
        #print(f"Verifier LLM Agent Response: {verifier_response}")

        # Append the user's query and the assistant's response to the chat history
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": final_response})

        #print(f'\n{final_response}')

    print(f"chat history: {st.session_state.chat_history}")
    
    return final_response


def run_chatbot():
    # Modularize chatbot UI elements
       
    """Run the Streamlit chatbot interface."""
    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "HellošŸ‘‹ , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]
    
    for message in st.session_state.messages:
        if message["role"] == "assistant":
            with st.chat_message(message["role"], avatar = "šŸ’"):
                st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            with st.chat_message(message["role"]):
                st.markdown(f':black[{message["content"]}]')

    if user_input := st.chat_input("Ask something", key = "user_input_main"):
        with st.chat_message("user", avatar = None):
            st.markdown(user_input)
        st.session_state.messages.append({"role": "user", "content": user_input})
        assistant_response = process_chat_response(user_input, tools = tools, language = language).strip()
        
        # Stream assistant's response in real-time
        stream_message(assistant_response)

If I put simple.

All I had that works well with terminal driven Input() was below code. This is where u see I am prompting user multiple times.

All I want is to replace Input with streamlit chat_input driven UI.

The first place where chat_input() is called in main script which sends output to here later

def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
        merged_arguments (Dict): The parsed arguments from previous LLM
        schema_model (BaseModel): InputQuery Schema
        
    Returns:
        BaseModel: Verified Pydantic Model
    """ 

    while True:
        errors = {}
        try:
            validated_query = schema_model(**merged_arguments)
            print(f'validated query in validate_input_query:{validated_query}\n')
            break  # Exit loop if validation is successful
        except ValueError as e:
            # Collect all field errors
            print(f'e value: {e}')
            for error in e.errors():
                print(f'error: {error}')
                if error['loc']:
                    field = error['loc'][0]
                    error_message = error['msg']
                    errors[field] = error_message
                else:
                    # Handle general errors (like missing required fields)
                    print(f"General validation error: {error['msg']}")
                    missing_fields_message = error['msg']
                    if "Missing required fields" in missing_fields_message:
                        missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                        for field in missing_fields:
                            user_input = input(f" {field} is mandatory. Please provide a value for '{field}': ")
                            merged_arguments[field] = user_input
                            chat_history.append({"role": "user", "content": user_input})

            if errors:
                # Show all specific errors to the user at once
                print("The following fields needs fixing:")
                for field, error_message in errors.items():
                    print(f"- {field}: {error_message}")

                # Prompt user to correct all specific errors together
                for field , error_message in errors.items():
                    user_input = input(f"{error}- Please provide a valid value for '{field}': ")
                    merged_arguments[field] = user_input
                    chat_history.append({"role": "user", "content": user_input})

    while True:
        user_choice = input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ")
        
        if user_choice == 'n':
        # Finalize and break out of the loop
          break

        elif user_choice == '1':
            additional_info = input("Please enter additional details: ")
            validated_query = process_chat_response(additional_info, tools = tools)
            # merged_arguments.update(additional_entities)
            # validated_query = schema_model(**merged_arguments)  # Re-validate with the new data
            #print(f"Updated Validated Response: {response}")
            if(isinstance(validated_query, str)):
                print(f'Yes {validated_query} is string')
                continue
            break

        elif user_choice == '2':
            field_to_correct = input("Enter the field name you want to correct: ")
            if field_to_correct in merged_arguments:
                new_value = input(f"Enter the new value for '{field_to_correct}': ")
                merged_arguments[field_to_correct] = new_value
                validated_query = schema_model(**merged_arguments)  # Re-validate with the updated field
                #print(f"Updated Validated Response: {validated_query}")
                break
            else:
                print(f"Field '{field_to_correct}' is not present in the current query.")
        
        # elif user_choice.lower() == 'n':
        #     #validated_query = schema_model(**merged_arguments)
        #     break
        else:
            print("Invalid option. Please enter 1, 2, or 'n'.")
    
    return validated_query

Hereā€™s how I took a stab at it by looking at ur example:

  • Replaced as Python input() with get_user_input() defined below. All python stdio print() replaced by a utility function stream_message() which basically does chat_message() with a streaming markdown.
  • The issue I still face is, the moment it goes in first validate_query() and moves to second While Loop asking for choice (2nd part of validate_input_query() code), it starts again throwing the same error of building multiple chat_input even though in get_user_input() I am using sessions_stage and conditional flow.

YOUR HELP IS VERY MUCH APPRECIATED HERE.

As I am stuck on this for days and I feel some nuance I am missing here on session_state to route multiple input in chat_input()

validate_input_query():


def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
        merged_arguments (Dict): The parsed arguments from previous LLM
        schema_model (BaseModel): InputQuery Schema
        
    Returns:
        BaseModel: Verified Pydantic Model
    """ 

    if "current_stage" not in st.session_state:
        st.session_state.current_stage = "validation" 

    # Initialize or retrieve stored errors
    if "errors" not in st.session_state:
        st.session_state.errors = {}

    
    while True:
        errors = {}
        print(f"stage of code: Entered Validate_query()")
        # Validation Stage
        if st.session_state.current_stage == "validation":
            try:
                validated_query = schema_model(**merged_arguments)
                # Convert the validated Pydantic model into a dictionary for easier iteration
                validated_data = validated_query.dict()
                message = "Great! Here's what we learnt from your intent:\n\n"
                # Stream each key-value pair one by one with bold keys
                for key, value in validated_data.items():
                    message += f"**{key}**: {value}\n\n"  # Key is bolded

                stream_message(message, avatar="šŸ’")  # Stream the full message to the chat

                st.session_state.current_stage = "ask_additional"  # Mark validation as complete

                print(f'validated query in validate_input_query:{validated_query}\n')

                print()
                print(f"stage of code: From Validate_query() to Ask_Additional()")

                break

            except ValueError as e:
                # Collect all field errors
                print(f'e value: {e}')
                for error in e.errors():
                    print(f'error: {error}')
                    if error['loc']:
                        field = error['loc'][0]
                        error_message = error['msg']
                        errors[field] = error_message
                    else:
                        # Handle general errors (like missing required fields)
                        print(f"General validation error: {error['msg']}")
                        missing_fields_message = error['msg']
                        if "Missing required fields" in missing_fields_message:
                            missing_fields = missing_fields_message.split(": ")[-1].split(", ")
                            for field in missing_fields:
                                # Prompt the user for missing mandatory fields using get_user_input()
                                user_input = get_user_input(f"**{field} is mandatory. Please provide a value for '{field}':**", key = get_unique_key(f"mandatory_field_{field}"),
                                                            stage = "correction")
                                if user_input is None:
                                    continue
                                else:
                                    merged_arguments[field] = user_input
                                    st.session_state.chat_history.append({"role": "user", "content": user_input})

                if errors:
                    # Show all specific errors to the user at once
                    #print("The following fields needs fixing:")
                    stream_message("The following fields need fixing:\n\n")
                    for field, error_message in errors.items():
                        #print(f"- {field}: {error_message}")
                        stream_message(f"- **{field}**: {error_message}\n\n")

                    # Store the errors for the correction stage
                    st.session_state.errors = errors
                    st.session_state.current_stage = "correction"  # Move to correction stage

        # Correction Stage
        if st.session_state.current_stage == "correction":

            errors = st.session_state.errors  # Retrieve stored errors
            # Prompt user to correct all specific errors together
            for field , error_message in errors.items():
                # Ask for input for the field
                #stream_message(f"Please provide a valid value for '{field}':\n\n")
                user_input = get_user_input(f"Please provide a valid value for '{field}':", key = get_unique_key(f"correct_field_{field}"), stage = "correction")
                if user_input:
                    merged_arguments[field] = user_input
                    st.session_state.chat_history.append({"role": "user", "content": user_input})
                    del st.session_state.errors[field]

             # If all errors are corrected, go back to validation
            if not st.session_state.errors:
                st.session_state.current_stage = "validation"
                continue

        # Additional details or corrections
        #while st.session_state.current_stage == "validation" and not st.session_state.errors:
    while True:
            print()
            print(f"stage of code: Entered Ask_Additional()")
            if st.session_state.current_stage == "ask_additional":
                user_choice = get_user_input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ", 
                key = get_unique_key(f"user_choice_key"), stage = "ask_additional")

                if user_choice is None:
                        continue  # Wait for valid input
                
                if user_choice == 'n':
                    st.session_stage.current_stage = "completed"
                    return validated_query  

                elif user_choice == '1':
                    additional_info = get_user_input("Please enter additional details:", key = get_unique_key(f"additional_info_key"), stage = "ask_additional")
                    if additional_info:
                        validated_query = process_chat_response(additional_info, tools = tools)
                        if(isinstance(validated_query, str)):
                            print(f'Yes {validated_query} is string')
                            continue
                    break

                elif user_choice == '2':
                    field_to_correct = get_user_input("Enter the field name you want to correct:", key = get_unique_key(f"field_to_correct_key"), stage = "ask_additional")

                    if field_to_correct in merged_arguments:
                        new_value = get_user_input(f"Enter the new value for '{field_to_correct}':", key = get_unique_key(f"new_value_{field_to_correct}"), stage = "ask_additional")
                        if new_value is not None:
                            merged_arguments[field_to_correct] = new_value
                            st.session_state.current_stage = "validation"
                            validated_query = schema_model(**merged_arguments)  # Re-validate with the updated field
                            break
                    else:
                        stream_message(f"Field '{field_to_correct}' is not present in the current query.")
                else:
                    stream_message("Invalid option. Please enter 1, 2, or 'n'.")

get_user_input():



def get_user_input(message: str, key: str, stage: str) -> str:
    """
    Handles user input via the chat interface, displays it, and stores it in session state.

    Args:
        message (str): The message shown to the user as a prompt for input.
        key (str): A unique key to ensure a single input widget.
        stage (str): The current stage of the chatbot process, e.g., 'main_chat', 'validation', 'correction'

    Returns:
        str: The user input.
    """

    # Stream the prompt message
    if stage != "main_chat":
        stream_message(message)

    print(f'key: {key}')

    print(f'stage: {stage}')

    if "current_stage" not in st.session_state:
        st.session_state.current_stage = stage

    if st.session_state.current_stage == stage:
        if user_input := st.chat_input(message, key=key):
            print(f"user input: {user_input}")
            # Display the user input back in the chat
            with st.chat_message("user", avatar=None):
                st.markdown(user_input)
            st.session_state.messages.append({"role": "user", "content": user_input})
        
    return user_input

T

There will be very few cases where youā€™ll want a while True loop in a Streamlit app to collect input. Streamlit itself already wraps your script in a loop.

For example, suppose you had a simple Python command line program to collect a list, one item at a time:

my_list = []
while True:
    print("Add item:")
    item = input()
    my_list.append(item)
    print(my_list)

The equivalent Streamlit app would be:

import streamlit as st

if "my_list" not in st.session_state:
    st.session_state.my_list = []

item = st.chat_input("Add item")
if item:
    st.session_state.my_list.append(item)
st.write(st.session_state.my_list)

The critical thing to understand when moving from a command-line app to a Streamlit app is that Streamlit does not ā€œwaitā€ at its inputs. All inputs have a default value and when you call those commands that default value (like None or "") are returned immediately, faster than any user could ever respond. The pattern you want is a logical stop after rendering that input on your screen. Streamlit will handle the rerun (loop) as soon as the user does something.

In the case of chat inputs, youā€™ll need a conditional on the output to check if it exists. If itā€™s None, the user hasnā€™t responded yet, so you donā€™t want to keep going, re-prompting them yet, so your script will need to respond correctly to that. Then when the user responds, the whole script reruns and when it gets to that input and sees a non-None value it can proceed.

I was trying to distill your case into something very simple like:

  1. Get a response
  2. Iterate through the response to modify it
  3. Submit the completed response

The critical pattern in my example is that I use Session State to hold the ā€œpending answer.ā€ On every rerun, the script sees something ā€œpendingā€ and works with that. When the pending thing has been fully resolved, itā€™s committed to the history and cleared, so the script falls back to asking for a new item input.

Iā€™ll try to extend the example to use fragments or dialogs to be a little more sophisticated.

1 Like

I am trying to make sense of this

In my case, While True is at two places:

  • First to make sure validate_input_query() keeps running until the response is completely validated before returning output from it.
  • 2nd, to make sure a user enters a value to decide the end of the whole process.

My issue is, I am not sure how to keep asking user to input again during validation when I canā€™t use chat_input again. This is just boggling my mind and have spend days to figure this out to no avail. Should I put the get_user_input() under a fragment so everytime it is called the underlying chat_input() for that piece is only run?

It shoudnā€™t be this hard to understand how to keep asking user for multiple inputs via chat_input(). Its just confusing as of now

If I have 7-8 times input is needed from user and depending upon its successful validation, more can be needed.

Are u saying everytime chat_input() is run due to script re-run when a user enters somethig, you assign every time value of input into a session_state variable and use that in ur code where input from the user is needed?

Question I have is, if 7-8 times input is needed and those input would typically be stored in equal number of variables, are u saying now, we create 7-8 session_state_variables so each time chat_input is called , it gets saved into that?

Problem is how would I keep track of how many session_state input variables I have and how would I know which chat_input() output need to be saved as which session_state input variable.

Here is a more meaningful example. In this example:

  1. The user enters a prompt
  2. A fake LLM generates a response
  3. Validation highlights possible errors and asks the user to choose:
    • Accept as-is
    • Correct the highlighted areas
    • Rewrite the entire response
  4. After choosing, the user is prompted accordingly until the finally accept the answer.
  5. Repeat

Itā€™s important to observe that for any run of the script, only one thing is being done. A single script run will not collect multiple responses form the user. Instead, the script run only process the one thing the user did last.

Question I have is, if 7-8 times input is needed and those input would typically be stored in equal number of variables, are u saying now, we create 7-8 session_state_variables so each time chat_input is called , it gets saved into that?

We donā€™t need separate variables in Session State for every possible thing. The key things in the following example are:

  • There is a variable keeping track of what stage weā€™re in (starting with a new prompt, validating a response, iterating through corrections, or just writing a custom answer).
  • In the validation, correction, and rewrite stages, the tentative (pending) answer is stored in Session State.
  • Furthermore, during the correction phase, a list of errors is stored, but we pop them off the list as they are resolved. When there are none left, the user can accept or just rewrite the whole thing if they still arenā€™t happy.

Sorry itā€™s a bit scrappy, but hopefully it shows one possible logical flow.

import streamlit as st
import random
import string
import time

if "initialized" not in st.session_state:
    st.session_state.stage = "ask"
    st.session_state.history = []
    st.session_state.errors = []
    st.session_state.pending = None
    st.session_state.initialized = True

def fake_LLM():
    """Returns a fake generated response"""
    for i in range(20):
        time.sleep(.2)
        yield "".join(random.choices(string.ascii_lowercase, k=10)) + " "

def locate_errors(response_segments):
    """Locates three random "errors" in the response"""
    return random.choices(range(len(response_segments)), k=3)

def add_highlights(segments, highlight_indices):
    return [
        "***"+response_segments[i]+"***"
        if i in error_indices
        else response_segments[i]
        for i in range(len(response_segments))
    ]

def get_new_answer():
    new = st.text_area("Rewrite the answer", value=" ".join(st.session_state.pending))
    if st.button("Submit"):
        st.session_state.history.append({"role":"assistant", "content":new})
        st.session_state.pending = None
        st.session_state.stage = "ask"
        st.rerun()
                                        

for message in st.session_state.history:
    with st.chat_message(message["role"]):
        st.write(message["content"])

if st.session_state.stage == "ask":
    if user_input := st.chat_input("Enter a prompt"):
        st.session_state.history.append({"role":"user","content": user_input})
        with st.chat_message("user"):
            st.write(user_input)
        with st.chat_message("assistant"):
            response = st.write_stream(fake_LLM())
            st.session_state.pending = response.split(" ")
            st.session_state.stage = "validate"
            st.rerun()
    st.stop()

if st.session_state.stage == "validate":
    st.chat_input("...", disabled=True)
    response_segments = st.session_state.pending
    error_indices = locate_errors(response_segments)
    highlighted_segments = add_highlights(response_segments, error_indices)
    with st.chat_message("assistant"):
        st.markdown(" ".join(highlighted_segments))
        st.divider()
        if len(error_indices) > 0:
            st.write("Possible errors are bolded.")
            if st.button("Correct errors"):
                st.session_state.errors = error_indices
                st.session_state.stage = "correct"
                st.rerun()
        if st.button("Accept as-is"):
            st.session_state.history.append({"role":"assistant","content":" ".join(st.session_state.pending)})
            st.session_state.pending = None
            st.session_state.stage = "ask"
            st.rerun()
        if st.button("Rewrite answer"):
            st.session_state.stage = "rewrite"
            st.rerun()
    st.stop()

if st.session_state.stage == "correct":
    if len(st.session_state.errors) == 0:
        st.chat_input("Accept or rewrite the answer above.", disabled=True)
        with st.chat_message("assistant"):
            st.markdown(" ".join(st.session_state.pending))
            st.divider()
            if st.button("Accept"):
                st.session_state.history.append({"role":"assistant","content":" ".join(st.session_state.pending)})
                st.session_state.pending = None
                st.session_state.stage = "ask"
                st.rerun()
            if st.button("Rewrite answer"):
                st.session_state.stage = "rewrite"
                st.rerun()
        st.stop()
    user_input = st.chat_input("Replacement text")
    response_segments = st.session_state.pending
    error_indices = st.session_state.errors
    current_error = st.session_state.errors[0]
    highlighted_segments = add_highlights(response_segments, error_indices)
    highlighted_segments[current_error] = ":red["+highlighted_segments[current_error]+"]"
    with st.chat_message("assistant"):
        st.markdown(" ".join(highlighted_segments))
        st.divider()
        st.markdown("What should the red, bold segment say?")
        if user_input:
            st.session_state.pending[current_error] = user_input
            st.session_state.errors.pop(0)
            st.rerun()

if st.session_state.stage == "rewrite":
    st.chat_input("Rewrite the answer in the text area above.", disabled=True)
    with st.chat_message("assistant"):
        get_new_answer()

wow. Thats an amazing piece u built. I didnā€™t know all that interaction could be done within chat_message. This makes me so excited now that good UX can be built in chatbot.

I am still, not sure I understood how u did it.

  • I see u r using multiple chat_input() in ur code. I thought we can only have one chat_input in the entire script? U are having 3 and yet it doesnā€™t DuplicateWidget or issue I was facing when multiple times I needed to get input and I ended by calling chat_input()
  • How does st.stop() and st.rerun() help in making sure only the final accepted answer is shown back in history and UI and not intermediate.

For all widgets, you have have multiple instances in an app, just not multiple instances that are exactly the same. (e.g. Use different keys so Streamlit can tell them apart.) st.chat_input is a little special when you donā€™t put it in a container since it pins to the bottom, so typically for design reasons, youā€™d only want one instance of st.chat_input within the main body of the app. In my example, I only ever have one chat input at a time. The command is called in different places, but any one script run will only call it once.

st.stop() could be replaced in most places with elif instead. Both st.stop() and st.rerun() will end the script run at that point and go no further. st.stop() will then wait for the user to do something, but st.rerun() will immediately refresh the view, typically to reflect some change thatā€™s just been made. In this case, when switching stages, we want to redraw the screen to show the next stage. st.rerun() could be replaced with callbacks in most places. The flow that happens in the example:

  1. The page loads at stage ā€œvalidateā€
  2. The user clicks ā€œAccept as-isā€
  3. The script reruns (still with stage ā€œvalidateā€
  4. When the script run gets to the accept button, itā€™s true, so the stage is changed to ā€œaskā€ and the pending message is moved into the chat history.
  5. The script reruns from st.rerun()
  6. The page loads with stage ā€œaskā€ and the newly added message in the history.

You can move those things that happen inside ā€œif buttonā€ to a callback, so they happen at the begginning of the rerun triggered by the button. I was trying to keep the example a little more ā€œflatā€ to show the logic. More about button logic.

I really loved ur UX. It gives me more ideas I wasnā€™t thinking.

However I am still battling how to make it all work for my case. I have just like urs, a main chat_input() which kick starts the input. Then that input is passed to an LLM to get a response which is passed to a validate_input_query() for validation. In the validate_input_query() is where I also have like 3 stages: validation, correct, ask_additional. Within those it prompts multiple times to ask for input.

Issue is, u had 4 chat_input places: main, all three stages.

However within a stage, u only had once. In my case, within a stage (lets say correction), I am going to ask multiple times to user to enter few things until validation stage goes fine and says complete. That requires putting chat_input in multiple places within the same validation stage which will break it and cause errors

To be clear, I have st.chat_input once per run. If you see in the correction stage, there is an iterative process that happens. Each run asks about ā€œthe next error.ā€ The error is processed and removed from the list of consideration. So on the next rerun, ā€œthe next errorā€ is something different.

Can you describe more specifically what UI you want to accomplish? Can you for example, give a simple LLM response and tell me what you want a user to do with it in what order? In my example, I just generate a list of action items, then work through them in the order they were identified. Do you have something else, fundamentally, you want to do?

Also notice how I disable the chat input on runs where I want the user to look up and take the next step within the chat message above. You can do that too and create a form within the chat message to collect multiple bits of information in one script run.

(Also, I should change my example to use highlight (background color) instead of bold to it stands out betterā€¦ :thinking: )