Add WebSocket broadcast to stt-server.py

Every connection receives the full event stream (vad_start, vad_end,
transcript, error) from the moment it connects — no subscription
handshake required. The asyncio WebSocket server runs in a daemon thread
alongside the VAD loop and transcription thread. Events still go to
stdout unchanged.

Port is configurable via STT_PORT env var (default: 11501).
Add websockets to both setup scripts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-07 08:53:54 +00:00
parent a4fe95b24a
commit 18404708e3
3 changed files with 77 additions and 23 deletions

View File

@@ -1,18 +1,25 @@
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
"""
STT process: records audio, runs Silero VAD, transcribes with faster-whisper.
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
Broadcasts JSON events to all connected WebSocket clients and to stdout.
Events (JSON lines on stdout):
Events:
{"event": "ready"}
{"event": "vad_start"}
{"event": "vad_end", "duration": 1.23}
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
{"event": "error", "message": "..."}
{"event": "vad_end", "duration": 1.23}
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
{"event": "error", "message": "..."}
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
Every WebSocket connection receives the full event stream from the moment it
connects — no subscription handshake required.
All log/status messages go to stderr. Stdout is machine-readable events only.
Environment:
STT_PORT WebSocket port (default: 11501)
Usage:
./stt-server.py
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
@@ -26,13 +33,16 @@ import threading
import queue
import subprocess
import traceback
import asyncio
import websockets
import numpy as np
import torch
SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
def log(msg):
@@ -40,10 +50,49 @@ def log(msg):
sys.stderr.flush()
def emit(event):
sys.stdout.write(json.dumps(event) + '\n')
sys.stdout.flush()
# --- WebSocket broadcast ---
_ws_loop = None
_ws_clients = set() # set of asyncio.Queue, one per connection
def emit(event):
line = json.dumps(event)
sys.stdout.write(line + '\n')
sys.stdout.flush()
if _ws_loop is not None:
for q in list(_ws_clients):
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
async def ws_handler(websocket):
q = asyncio.Queue()
_ws_clients.add(q)
log(f'client connected ({len(_ws_clients)} total)')
try:
while True:
msg = await q.get()
await websocket.send(msg)
except websockets.ConnectionClosed:
pass
finally:
_ws_clients.discard(q)
log(f'client disconnected ({len(_ws_clients)} remaining)')
async def ws_main():
global _ws_loop
_ws_loop = asyncio.get_running_loop()
async with websockets.serve(ws_handler, '', PORT):
log(f'WebSocket listening on port {PORT}')
await asyncio.Future() # run forever
def start_ws_server():
asyncio.run(ws_main())
# --- Mic ---
def find_mic():
candidates = [
@@ -63,6 +112,8 @@ def s16le_to_f32(data):
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# --- Args + model loading ---
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='base.en')
parser.add_argument('--device', default='cuda')
@@ -82,25 +133,25 @@ except Exception as e:
log('loading silero VAD...')
from silero_vad import load_silero_vad, VADIterator
vad_model = load_silero_vad()
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
threshold=0.5, min_silence_duration_ms=500)
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
threshold=0.5, min_silence_duration_ms=500)
log('VAD ready')
# Ring buffer for pre-roll context
# --- Pre-roll ring buffer ---
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
history_pos = 0
def push_history(samples):
global history_pos
n = len(samples)
base = history_pos % HISTORY_SAMPLES
# May wrap around — handle both cases
n = len(samples)
base = history_pos % HISTORY_SAMPLES
space = HISTORY_SAMPLES - base
if n <= space:
history[base:base + n] = samples
else:
history[base:] = samples[:space]
history[base:] = samples[:space]
history[:n - space] = samples[space:]
history_pos += n
@@ -113,7 +164,8 @@ def get_preroll():
return out
# Transcription runs in a separate thread so VAD is never blocked by GPU
# --- Transcription thread ---
transcription_queue = queue.Queue()
def transcription_worker():
@@ -152,9 +204,11 @@ def transcription_worker():
threading.Thread(target=transcription_worker, daemon=True).start()
threading.Thread(target=start_ws_server, daemon=True).start()
# Main recording + VAD loop
# --- Main recording + VAD loop ---
cmd, cmd_args = find_mic()
log(f'mic: {cmd} {" ".join(cmd_args)}')
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)