How to visualize sources in Streamlit Chat App

Hi,

I’ve been following this demo to create my chat app. The app queries a RAG-based LLM deployed on databricks.

The base functionality works well but I am now trying to display the References right after the AI message. The references are chunks of text returned by the LLM API in the form of a list of dictionaries.

To start simple, I display them as markdown text. The problem is that the References are only displayed for the latest message. As soon as I ask a new question, the References of the previous message gets greyed out and eventually disappear. This is despite the fact that I am storing the References in the session state (pretty much like messages).

After playing around with ChatGPT, I now have a solution where the References of previous messages “randomly” pop-up. For example, after 3 messages, the References of the 1st message suddenly appear again but the 2nd message still does not show the References. It’s pretty random and hard to describe.

How do I resolve the issue?

Code:

import os
import requests
import random
import time
import mlflow
import numpy as np
import pandas as pd
import json

from dotenv import load_dotenv
import streamlit as st

from parameters import DBX_MODEL_URL

# Load environment variables
load_dotenv()

# Constants
DBX_TOKEN = os.getenv('DBX_TOKEN')
SAMPLE_REFERENCES = [
    {'document_name': 'doc1', 'chunk_text': 'This is a chunk of text from the doc1.'},
    {'document_name': 'doc2', 'chunk_text': 'This is a chunk of text from the doc2.'},
]


def create_tf_serving_json(data):
    return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}


def score_model(dataset):
    headers = {'Authorization': f'Bearer {DBX_TOKEN}', 'Content-Type': 'application/json'}
    ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.post(DBX_MODEL_URL, headers=headers, data=data_json)
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    return response.json()


def parse_llm_message(response):
    ai_answer = response['predictions'][0]['choices'][0]['message']['content']
    references = SAMPLE_REFERENCES#response['predictions'][0].get('references', [])
    return ai_answer, references


def prompt_llm(messages):
    input_example = np.array([{'messages': messages}])
    response = score_model(input_example)
    return parse_llm_message(response)


# Streamlit UI setup
st.title("Ask me any questions!")

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = []
if "references" not in st.session_state:
    st.session_state.references = []


# Display chat history
for i, message in enumerate(st.session_state.messages):
    print('Rendering message:', i)
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        
        # Check if references exist for this message and it's from the assistant
        if message["role"] == "assistant" and i < len(st.session_state.references):
            st.markdown("### References:")
            print('Found references for this message.', i)
            for ref in st.session_state.references[i]:
                st.write(f"**Document Name:** {ref['document_name']}")
                st.markdown(f"> {ref['chunk_text']}")

# Handle user input
if prompt := st.chat_input("What is up?"):
    # Display user message immediately
    st.chat_message("user").markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Get AI response and references
    ai_response, references = prompt_llm(st.session_state.messages)

    # Display AI response and references
    with st.chat_message("assistant"):
        st.markdown(ai_response)
        st.markdown("### References:")
        for ref in references:
            st.write(f"**Document Name:** {ref['document_name']}")
            st.markdown(f"> {ref['chunk_text']}")

    # Update session state with the new messages and references
    st.session_state.messages.append({"role": "assistant", "content": ai_response})
    st.session_state.references.append(references)

Screenshot 1 - References of 1st message disappear:

Screenshot 2 - References of 1st message appear back. 2nd message still does not display them.

I’m not sure about the logic here:

if message["role"] == "assistant" and i < len(st.session_state.references):

What about just storing the references with the messages itself. So a messages in your chat history would be:

[
    {"role": "users", "content": "hi"},
    {"role": "assistant", "content": "hi", "references": []}
    {"role": "users", "content": "What is a cat?"},
    {
        "role": "assistant",
        "content": "A glorious, adorable creature with a boopable nose.",
        "references": ["doc1"]
    }
]