Streamlit Crossfiltering - not able to do crossfiltering

Hi, am new to Streamlit. i am having a project deadline in two days. and my struggling to enable crossfiltering between plots. i am following the https://www.youtube.com/watch?v=htXgwEXwmNs for implementation. but am not able to enable crossfiltering.

this is my whole code. i don’t know where am making mistake. i will also add the data set which am using its Australian shark dataset. github repo for code and data - GitHub - NK150600/Australian_shark

if anyone knows how to implement it, it would mean the world to me.
import streamlit as st
import pandas as pd
import plotly.express as px

def initialize_state():
“”“Initializes all filters and counter in Streamlit Session State
“””
for q in [“map”, “scatter”]:
if f"{q}_query" not in st.session_state:
st.session_state[f"{q}_query"] = set()

def query_data(df: pd.DataFrame) → pd.DataFrame:
“”“Apply filters in Streamlit Session State to filter the input DataFrame”“”

# Create combined column for lat-lon filtering
df["lat_lon"] = df["Latitude"].astype(str) + "-" + df["Longitude"].astype(str)

# Add a 'selected' column, initially set to True for all rows
df["selected"] = True

# Apply filters based on session state
if st.session_state["map_query"]:
    df.loc[~df["lat_lon"].isin(st.session_state["map_query"]), "selected"] = False

# For scatter plot, we'll need to handle it differently since 'no of incidents' is created later
# We'll leave this part to be handled in the visualization function

# Return the entire DataFrame
return df

def display_map(location_data):
fig = px.scatter_mapbox(location_data, lat = ‘Latitude’, lon = ‘Longitude’, hover_name = ‘Shark.scientific.name’, zoom = 3)

fig.update_layout(mapbox_style=‘open-street-map’,height= 600,
margin=dict(l=0, r=0, t=0, b=0),
autosize=True)

return fig

def scatterplot(data):

col1 = st.columns(1)
with col1[0]:
    sel = st.selectbox('select interested feature',
                        options=['Incident Hotspots by Location','Temperature Trends and Incident Counts','Number of Incidents by State','Average Temperature by State','Temperature vs. Incidents (State)'])
# with col2:
#     y_axis = st.selectbox('Select y-axis value', options=['no.incidents','Latitude','Longitude','temp'])
fig = None
incidents_by_temp = None
incidents_by_location = None
# if x_axis == 'temp' and y_axis == 'no.incidents':
if sel == 'Temperature Trends and Incident Counts' :
    incidents_by_temp = (
    data.groupby('temp')
    .size()  # Count the number of rows (incidents) per temperature
    .reset_index(name='no_of_incidents')  # Rename the count column
    )

    fig = px.scatter(
    incidents_by_temp, 
    x= 'temp', 
    y= 'no_of_incidents',
    title="Scatterplot: Temperature vs Number of Incidents",
    labels={"temperature": "Temperature", "no_of_incidents": "Number of Incidents"},
    template="plotly_white"
    )

# if x_axis in ['Longitude','Latitude'] and y_axis in ['Latitude','Longitude'] :
elif sel == 'Incident Hotspots by Location' :
    incidents_by_location = data.groupby(['Latitude', 'Longitude']).size().reset_index(name='no_of_incidents')

    # Create scatter plot
    fig = px.scatter(
        incidents_by_location,
        x= 'Latitude',
        y= 'Longitude',
        size='no_of_incidents',  # Size of points based on number of incidents
        color='no_of_incidents',  # Color points based on number of incidents
        title="Scatter Plot: Number of Incidents vs Location",
        labels={"Longitude": "Longitude", "Latitude": "Latitude", "no_of_incidents": "Number of Incidents"},
        template="plotly_white"
    )

# if x_axis == 'Location' and y_axis == 'no.incidents':
elif sel == 'Number of Incidents by State' :
    incidents_by_location = data['State'].value_counts().reset_index()
    incidents_by_location.columns = ['State', 'Number of Incidents']
    # Create scatter plot
    fig = px.bar(
        incidents_by_location,
        x='State',
        y='Number of Incidents',
        title='Number of Incidents by State',
        labels={'State': 'State', 'Number of Incidents': 'Number of Incidents'},
        template='plotly_white'
    )

# if x_axis == 'Location' and y_axis == 'temp':
elif sel == 'Average Temperature by State' :
    # Calculate average temperature by location
    avg_temp_by_location = data.groupby('State')['temp'].mean().reset_index()

    # Create a bar chart
    fig = px.bar(
        avg_temp_by_location,
        x='State',
        y='temp',
        title='Average Temperature by State',
        labels={'State': 'State', 'Temperature': 'Average Temperature'},
        template='plotly_white'
    )

elif sel == 'Temperature vs. Incidents (State)' :
    # Group the data by temperature and state, and count the number of incidents
    incidents_by_temp = data.groupby(['temp', 'State']).size().reset_index(name='no_of_incidents')
    
    # Create the scatter plot
    fig = px.scatter(
        incidents_by_temp,
        x='temp',  # Temperature on the x-axis
        y='no_of_incidents',  # Number of incidents on the y-axis
        size='no_of_incidents',  # Point size based on number of incidents
        color='State',  # Color points based on State
        title="Scatterplot: Temperature vs Number of Incidents (State-wise)",
        labels={"temp": "Temperature", "no_of_incidents": "Number of Incidents"},
        template="plotly_white",
    )  

if fig is not None:
    if sel in ['Temperature Trends and Incident Counts', 'Temperature vs. Incidents (State)'] and incidents_by_temp is not None:
        fig.update_traces(
            selectedpoints=[
                i for i, point in enumerate(
                    incidents_by_temp['temp'].astype(str) + '-' +
                    incidents_by_temp['no_of_incidents'].astype(str)
                ) if point in st.session_state["scatter_query"]
            ]
        )
    elif sel == 'Incident Hotspots by Location' and incidents_by_location is not None:
        fig.update_traces(
            selectedpoints=[
                i for i, point in enumerate(
                    incidents_by_location['Latitude'].astype(str) + '-' +
                    incidents_by_location['Longitude'].astype(str)
                ) if point in st.session_state["scatter_query"]
            ]
        )
    
return fig
# st.plotly_chart(fig, use_container_width=True)

def render_plotly_ui(df,options):
px_map = display_map(df)
px_plot = scatterplot(df)
plot_map = st.plotly_chart(px_map, on_select=“rerun”, selection_mode=(‘points’, ‘box’, ‘lasso’), use_container_width=True)
if options == ‘temperature’:
plot_scatter = st.plotly_chart(px_plot , on_select=“rerun”, selection_mode=(‘points’, ‘box’, ‘lasso’), use_container_width=True)
with st.expander(‘selected’):
st.write(plot_map)
with st.expander(‘selected scattermap’):
st.write(plot_scatter)

current_query = {}
# if len(plot_scatter.selection.points)>0: 
current_query["map_query"] = {f"{point['x']}-{point['y']}" for point in plot_scatter["selection"]["points"]}
st.write(current_query["map_query"])
 
# if len(plot_map.selection.points)>0: 
current_query["scatter_query"] = {f"{point['lat']}-{point['lon']}" for point in plot_map["selection"]["points"]}
st.write(current_query["scatter_query"])

st.write(current_query)
st.write(st.session_state)

return current_query

def update_state(current_query):
“”"Stores input dict of filters into Streamlit Session State.

If one of the input filters is different from previous value in Session State, 
rerun Streamlit to activate the filtering and plot updating with the new info in State.
"""
rerunn = False
for q in ["map", "scatter"]:
    if current_query[f"{q}_query"] - st.session_state[f"{q}_query"]:
        st.session_state[f"{q}_query"] = current_query[f"{q}_query"]
        rerunn = True

if rerunn:
    st.rerun()    

def main():
df = pd.read_csv(“complete_data_final.csv”)
transformed_df = query_data(df)
st.sidebar.image(“shark.jpg”)
st.sidebar.write(“AUSTRALIAN SHARK”)
st.sidebar.markdown(“-----------------”)

options = st.sidebar.radio('Focus Area', options = ['temperature','shark type','victims info'])

st.sidebar.slider("Temp",0.0,100.0, (25.0,75.0))

current_query = render_plotly_ui(transformed_df,options)
update_state(current_query)

if name == ‘main’:
st.set_page_config(layout=‘wide’)
initialize_state()
main()