Problem about altair chart in streamlit

How can i fix this


to this

import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import streamlit as st
import pandas as pd
import altair as alt
N = 30
N_test = 20 
def main():
    #khởi tạo bộ tạo số ngẫu nhiên
    np.random.seed(100)
    X_true = np.linspace(0, 5, 51)
    y_true = 3*(X_true -2) * (X_true - 3)*(X_true-4)
    
    X = np.random.rand(N, 1)*5
    y = 3*(X -2) * (X - 3)*(X-4) +  10*np.random.randn(N, 1)
    X = X.tolist()
    y=y.tolist()
    X_test = (np.random.rand(N_test,1) - 1/8) *10
    y_test = 3*(X_test -2) * (X_test - 3)*(X_test-4) +  10*np.random.randn(N_test, 1)
    c = (
        alt.Chart(pd.DataFrame({"x":X,"y":y}))
        .mark_circle()
        .encode(
            x="x",
            y="y",
            
        )
    )
    z2="blue"
    line1 = (
    alt.Chart(pd.DataFrame({"x": X_true, "y": y_true,"z":z2}))
    .mark_line()
    .encode(x="x:Q", y="y:Q",color = alt.Color("z", scale=None),)
    )
    
    poly_features = PolynomialFeatures(degree=2, include_bias=False)
    X_poly = poly_features.fit_transform(X)
    lin_reg = LinearRegression()
    lin_reg.fit(X_poly, y)
    
    st.write(lin_reg.intercept_, lin_reg.coef_)
    w0 = lin_reg.intercept_[0]
    w1 = lin_reg.coef_[0,0]
    w2 = lin_reg.coef_[0,1]
    
    y_predict = w0 + X_true*w1 + X_true**2*w2
    z3="yellow"
    line2 = (
    alt.Chart(pd.DataFrame({"x": X_true, "y": y_predict,"z":z3}))
    .mark_line()
    .encode(x="x:Q", y="y:Q",color = alt.Color("z", scale=None),)
    )

    #tinh sai so tren tap test
    X_test_poly = poly_features.fit_transform(X_test)
    y_test_predict = lin_reg.predict(X_test_poly)
    mse_test = mean_squared_error(y_test, y_test_predict)
    rmse_test = np.sqrt(mse_test)       
    st.write('Sai so binh phuong trung binh - test: ')
    st.write('%.4f' % rmse_test)

    st.altair_chart(c+line1+line2, use_container_width=True)
    
if __name__ == '__main__':
    main()

This topic was automatically closed 365 days after the last reply. New replies are no longer allowed.