I have tried to put together a script which has all functions that are called. The script itself might not run as it need few more global variables and some LLM API calls, but it will give u how the bottom most function run_chatbot() is first run to start the chat_input() and Chat UI.
That then passed the first input text to process_chat_response() within which is the validate_input_query() function which is where other chat_input() are getting called via get_user_input() function defined on the top. Each get_input_user() tht gets called in validate_input_query() has st.chat_input() to ask for user input when prompted in validate_input_query() in stream_message() statement (which is nothing but a common function for chat_message()).
# Utility to stream messages one by one in the chat interface
import streamlit as st
def stream_message(response_text: str, avatar: str = "š", delay: float = 0.07, max_words: int = 25) -> None:
"""
Streams response text in the chat interface with a typing effect, simulating ChatGPT's output.
Args:
response_text (str): The full response to be displayed.
avatar (str): Avatar for the assistant in the chat.
delay (float): Time delay between streaming each word.
max_words (int): Maximum words before inserting a line break.
"""
# Store the assistant's message in session state
st.session_state.messages.append({"role": "assistant", "content": response_text})
with st.chat_message("assistant", avatar=avatar):
message_placeholder = st.empty()
full_response = ""
words_count = 0
# Split response_text while preserving the markdown (** for bold text)
words = response_text.split(' ')
for i, chunk in enumerate(words):
# Add space after each word
full_response += chunk + " "
words_count += 1
# Insert line break after max_words or when punctuation ends a sentence
if words_count >= max_words and chunk.endswith(".") and i < len(words) - 1:
full_response += "\n\n" # Insert line break for readability
words_count = 0
# Sleep to simulate delay in typing
time.sleep(delay)
message_placeholder.markdown(
f'{full_response} ā'
)
message_placeholder.markdown(full_response)
def get_user_input(message: str, key: str) -> str:
"""
Handles user input via the chat interface, displays it, and stores it in session state.
Args:
prompt (str): The message shown to the user as a prompt for input.
Returns:
str: The user input.
"""
# Stream the prompt message
stream_message(message)
print(f'key: {key}')
# Accept user input via chat input
user_input = st.chat_input(message, key = key)
print(f"user_input : {user_input}")
# Display the user input back in the chat
with st.chat_message("user", avatar=None):
st.markdown(user_input)
# Store the user input in session state for persistence
st.session_state.messages.append({"role": "user", "content": user_input})
return user_input
def get_unique_key(prefix: str) -> str:
"""
Generates a unique key by using a session-based counter.
Args:
prefix (str): A prefix string to identify the type of input.
Returns:
str: A unique key combining the prefix and the current step counter.
"""
if "input_counter" not in st.session_state:
st.session_state.input_counter = 0 # Initialize the counter
print(st.session_state.input_counter)
st.session_state.input_counter += 1 # Increment the counter for each input field
return f"{prefix}_{st.session_state.input_counter}"
def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
"""
Processes the chat response from a previous LLM Agent and validate it as per schema_model
Args:
merged_arguments (Dict): The parsed arguments from previous LLM
schema_model (BaseModel): InputQuery Schema
Returns:
BaseModel: Verified Pydantic Model
"""
while True:
errors = {}
try:
validated_query = schema_model(**merged_arguments)
# Convert the validated Pydantic model into a dictionary for easier iteration
validated_data = validated_query.dict()
message = "Great! Here's what we learnt from your intent:\n\n"
# Stream each key-value pair one by one with bold keys
for key, value in validated_data.items():
message += f"**{key}**: {value}\n\n" # Key is bolded
stream_message(message, avatar="š") # Stream the full message to the chat
print(f'validated query in validate_input_query:{validated_query}\n')
break # Exit loop if validation is successful
except ValueError as e:
# Collect all field errors
print(f'e value: {e}')
for error in e.errors():
print(f'error: {error}')
if error['loc']:
field = error['loc'][0]
error_message = error['msg']
errors[field] = error_message
else:
# Handle general errors (like missing required fields)
print(f"General validation error: {error['msg']}")
missing_fields_message = error['msg']
if "Missing required fields" in missing_fields_message:
missing_fields = missing_fields_message.split(": ")[-1].split(", ")
for field in missing_fields:
# Prompt the user for missing mandatory fields using get_user_input()
user_input = get_user_input(f"**{field} is mandatory. Please provide a value for '{field}':**", key = get_unique_key(f"mandatory_field_{field}"))
if user_input is None:
continue
else:
merged_arguments[field] = user_input
st.session_state.chat_history.append({"role": "user", "content": user_input})
if errors:
# Show all specific errors to the user at once
#print("The following fields needs fixing:")
stream_message("The following fields need fixing:\n\n")
for field, error_message in errors.items():
#print(f"- {field}: {error_message}")
stream_message(f"- **{field}**: {error_message}\n\n")
# Prompt user to correct all specific errors together
for field , error_message in errors.items():
# Ask for input for the field
stream_message(f"Please provide a valid value for '{field}':\n\n")
user_input = get_user_input(f"Please provide a valid value for '{field}':", key = get_unique_key(f"correct_field_{field}"))
if user_input is not None:
merged_arguments[field] = user_input
st.session_state.chat_history.append({"role": "user", "content": user_input})
while True:
user_choice = get_user_input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ",
key = get_unique_key(f"user_choice_key"))
if user_choice is None:
continue # Wait for valid input
if user_choice == 'n':
return validated_query
elif user_choice == '1':
additional_info = get_user_input("Please enter additional details:", key = get_unique_key(f"additional_info_key"))
if additional_info:
validated_query = process_chat_response(additional_info, tools = tools)
if(isinstance(validated_query, str)):
print(f'Yes {validated_query} is string')
continue
break
elif user_choice == '2':
field_to_correct = get_user_input("Enter the field name you want to correct:", key = get_unique_key(f"field_to_correct_key"))
if field_to_correct in merged_arguments:
new_value = get_user_input(f"Enter the new value for '{field_to_correct}':", key = get_unique_key(f"new_value_{field_to_correct}"))
if new_value is not None:
merged_arguments[field_to_correct] = new_value
validated_query = schema_model(**merged_arguments) # Re-validate with the updated field
break
else:
stream_message(f"Field '{field_to_correct}' is not present in the current query.")
else:
stream_message("Invalid option. Please enter 1, 2, or 'n'.")
def process_chat_response(
user_query: str,
system_prompt: str = SYSTEM_PROMPT,
tools: Any = None,
schema_model: BaseModel = InputQuery,
language: str = "English"
) -> Union[BaseModel, str]:
"""
Processes the chat response from a completion request and outputs function details and arguments.
This function sends a list of messages to a chat completion request, processes the response to extract
function calls and arguments, and prints relevant information. It also merges function arguments and
initializes an `InputQuery` object using the merged arguments.
Args:
user_query (str): Query entered by user
system_prompt (str): If a user has any specific prompt to enter.
tools (Any): The tools to be used with the chat completion request.
schema_model (BaseModel): The Pydantic model class used for validating user query.
language (str): Language in which you want to a response
Returns:
response_output (Union[BaseModel, str]): Returns the response for the query.
"""
# Ensure chat_history is initialized in session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Convert chat history into a formatted string
chat_history_str = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in st.session_state.chat_history])
# Format the system prompt with the chat history
formatted_system_prompt = system_prompt.format(chat_history=chat_history_str, language = language)
messages = [{"role": "system", "content": formatted_system_prompt}]
messages.append({"role": "user", "content": user_query})
print(f'Full Prompt: {messages}')
print(f'Tools : {tools}')
response = chat_completion_request(messages, tools=tools, response_format={"type": "text"})
print(f'\n{response}')
merged_arguments = defaultdict(lambda: None)
if response.choices[0].finish_reason == "tool_calls":
for tool_call in response.choices[0].message.tool_calls:
function_arguments = json.loads(tool_call.function.arguments)
merged_arguments.update(function_arguments)
merged_arguments = dict(merged_arguments)
print()
print(f'function call arguments: {function_arguments}')
print(f"Merged Arguments: {merged_arguments}")
# Convert merged_arguments to a JSON-like string and escape curly braces
merged_arguments_str = str(merged_arguments).replace("{", "{{").replace("}", "}}")
# Append the user's query and the assistant's response to the chat history
st.session_state.chat_history.append({"role": "user", "content": user_query})
st.session_state.chat_history.append({"role": "assistant", "content": merged_arguments_str})
# Verifying the Output with Verifier LLM Agent
#verifier_response = verifier_agent_response(user_query, merged_arguments, tools)
print()
#print(f"Verifier LLM Agent Response: {verifier_response}")
# Validate the InputQuery object with re-prompting if necessary
final_response = validate_input_query(merged_arguments, schema_model)
print(f"Process Chat Final Response: {final_response}")
elif response.choices[0].finish_reason == 'stop' and response.choices[0].message.content is not None:
final_response = response.choices[0].message.content.strip()
# Verifying the Output with Verifier LLM Agent
#verifier_response = verifier_agent_response(user_query, final_response, tools)
print()
#print(f"Verifier LLM Agent Response: {verifier_response}")
# Append the user's query and the assistant's response to the chat history
st.session_state.chat_history.append({"role": "user", "content": user_query})
st.session_state.chat_history.append({"role": "assistant", "content": final_response})
#print(f'\n{final_response}')
print(f"chat history: {st.session_state.chat_history}")
return final_response
def run_chatbot():
# Modularize chatbot UI elements
"""Run the Streamlit chatbot interface."""
language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Helloš , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]
for message in st.session_state.messages:
if message["role"] == "assistant":
with st.chat_message(message["role"], avatar = "š"):
st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
else:
with st.chat_message(message["role"]):
st.markdown(f':black[{message["content"]}]')
if user_input := st.chat_input("Ask something", key = "user_input_main"):
with st.chat_message("user", avatar = None):
st.markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
assistant_response = process_chat_response(user_input, tools = tools, language = language).strip()
# Stream assistant's response in real-time
stream_message(assistant_response)