Hi, streamit is just great!
I’m running light_gbm and I want to plot the tree, but I got below message, while I can run successfully in my jupyter notebook, can anyone help advise where I went wrong? thanks a lot!
ax = lgb.plot_tree(gbm, tree_index=3, figsize=(35, 15), show_info=[‘split_gain’])
st.graphviz_chart(ax)
Exception : Unhandled type for graphviz chart: <class ‘matplotlib.axes._subplots.AxesSubplot’>
Hi @Haydn_He,
thanks a lot for posting! st.graphviz_chart()
accepts a graphviz
object as input parameter. I think lgb.plot_tree
returns a matplotlib object (graphviz docs). Correct me if I am wrong on this. You will need to create or obtain a graphviz object and then pass it to st.graphviz_chart()
The docs have some more details, but the tricky part is probably getting a graphviz object out of light_gbm
… at least I am not an expert on that.
You can use st.pyplot()
though. Just swap plt.show()
with st.pyplot()
and you should able to visualize the output of lgb.plot_tree
. Let me know if it does not work!
Best,
Matteo
1 Like
@monchier,
Appreciated your answer, yes you are right, the lgb.plot_tree object is actually a matplotlib, however st.pyplot just doesn’t work, I will continue dig to in this issues
regards
@monchier,
Thanks again for your help, I managed to get Plot_tree() works, while the graph quality is bad, I’m trying render graphviz plot by st.graphviz_plot(lgb.create_tree_digraph(gbm, tree_index=1)), there is no errors but nothing show up, any chance to advise where I missed will be appreciated.
Regards
Hey @Haydn_He,
I’m able to plot the graph using st.graphviz_chart
using the example from the LightGBM repo
https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/plot_example.py
# coding: utf-8
import lightgbm as lgb
import pandas as pd
import streamlit as st
if lgb.compat.MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
else:
raise ImportError('You need to install matplotlib for plot_example.py.')
# load or create your dataset
df_train = pd.read_csv('regression.train', header=None, sep='\t')
df_test = pd.read_csv('regression.test', header=None, sep='\t')
y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)
# specify your configurations as a dict
params = {
'num_leaves': 5,
'metric': ('l1', 'l2'),
'verbose': 0
}
evals_result = {} # to record eval results for plotting
# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_sets=[lgb_train, lgb_test],
feature_name=['f' + str(i + 1) for i in range(X_train.shape[-1])],
categorical_feature=[21],
evals_result=evals_result,
verbose_eval=10)
graph = lgb.create_tree_digraph(gbm, tree_index=53, name='Tree54')
st.graphviz_chart(graph)
@Jonathan_Rhone, excellent!, thank you!
I eventually found out this works on Linus while on Windows still nothing displayed, but it’s good it’s works now
Regards
1 Like