Add WebSocket broadcast to stt-server.py

Every connection receives the full event stream (vad_start, vad_end,
transcript, error) from the moment it connects — no subscription
handshake required. The asyncio WebSocket server runs in a daemon thread
alongside the VAD loop and transcription thread. Events still go to
stdout unchanged.

Port is configurable via STT_PORT env var (default: 11501).
Add websockets to both setup scripts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-07 08:53:54 +00:00
parent a4fe95b24a
commit 18404708e3
3 changed files with 77 additions and 23 deletions

View File

@@ -57,7 +57,7 @@ fi
echo "==> upgrading pip + build tools"
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
"${VENV}/bin/pip" install torch silero-vad
"${VENV}/bin/pip" install torch silero-vad websockets
# --- clone (skipped if already done) ---
if [ ! -d "${BUILD_DIR}/src/.git" ]; then

View File

@@ -25,7 +25,7 @@ fi
echo "==> installing torch and faster-whisper"
"${VENV}/bin/pip" install --upgrade pip --quiet
"${VENV}/bin/pip" install torch faster-whisper silero-vad
"${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
echo ""
echo "==> done. Venv ready at ${VENV}"

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
"""
STT process: records audio, runs Silero VAD, transcribes with faster-whisper.
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
Broadcasts JSON events to all connected WebSocket clients and to stdout.
Events (JSON lines on stdout):
Events:
{"event": "ready"}
{"event": "vad_start"}
{"event": "vad_end", "duration": 1.23}
@@ -11,8 +12,14 @@ Events (JSON lines on stdout):
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
Every WebSocket connection receives the full event stream from the moment it
connects — no subscription handshake required.
All log/status messages go to stderr. Stdout is machine-readable events only.
Environment:
STT_PORT WebSocket port (default: 11501)
Usage:
./stt-server.py
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
@@ -26,13 +33,16 @@ import threading
import queue
import subprocess
import traceback
import asyncio
import websockets
import numpy as np
import torch
SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
def log(msg):
@@ -40,10 +50,49 @@ def log(msg):
sys.stderr.flush()
def emit(event):
sys.stdout.write(json.dumps(event) + '\n')
sys.stdout.flush()
# --- WebSocket broadcast ---
_ws_loop = None
_ws_clients = set() # set of asyncio.Queue, one per connection
def emit(event):
line = json.dumps(event)
sys.stdout.write(line + '\n')
sys.stdout.flush()
if _ws_loop is not None:
for q in list(_ws_clients):
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
async def ws_handler(websocket):
q = asyncio.Queue()
_ws_clients.add(q)
log(f'client connected ({len(_ws_clients)} total)')
try:
while True:
msg = await q.get()
await websocket.send(msg)
except websockets.ConnectionClosed:
pass
finally:
_ws_clients.discard(q)
log(f'client disconnected ({len(_ws_clients)} remaining)')
async def ws_main():
global _ws_loop
_ws_loop = asyncio.get_running_loop()
async with websockets.serve(ws_handler, '', PORT):
log(f'WebSocket listening on port {PORT}')
await asyncio.Future() # run forever
def start_ws_server():
asyncio.run(ws_main())
# --- Mic ---
def find_mic():
candidates = [
@@ -63,6 +112,8 @@ def s16le_to_f32(data):
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# --- Args + model loading ---
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='base.en')
parser.add_argument('--device', default='cuda')
@@ -87,7 +138,8 @@ vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
log('VAD ready')
# Ring buffer for pre-roll context
# --- Pre-roll ring buffer ---
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
history_pos = 0
@@ -95,7 +147,6 @@ def push_history(samples):
global history_pos
n = len(samples)
base = history_pos % HISTORY_SAMPLES
# May wrap around — handle both cases
space = HISTORY_SAMPLES - base
if n <= space:
history[base:base + n] = samples
@@ -113,7 +164,8 @@ def get_preroll():
return out
# Transcription runs in a separate thread so VAD is never blocked by GPU
# --- Transcription thread ---
transcription_queue = queue.Queue()
def transcription_worker():
@@ -152,9 +204,11 @@ def transcription_worker():
threading.Thread(target=transcription_worker, daemon=True).start()
threading.Thread(target=start_ws_server, daemon=True).start()
# Main recording + VAD loop
# --- Main recording + VAD loop ---
cmd, cmd_args = find_mic()
log(f'mic: {cmd} {" ".join(cmd_args)}')
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)