Implementing LangGraph's tool review in Streamlit

I am trying to implement LangGraph’s tool review in Streamlit. When there is a tool call, the graph interrupts. After this, the Streamlit app should request user input. Once user input is received, the graph execution should continue based on the logic given in the human_review_node.

Below is my current implementation. But, this has several issues.

  1. When the graph is interrupted, the radio box requesting user input does not show up. You have to add some text to to chat input and hit enter.
  2. Also, you have to click on the radio submit button twice to rerun the Streamlit app.

Is there a way to refresh the Streamlit app when the graph ‘interrupts’?

import streamlit as st
from typing_extensions import Literal
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.checkpoint.memory import MemorySaver
from langgraph.types import Command, interrupt
# from langchain_anthropic import ChatAnthropic
from langchain_core.tools import tool
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
import os
from dotenv import load_dotenv
load_dotenv()


def display_chat_history():
    for msg in st.session_state.chat_history:
        if isinstance(msg, AIMessage):
            st.chat_message("assistant").write(msg.content)
        elif isinstance(msg, HumanMessage):
            st.chat_message("user").write(msg.content)

model = ChatOpenAI(
    model='gpt-4o',
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)


# --- Tool Definition ---
@tool
def weather_search(city: str):
    """Search for the weather"""
    return "Sunny!"

# model = ChatAnthropic(model_name="claude-3-5-sonnet-latest").bind_tools([weather_search])
model = model.bind_tools([weather_search])


class State(MessagesState):
    """Simple state."""

def call_llm(state):
    return {"messages": [model.invoke(state["messages"])]}

def human_review_node(state) -> Command[Literal["call_llm", "run_tool"]]:
    last_message = state["messages"][-1]
    tool_call = last_message.tool_calls[-1]
    # Interrupt and wait for user input
    print('graph interrupted?')
    human_review = interrupt(
        {
            "question": "Is this correct?",
            "tool_call": tool_call,
        }
    )
    review_action = human_review["action"]
    review_data = human_review.get("data")
    if review_action == "continue":
        return Command(goto="run_tool")
    elif review_action == "update":
        updated_message = {
            "role": "ai",
            "content": last_message.content,
            "tool_calls": [
                {
                    "id": tool_call["id"],
                    "name": tool_call["name"],
                    "args": review_data,
                }
            ],
            "id": last_message.id,
        }
        return Command(goto="run_tool", update={"messages": [updated_message]})
    elif review_action == "feedback":
        tool_message = {
            "role": "tool",
            "content": review_data,
            "name": tool_call["name"],
            "tool_call_id": tool_call["id"],
        }
        return Command(goto="call_llm", update={"messages": [tool_message]})

def run_tool(state):
    new_messages = []
    tools = {"weather_search": weather_search}
    tool_calls = state["messages"][-1].tool_calls
    for tool_call in tool_calls:
        tool = tools[tool_call["name"]]
        result = tool.invoke(tool_call["args"])
        new_messages.append(
            {
                "role": "tool",
                "name": tool_call["name"],
                "content": result,
                "tool_call_id": tool_call["id"],
            }
        )
    return {"messages": new_messages}

def route_after_llm(state) -> Literal[END, "human_review_node"]:
    if len(state["messages"][-1].tool_calls) == 0:
        return END
    else:
        return "human_review_node"

# --- Build the Graph ---
builder = StateGraph(State)
builder.add_node(call_llm)
builder.add_node(run_tool)
builder.add_node(human_review_node)
builder.add_edge(START, "call_llm")
builder.add_conditional_edges("call_llm", route_after_llm)
builder.add_edge("run_tool", "call_llm")

if 'graph' not in st.session_state:

    memory = MemorySaver()
    st.session_state.graph = builder.compile(checkpointer=memory)
    st.session_state.config = {"configurable": {"thread_id": "5"}}

if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []


st.title("Tool Review (Human-in-the-Loop)")

display_chat_history()




user_input = st.chat_input("Ask a question:")

if st.session_state.graph.get_state(st.session_state.config).next:

    with st.sidebar:
        st.write('awaiting human review')
    st.subheader("Tool Call Review")
    action = st.radio("Approve tool call?", ["continue", "update", "feedback"])
    review_data = None
    if action == "update":
        review_data = st.text_input("Updated tool arguments (JSON):", "{}")
    elif action == "feedback":
        review_data = st.text_area("Feedback for the model:")
    if st.button("Submit Review"):
        # Prepare the Command for resuming execution
        cmd_data = {"action": action}
        if action in ["update", "feedback"]:
            cmd_data["data"] = review_data
        print('resumiing the graph?')
        response = st.session_state.graph.invoke(Command(resume=cmd_data), config=st.session_state.config)
        st.session_state.chat_history.append(response['messages'][-1])

elif user_input:
    
    st.session_state.chat_history.append({"role": "user", "content": user_input})
    response = st.session_state.graph.invoke({"messages": st.session_state.chat_history}, config=st.session_state.config)
    st.session_state.chat_history.append(response['messages'][-1])