Hi all,
I’m trying to make a heatmap that could be queried, meaning that when user clicks on a heatmap cell, it applies a filter on the data therefore changing the heatmap itself.
The issue comes when I try to resize the height of the heatmap depending on the number of remaining lines in the plot. I have tried a bunch of differents ways to do this but none seem to work, here is the code I’m running on React side :
import React, { useEffect, useState, ReactNode } from "react";
import { Streamlit, withStreamlitConnection, ComponentProps, StreamlitComponentBase } from "streamlit-component-lib";
import Plot from "react-plotly.js"
class CustomHeat extends StreamlitComponentBase {
click_num= this.props.args.click_num; // Create a new variable of type number and initialize it to 0
public render = (): ReactNode => {
// let totalHeight: number = this.props.args.height;
const { data, layout, frames, config } = JSON.parse(this.props.args.spec)
const override_height = this.props.args["override_height"];
Streamlit.setFrameHeight(override_height);
return (
<Plot
data={data}
layout={layout}
frames={frames}
config={config}
onClick={this.handleClick}
style={{width: '100%', height: override_height}}
/>
)
}
private handleClick = (eventData: any) => {
this.click_num += 1;
Streamlit.setComponentValue(
eventData.points.map((p: any) => {
return { index: p.pointIndex, x: p.x, y: p.y, click_num: this.click_num}
})
)
}
}
On python side, in init.py :
import os
import streamlit.components.v1 as components
# Define location of the packaged frontend build
parent_dir = os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(parent_dir, "frontend/build")
# Create a function _component_func which will call the frontend component when run
_component_func = components.declare_component(
"custom_heatmap"
# , url="http://localhost:3001" # Fetch frontend component from local webserver
, path=build_dir
)
# Define a public function for the package,
# which wraps the caller to the frontend code
def st_custom_plot(fig, click_num, override_height, key):
component_value = _component_func(spec=fig.to_json(), click_num=click_num, override_height=override_height, key=key)
return component_value
On python side when actually using the component :
# Define a function to remove an element from the session state
def remove_element_from_session_state():
if st.session_state.heatmap_filter is not None:
st.success(f"'{st.session_state.heatmap_filter}' removed from filters.")
st.session_state.heatmap_filter = None
# st.session_state.heatmap[0]["y"] = None
st.session_state.has_been_clicked = False
else:
st.warning(f"No filter applied.")
def get_heatmap_data(df, y_col):
# Format the table used for the heatmap
heatmap_data = df.loc[:, ["Country", "Cat", y_col]]
heatmap_data = heatmap_data.groupby(["Country", "Cat"], as_index=False).mean()
heatmap_pivot = heatmap_data.pivot(index='Country', columns='Cat', values=y_col)
return heatmap_pivot
def get_fig_heatmap(heatmap_pivot_filtered, num_height):
layout = go.Layout(
# paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(217,217,214,0.1)',
title='Heatmap',
xaxis=dict(title='Criticality'),
yaxis=dict(title='Country'),
height=int(num_height*50), width=650
)
fig = go.Figure(data=go.Heatmap(
z=heatmap_pivot_filtered.values,
x=heatmap_pivot_filtered.columns,
y=heatmap_pivot_filtered.index,
type='heatmap',
hoverongaps = False,
colorscale= [
[0.0, "#FDFD96"], # pastel yellow
[1.0, "#FF6961"] # pastel red
]
,
# colorscale='gray_r',
zmin=heatmap_pivot_filtered.values.min(),
zmax=heatmap_pivot_filtered.values.max()), layout=layout)
# fig.update_layout(height=50*num_height)
return fig
def session_print(x):
if x in st.session_state:
st.write(st.session_state[x])
else:
st.write(f"{x} is not yet defined")
def update_heatmap_vars():
if "heatmap" in st.session_state: # if the heatmap has been rendered at least once
if st.session_state.heatmap is not None: # if the heatmap has been clicked on at least once
if st.session_state.click_num < st.session_state.heatmap[0]["click_num"]:
st.session_state.click_num = st.session_state.heatmap[0]["click_num"]
st.session_state.heatmap_filter = st.session_state.heatmap[0]["y"]
# else:
# st.session_state.heatmap_filter = None
else:
st.session_state.heatmap_filter = None
else:
st.session_state.heatmap_filter = None
def filter_heatmap_data(heatmap_pivot):
if "heatmap" in st.session_state: # if the heatmap has been rendered at least once
if st.session_state.heatmap is not None: #if the heatmap has been clicked on at least once
if st.session_state.heatmap_filter is not None:
heatmap_pivot_filtered = heatmap_pivot.loc[heatmap_pivot.index == st.session_state.heatmap_filter, :]
else:
heatmap_pivot_filtered = heatmap_pivot.copy()
else:
heatmap_pivot_filtered = heatmap_pivot.copy()
else:
heatmap_pivot_filtered = heatmap_pivot.copy()
return heatmap_pivot_filtered
def filter_all_data(df):
if "heatmap" in st.session_state:
if st.session_state.heatmap is not None:
if st.session_state.heatmap_filter is not None:
country = st.session_state.heatmap[0]["y"]
df = df.loc[df.Country == country, :]
return df
def display_heatmap(df, y_col):
# Transform data to the right format
heatmap_pivot = get_heatmap_data(df, y_col)
update_heatmap_vars()
# Button to remove the applied filter
if st.button("Remove Heatmap Filter", key="remove_heatmap"):
remove_element_from_session_state()
# Filter the data
heatmap_pivot_filtered = filter_heatmap_data(heatmap_pivot)
num_height_ticks = heatmap_pivot_filtered.shape[0]
if num_height_ticks != 0:
# Define the heatmap fig
fig = get_fig_heatmap(heatmap_pivot_filtered, num_height=num_height_ticks)
# Send and receive plot from React side
if "heatmap" in st.session_state:
del st.session_state.heatmap
st_custom_plot(fig, click_num=st.session_state.click_num,
override_height=int(num_height_ticks*50) , key="heatmap")
else:
st.warning("Current filter selection leads to no data")
# Apply filters
df = filter_all_data(df)
return df
This code may actually be useful for anyone trying to have a bi-directional plotly heatmap (it was inspired by many other similar use case).
Here is what the heatmap looks like at the end before filtering (many blanks for privacy purposes):
And after filtering without all the code related to num_height and override_height.:
If I now try to resize the fig in the same manner as I’m doing in the script above, the plot is fine at first rendering, however when I try to filter the lines and therefore change the height of the fig, its starts bugging. The plot keeps switching between new height and the old height at a high rate but it can’t seem to decide which one to choose.
Does anyone have any idea why ? Or just in which direction I should research ?
Sorry if the code isn’t perfect, it’s my first streamlit app and I still have a lot to learn !