I have some a piece of code on matplotlib plot, but found it was rendered very very slow(hours) in streamlit web page, while it finished only in some seconds in local pycharm.
Any one know what’s the cause ??
Paste my relevant code as below:
@st.cache
def plot_learning_curve(model, X, y, cv, train_sizes):
train_sizes, train_scores, test_scores = learning_curve(model, X=X, y=y, cv=cv, train_sizes=train_sizes)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
fig, ax = plt.subplots(figsize=(12, 10))
ax.set_xlabel("Training examples", fontsize=15)
ax.set_ylabel("Accuracy", fontsize=15)
ax.set_title("Model learning curve", fontsize=20)
ax.grid()
ax.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
ax.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1,
color="g")
ax.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
ax.plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Test score")
ax.legend(loc="best", fontsize=14)
return fig
fig = plot_learning_curve(clf, X, y, cv, train_sizes)
st.pyplot(fig)