Graphviz plot

Hi, streamit is just great!
I’m running light_gbm and I want to plot the tree, but I got below message, while I can run successfully in my jupyter notebook, can anyone help advise where I went wrong? thanks a lot!

ax = lgb.plot_tree(gbm, tree_index=3, figsize=(35, 15), show_info=[‘split_gain’])
st.graphviz_chart(ax)

Exception : Unhandled type for graphviz chart: <class ‘matplotlib.axes._subplots.AxesSubplot’>

Hi @Haydn_He,
thanks a lot for posting! st.graphviz_chart() accepts a graphviz object as input parameter. I think lgb.plot_tree returns a matplotlib object (graphviz docs). Correct me if I am wrong on this. You will need to create or obtain a graphviz object and then pass it to st.graphviz_chart()
The docs have some more details, but the tricky part is probably getting a graphviz object out of light_gbm… at least I am not an expert on that.

You can use st.pyplot() though. Just swap plt.show() with st.pyplot() and you should able to visualize the output of lgb.plot_tree. Let me know if it does not work!

Best,
Matteo

1 Like

@monchier,

Appreciated your answer, yes you are right, the lgb.plot_tree object is actually a matplotlib, however st.pyplot just doesn’t work, I will continue dig to in this issues

regards

@monchier,

Thanks again for your help, I managed to get Plot_tree() works, while the graph quality is bad, I’m trying render graphviz plot by st.graphviz_plot(lgb.create_tree_digraph(gbm, tree_index=1)), there is no errors but nothing show up, any chance to advise where I missed will be appreciated.

Regards

Hey @Haydn_He,

I’m able to plot the graph using st.graphviz_chart using the example from the LightGBM repo

https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/plot_example.py

# coding: utf-8
import lightgbm as lgb
import pandas as pd
import streamlit as st

if lgb.compat.MATPLOTLIB_INSTALLED:
    import matplotlib.pyplot as plt
else:
    raise ImportError('You need to install matplotlib for plot_example.py.')

# load or create your dataset
df_train = pd.read_csv('regression.train', header=None, sep='\t')
df_test = pd.read_csv('regression.test', header=None, sep='\t')

y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)

# specify your configurations as a dict
params = {
    'num_leaves': 5,
    'metric': ('l1', 'l2'),
    'verbose': 0
}

evals_result = {}  # to record eval results for plotting

# train
gbm = lgb.train(params,
                lgb_train,
                num_boost_round=100,
                valid_sets=[lgb_train, lgb_test],
                feature_name=['f' + str(i + 1) for i in range(X_train.shape[-1])],
                categorical_feature=[21],
                evals_result=evals_result,
                verbose_eval=10)

graph = lgb.create_tree_digraph(gbm, tree_index=53, name='Tree54')
st.graphviz_chart(graph)

@Jonathan_Rhone, excellent!, thank you!

I eventually found out this works on Linus while on Windows still nothing displayed, but it’s good it’s works now

Regards

1 Like