I just started using the Streamlit library and I love it tremendously. But I encountered a small problem in this process. This problem occurred when I added new codes to store historical data in the AI Chat I developed.
I tried the if “history” not in st.session_state: solution in the documentation, but nothing changed.
memory_prompt_template = """<s>[INST] You are an AI chatbot, an expert in economics and finance, chatting with a human. Answer his questions.
Previous conversation: {history}
Human: {human_input}
AI: [/INST]"""
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
def create_llm(model_path = config["model_path"]["large"], model_type = config["model_type"], model_config = config["model_config"]):
llm = CTransformers(model=model_path, model_type=model_type, config=model_config)
return llm
def create_embeddings(embeddings_path = config["embeddings_path"]):
return HuggingFaceInstructEmbeddings(embeddings_path)
def create_chat_memory(chat_history):
return ConversationBufferWindowMemory(memory_key="history", chat_memory=chat_history, k=3)
def create_prompt_from_template(template):
return PromptTemplate.from_template(template)
def create_llm_chain(llm, chat_prompt, memory):
return LLMChain(llm=llm, prompt=chat_prompt, memory=memory)
def load_normal_chain(chat_history):
return chatChain(chat_history)
def save_chat_history_json(chat_history, file_path):
with open(file_path, "w") as f:
json_data = [message.dict() for message in chat_history]
json.dump(json_data, f)
def load_chat_history_json(file_path):
with open(file_path, "r") as f:
json_data = json.load(f)
messages = [HumanMessage(**message) if message["type"] == "human" else AIMessage(**message) for message in json_data]
return messages
def get_timestamp():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
class chatChain:
def __init__(self, chat_history):
self.memory = create_chat_memory(chat_history)
llm = create_llm()
chat_prompt = create_prompt_from_template(memory_prompt_template)
self.llm_chain = create_llm_chain(llm, chat_prompt, self.memory)
def run(self, user_input):
return self.llm_chain.run(human_input= user_input, history=self.memory.chat_memory.messages, stop="Human:")
def load_chain(chat_history):
return load_normal_chain(chat_history)
def clear_input_field():
st.session_state.user_question = st.session_state.user_input
st.session_state.user_input = ""
def set_send_input():
st.session_state.send_input = True
clear_input_field()
def save_chat_history():
if st.session_state.history != []:
if st.session_state.session_key == "new_session":
st.session_state.new_session_key = get_timestamp()
save_chat_history_json(st.session_state.history, config["chat_history_path"] + st.session_state.new_session_key + ".json")
else:
save_chat_history_json(st.session_state.history, config["chat_history_path"] + st.session_state.session_key + ".json")
def main():
st.title("Example AI Model")
chat_container = st.container()
st.sidebar.title("Chat Sessions")
chat_sessions = ["new_session"] + os.listdir(config["chat_history_path"])
if "send_input" not in st.session_state:
st.session_state.send_input = False
st.session_state.user_question = ""
st.session_state.new_session_key = None
st.sidebar.selectbox("Select a chat session", chat_sessions, key="session_key")
chat_history = StreamlitChatMessageHistory(key="history")
llm_chain = load_chain(chat_history)
user_input = st.text_input("Ask your questions to model.", key="user_input", on_change=set_send_input)
send_button = st.button("Ask", key="send_button")
if send_button or st.session_state.send_input:
if st.session_state.user_question != "":
with chat_container:
llm_response = llm_chain.run(st.session_state.user_question)
st.session_state.user_question = ""
if chat_history.messages:
with chat_container:
st.write("Chat History:")
for message in chat_history.messages:
st.chat_message(message.type).write(message.content)
save_chat_history()
if __name__ == "__main__":
main()