How can I combine streamlit_webrtc and asyncio to run inference in parallel with the audio buffering? is there any example of how to do this?
I am currently trying to deploy some Essentia model for music genre recognition with the help of streamlit_webrtc. Basically I want to capture the audio stream and then when a certain amount of frames is reached (e.g. 5 seconds) I run the prediction. Unfortunately, this has a lot of lags and delays so I am trying to run these in parallel with the help of asyncio. This seems to be blocking my pipeline, possibly because streamlit_webrtc uses asyncio too. So what is the proper way to do this. Here is my code:
import json
import time
import json
import pydub
import queue
import socket
import joblib
import librosa
import asyncio
import logging
import numpy as np
import pandas as pd
import logging.handlers
import streamlit as st
import tensorflow as tf
from pathlib import Path
from pprint import pprint
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from essentia.standard import TensorflowPredictEffnetDiscogs, TensorflowPredict2D
DEBUG = False
SAMPLE_RATE = 16000
HERE = Path(__file__).parent
logger = logging.getLogger(__name__)
# Disable eager mode for tf.v1 compatibility with tf.v2
tf.compat.v1.disable_eager_execution()
import warnings
warnings.filterwarnings('ignore')
def get_metadata(fname, section=None):
with open(fname) as f:
meta = json.load(f)
if section:
return meta[section]
return meta
MODEL_LABELS = {"ESSENTIA_genre_discogs400-discogs-effnet-1-random-forrest": get_metadata("pretrained-models/genre_discogs400-discogs-effnet-1.json", "classes")}
def load_essentia_model(model, labels, n_frames=None, overlap=None, debug=True):
embeddings_model_args = {"graphFilename": "pretrained-models/discogs-effnet-bs64-1.pb", "output": "PartitionedCall:1"}
predictions_model_args = {"graphFilename": "pretrained-models/genre_discogs400-discogs-effnet-1.pb", "input": "serving_default_model_Placeholder", "output": "PartitionedCall:0"}
embeddings_model = TensorflowPredictEffnetDiscogs(**embeddings_model_args)
predictions_model = TensorflowPredict2D(**predictions_model_args)
return embeddings_model, predictions_model
def predict_essentia(audio, embeddings_model, predictions_model):
# Run the models
#print("-> Audio to predict: ", audio.shape, audio.dtype, type(audio))
embeddings = embeddings_model(audio)
predictions = predictions_model(embeddings)
return predictions
async def audio_collector(webrtc_ctx, sound_chunk_queue):
"""Asynchronously collect audio frames from webrtc context"""
while webrtc_ctx.state.playing:
try:
audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1)
except queue.Empty:
continue
for audio_frame in audio_frames:
sound = pydub.AudioSegment(
data=audio_frame.to_ndarray().tobytes(),
sample_width=audio_frame.format.bytes,
frame_rate=audio_frame.sample_rate,
channels=len(audio_frame.layout.channels)
)
# Put the sound into the queue
await sound_chunk_queue.put(sound)
async def run_inference(sound_chunk, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output):
start_process = time.time()
# Audio processing
print(f"Inference: Received sound_chunk of length {len(sound_chunk)}")
sound_chunk = sound_chunk.set_channels(1).set_frame_rate(SAMPLE_RATE)
musicnn_audio_buffer = np.array(sound_chunk.get_array_of_samples(), dtype=np.int16)
float_musicnn_audio_buffer = musicnn_audio_buffer.astype(np.float32, order='C') / 32768.0
# Run inference
likelihoods = predict_essentia(float_musicnn_audio_buffer, embeddings_model, predictions_model)
results_dict = dict(zip(MODEL_LABELS[model], np.round(np.mean(likelihoods, axis=0), 9)))
# Aggregate results into merged categories
probabilities = random_forrest_model.predict_proba(pd.DataFrame([results_dict]))
genres_dict = {genre: prob for genre, prob in zip(random_forrest_model.classes_, probabilities.flatten())}
# Display genres_dict using Streamlit
maximum = max(results_dict, key=results_dict.get)
text_output.markdown(f"**Top Tag:** {maximum, results_dict[maximum]}")
json_output.write("Genres Dict:")
json_output.json(genres_dict)
# End timing
end_process = time.time()
print(f"> Inference done. Total processing time: {end_process - start_process:.4f}s")
print(f"Genres Dict: {genres_dict}")
print(f">===============================================================<")
async def process_audio(sound_chunk_queue, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output):
sound_chunk = pydub.AudioSegment.empty()
min_buffer_size = int(0.5 * SAMPLE_RATE) # 0.5 seconds buffer
while True:
sound = await sound_chunk_queue.get() # Get sound from the queue
# Debugging print: track the queue size after consuming a chunk
print(f"ProcessAudio: Retrieved frame from queue, current queue size: {sound_chunk_queue.qsize()}")
sound_chunk += sound
print(f"Current sound_chunk length: {len(sound_chunk)}, required: {min_buffer_size}")
# If we have enough audio data (2 seconds), run inference
if len(sound_chunk) >= min_buffer_size:
print(f"Running inference with sound chunk of length {len(sound_chunk)}")
await run_inference(sound_chunk, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output)
# Keep only the last 0.5 seconds of audio for overlap in next inference
last_frames = int(-0.5 * min_buffer_size)
sound_chunk = sound_chunk[last_frames:]
else:
# Force inference every few cycles even if the buffer isn't quite full (if no new data is coming)
if sound_chunk_queue.qsize() == 0 and len(sound_chunk) > int(1.5 * SAMPLE_RATE): # Check buffer length, even with slightly smaller chunks
print(f"Running forced inference with incomplete sound chunk of length {len(sound_chunk)}")
await run_inference(sound_chunk, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output)
sound_chunk = pydub.AudioSegment.empty() # Clear buffer after inference
def app_sst(model):
webrtc_ctx = webrtc_streamer(key="music-to-genre",
mode=WebRtcMode.SENDONLY,
audio_receiver_size=1024 * 8,
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
media_stream_constraints={"video": False, "audio": True})
status_indicator = st.empty()
text_output = st.empty()
json_output = st.empty()
# Load model
embeddings_model, predictions_model = load_essentia_model(model, MODEL_LABELS[model])
random_forrest_model = joblib.load('pretrained-classifiers/random_forest_model.pkl')
sound_chunk_queue = asyncio.Queue()
async def run_tasks():
audio_task = asyncio.create_task(audio_collector(webrtc_ctx, sound_chunk_queue))
process_task = asyncio.create_task(process_audio(sound_chunk_queue, model, embeddings_model, predictions_model, random_forrest_model, text_output, json_output))
# Run both tasks concurrently until `webrtc_ctx.state.playing` is False
while webrtc_ctx.state.playing:
await asyncio.sleep(0.1) # Non-blocking sleep to keep checking the state
# Gracefully cancel tasks when streaming stops
audio_task.cancel()
process_task.cancel()
# Start event loop for tasks
asyncio.run(run_tasks())
def main():
st.set_page_config(layout="wide")
st.header("Real Time Music Genre Classification")
st.markdown("""This demo app is using musicnn as a recognition engine.""")
sound_only_page = "Sound only demo"
app_mode = st.selectbox("Choose the app mode", [sound_only_page])
model = st.selectbox("Choose the model", ["ESSENTIA_genre_discogs400-discogs-effnet-1-random-forrest"])
if app_mode == sound_only_page:
app_sst(model)
if __name__ == "__main__":
logging.basicConfig(format="[%(asctime)s] %(levelname)7s from %(name)s in %(pathname)s:%(lineno)d: %(message)s", force=True)
logger.setLevel(level=logging.DEBUG if DEBUG else logging.INFO)
st_webrtc_logger = logging.getLogger("streamlit_webrtc")
st_webrtc_logger.setLevel(logging.DEBUG)
fsevents_logger = logging.getLogger("fsevents")
fsevents_logger.setLevel(logging.WARNING)
main()
I would also welcome any ideas to improve this for speed and performance.