Hi everyone,
I was wondering if you can help me with this.
I’ve built a linear regression app using Streamlit and one of the methods the user can select is gradient descent.
For my example, I’m using simulated data, which makes it converge pretty fast. Also, I’ve added an animation to show the progress of the predicted line, as the algorithm converges, along with a progress bar.
But essentially this repetitive plotting, even if sampled to maximum 100 times (regardless of the epochs set for the algorithm convergence), slows it down significantly.
Is there any way to accelerate the animation, as the algorithm itself runs in msecs?
Here’s the live app to try and play with Gradient Descent to see what I mean - https://liraweb.azurewebsites.net/
Here’s the code I’ve implemented for the Gradient Descent method.
Note, this includes the animation function: I’ve researched on this community for similar posts and have re-used functionality proposed here by @Adrien_Treuille (How to animate a line chart). I didn’t have to use init()
function b.t.w. but have initialised it correctly with no issues.
# ****************************************************************************************************
# Gradient Descent method to calculate coeffiencts a and b
# ****************************************************************************************************
def GD_method(rl, L, epochs):
X=rl['X']
y=rl['y']
n = float(len(X)) # Number of elements in X
y_pred = [0]*len(X)
# Setup the figure for plotting and animating Gradiend Descent
fig, ax = plt.subplots()
ax.plot(X,rl['y_act'], label = 'Actual (Population Regression Line)',color='green')
ax.plot(X, y, 'ro', label ='Collected data')
ax.plot(X, y_pred, label = 'Predicted (Least Squares Line)', color='purple')
ax.set_title('Actual vs Predicted')
ax.set_xlabel('X')
ax.set_ylabel('y')
ax.legend()
the_plot = st.pyplot(plt,clear_figure=False)
#def init():
# pred_line.set_ydata([0]*len(X))
def animate(i): # update the y values (every 1000ms)
ax.plot(X, y_pred, label = 'Predicted (Least Squares Line)', color='purple')
the_plot.pyplot(plt,clear_figure=False)
# *****************************
# Performing Gradient Descent
# ****************************
# Initialise a and b
a = 0
b = 0
# Initialise progress bar
my_bar = st.progress(0)
status_text = st.empty()
pb_i = round(epochs/100)
#init()
for i in range(epochs):
y_pred = a * X + b # The current predicted value of Y
D_a = (-2/n) * sum(X * (y - y_pred)) # Derivative wrt a
D_b = (-2/n) * sum(y - y_pred) # Derivative wrt b
a = a - L * D_a # Update m
b = b - L * D_b # Update c
# Animate sampled plots as algorithm converges along with progress bar
if((i % pb_i) == 0 and round(i/pb_i)<101):
animate(i)
my_bar.progress(round(i/pb_i))
status_text.text('Gradient Descent converged to the optimal values. Exiting...')
print('a converged at', a, 'b converged at ',b)
return a, b, y_pred
And here is the GitHub Repository of the whole app if you want to dig around - https://github.com/etzimopoulos/LiRA-Web-App
Any suggestions would be great.
Thank you all and have a nice day!