diff --git a/setup-venv-local-build.sh b/setup-venv-local-build.sh index 7ccbbc0..e2dcf1d 100755 --- a/setup-venv-local-build.sh +++ b/setup-venv-local-build.sh @@ -57,7 +57,7 @@ fi echo "==> upgrading pip + build tools" "${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet -"${VENV}/bin/pip" install torch silero-vad +"${VENV}/bin/pip" install torch silero-vad websockets # --- clone (skipped if already done) --- if [ ! -d "${BUILD_DIR}/src/.git" ]; then diff --git a/setup-venv.sh b/setup-venv.sh index c017c9d..356b6cc 100755 --- a/setup-venv.sh +++ b/setup-venv.sh @@ -25,7 +25,7 @@ fi echo "==> installing torch and faster-whisper" "${VENV}/bin/pip" install --upgrade pip --quiet -"${VENV}/bin/pip" install torch faster-whisper silero-vad +"${VENV}/bin/pip" install torch faster-whisper silero-vad websockets echo "" echo "==> done. Venv ready at ${VENV}" diff --git a/stt-server.py b/stt-server.py index 8e7041a..38b994b 100755 --- a/stt-server.py +++ b/stt-server.py @@ -1,18 +1,25 @@ #!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"' """ -STT process: records audio, runs Silero VAD, transcribes with faster-whisper. +STT server: records audio, runs Silero VAD, transcribes with faster-whisper. +Broadcasts JSON events to all connected WebSocket clients and to stdout. -Events (JSON lines on stdout): +Events: {"event": "ready"} {"event": "vad_start"} - {"event": "vad_end", "duration": 1.23} - {"event": "transcript", "text": "...", "words": [...], "duration": 1.23} - {"event": "error", "message": "..."} + {"event": "vad_end", "duration": 1.23} + {"event": "transcript", "text": "...", "words": [...], "duration": 1.23} + {"event": "error", "message": "..."} word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99} +Every WebSocket connection receives the full event stream from the moment it +connects — no subscription handshake required. + All log/status messages go to stderr. Stdout is machine-readable events only. +Environment: + STT_PORT WebSocket port (default: 11501) + Usage: ./stt-server.py ./stt-server.py --model large-v3 --device cuda --compute-type int8_float16 @@ -26,13 +33,16 @@ import threading import queue import subprocess import traceback +import asyncio +import websockets import numpy as np import torch -SAMPLE_RATE = 16000 -VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz) -PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment -HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll +SAMPLE_RATE = 16000 +VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz) +PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context +HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll +PORT = int(__import__('os').environ.get('STT_PORT', 11501)) def log(msg): @@ -40,10 +50,49 @@ def log(msg): sys.stderr.flush() -def emit(event): - sys.stdout.write(json.dumps(event) + '\n') - sys.stdout.flush() +# --- WebSocket broadcast --- +_ws_loop = None +_ws_clients = set() # set of asyncio.Queue, one per connection + + +def emit(event): + line = json.dumps(event) + sys.stdout.write(line + '\n') + sys.stdout.flush() + if _ws_loop is not None: + for q in list(_ws_clients): + _ws_loop.call_soon_threadsafe(q.put_nowait, line) + + +async def ws_handler(websocket): + q = asyncio.Queue() + _ws_clients.add(q) + log(f'client connected ({len(_ws_clients)} total)') + try: + while True: + msg = await q.get() + await websocket.send(msg) + except websockets.ConnectionClosed: + pass + finally: + _ws_clients.discard(q) + log(f'client disconnected ({len(_ws_clients)} remaining)') + + +async def ws_main(): + global _ws_loop + _ws_loop = asyncio.get_running_loop() + async with websockets.serve(ws_handler, '', PORT): + log(f'WebSocket listening on port {PORT}') + await asyncio.Future() # run forever + + +def start_ws_server(): + asyncio.run(ws_main()) + + +# --- Mic --- def find_mic(): candidates = [ @@ -63,6 +112,8 @@ def s16le_to_f32(data): return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 +# --- Args + model loading --- + parser = argparse.ArgumentParser() parser.add_argument('--model', default='base.en') parser.add_argument('--device', default='cuda') @@ -82,25 +133,25 @@ except Exception as e: log('loading silero VAD...') from silero_vad import load_silero_vad, VADIterator vad_model = load_silero_vad() -vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE, - threshold=0.5, min_silence_duration_ms=500) +vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE, + threshold=0.5, min_silence_duration_ms=500) log('VAD ready') -# Ring buffer for pre-roll context +# --- Pre-roll ring buffer --- + history = np.zeros(HISTORY_SAMPLES, dtype=np.float32) history_pos = 0 def push_history(samples): global history_pos - n = len(samples) - base = history_pos % HISTORY_SAMPLES - # May wrap around — handle both cases + n = len(samples) + base = history_pos % HISTORY_SAMPLES space = HISTORY_SAMPLES - base if n <= space: history[base:base + n] = samples else: - history[base:] = samples[:space] + history[base:] = samples[:space] history[:n - space] = samples[space:] history_pos += n @@ -113,7 +164,8 @@ def get_preroll(): return out -# Transcription runs in a separate thread so VAD is never blocked by GPU +# --- Transcription thread --- + transcription_queue = queue.Queue() def transcription_worker(): @@ -152,9 +204,11 @@ def transcription_worker(): threading.Thread(target=transcription_worker, daemon=True).start() +threading.Thread(target=start_ws_server, daemon=True).start() -# Main recording + VAD loop +# --- Main recording + VAD loop --- + cmd, cmd_args = find_mic() log(f'mic: {cmd} {" ".join(cmd_args)}') mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)