How can I animate a line chart for Gradient Descent?

Hi everyone,

I was wondering if you can help me with this.

I’ve built a linear regression app using Streamlit and one of the methods the user can select is gradient descent.

For my example, I’m using simulated data, which makes it converge pretty fast. Also, I’ve added an animation to show the progress of the predicted line, as the algorithm converges, along with a progress bar.

But essentially this repetitive plotting, even if sampled to maximum 100 times (regardless of the epochs set for the algorithm convergence), slows it down significantly.

Is there any way to accelerate the animation, as the algorithm itself runs in msecs?

Here’s the live app to try and play with Gradient Descent to see what I mean - https://liraweb.azurewebsites.net/

Here’s the code I’ve implemented for the Gradient Descent method.

Note, this includes the animation function: I’ve researched on this community for similar posts and have re-used functionality proposed here by @Adrien_Treuille (How to animate a line chart). I didn’t have to use init() function b.t.w. but have initialised it correctly with no issues.

# ****************************************************************************************************
# Gradient Descent method to calculate coeffiencts a and b
# ****************************************************************************************************

def GD_method(rl, L, epochs):
    

    X=rl['X'] 
    y=rl['y']
    n = float(len(X)) # Number of elements in X
    y_pred = [0]*len(X)
    # Setup the figure for plotting and animating Gradiend Descent
    fig, ax = plt.subplots()
    ax.plot(X,rl['y_act'], label = 'Actual (Population Regression Line)',color='green')
    ax.plot(X, y, 'ro', label ='Collected data')   
    ax.plot(X, y_pred, label = 'Predicted (Least Squares Line)', color='purple')
    ax.set_title('Actual vs Predicted')
    ax.set_xlabel('X')
    ax.set_ylabel('y')
    ax.legend()
    the_plot = st.pyplot(plt,clear_figure=False)
    
    #def init():
    #    pred_line.set_ydata([0]*len(X))
    
    def animate(i):  # update the y values (every 1000ms)
        ax.plot(X, y_pred, label = 'Predicted (Least Squares Line)', color='purple')
        the_plot.pyplot(plt,clear_figure=False)
        
        
    # *****************************
    # Performing Gradient Descent 
    # ****************************
    # Initialise a and b
    a = 0
    b = 0
    # Initialise progress bar
    my_bar = st.progress(0)       
    status_text = st.empty()
    pb_i = round(epochs/100)
    #init()
    for i in range(epochs): 
        y_pred = a * X + b  # The current predicted value of Y
        D_a = (-2/n) * sum(X * (y - y_pred))  # Derivative wrt a
        D_b = (-2/n) * sum(y - y_pred)  # Derivative wrt b
        a = a - L * D_a  # Update m
        b = b - L * D_b  # Update c
        
        # Animate sampled plots as algorithm converges along with progress bar
        if((i % pb_i) == 0 and round(i/pb_i)<101):
            animate(i)
            my_bar.progress(round(i/pb_i))
    status_text.text('Gradient Descent converged to the optimal values. Exiting...')    
    print('a converged at', a, 'b converged at ',b)    
       
    return a, b, y_pred

And here is the GitHub Repository of the whole app if you want to dig around - https://github.com/etzimopoulos/LiRA-Web-App

Any suggestions would be great. :pray:t2:

Thank you all and have a nice day! :boom: :100:

Hey @etzimopoulos,

For this animation task, I expect you will have more luck using Altair to plot the graph rather than Matplotlib. A similar conclusion was reached in this benchmark by another user.

This is primarily due to Streamlit creating, serving and destroying images back and forth on the Server for Matplotlib, but only sending JSON arrays for Altair and letting the browser do the rendering.

You may also try with Plotly but I found the regular remounting of the Plotly plot costly in time compared to Altair, and from memory Streamlit doesn’t take into account Plotly’s update_traces to update part of an existing plot without recreating the whole plot

Other harder ideas to test :

  • hey, that could actually do a pretty good Streamlit Component, an animated Plotly which only runs updates to traces rather than remounting the full graph when receiving new data
  • MAYBE you could try the Streamlit Echarts component, use a key parameter to prevent the component from remounting and send the updated data to animate it (see the end of the README). It’s basically the same idea as the first bullet point but implemented in echarts.

Fanilo

3 Likes

Hi @andfanilo,

Thank you for your comments, very interesting. I’ll check the links, do my homework and see if any of these work.

I haven’t come across Altair before so will give it a go and see.

Thanks,
Angelo

1 Like

Hi @andfanilo, just as an update, I’ve now made some progress on that and Altair works really nicely.

I’ve stumbled across another issue thought that prevents me from selecting animation method (I’ve implemented it so I can select either Matplotlib, Altair, Plotly or No Animation (to run quickly).

I’ve raised the issue here not sure if you could help shed some light?

I’ve followed a similar approach to the thread here whereby I have introduced a “Predict” Button for the Gradient Descent method, then a Selectbox for the Plotting method, and then another “Go” Button for executing the animation. It seems that the combination of Predict --> Select --> Go doesn’t work as the Select method doesn’t work other than the first option in the Select dropdown. Every other option takes me back to the Predict button as if it wasn’t assigned.

What am I doing wrong here? :frowning: :disappointed: