Add WebSocket broadcast to stt-server.py

Every connection receives the full event stream (vad_start, vad_end,
transcript, error) from the moment it connects — no subscription
handshake required. The asyncio WebSocket server runs in a daemon thread
alongside the VAD loop and transcription thread. Events still go to
stdout unchanged.

Port is configurable via STT_PORT env var (default: 11501).
Add websockets to both setup scripts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-07 08:53:54 +00:00
parent a4fe95b24a
commit 18404708e3
3 changed files with 77 additions and 23 deletions

View File

@@ -57,7 +57,7 @@ fi
echo "==> upgrading pip + build tools" echo "==> upgrading pip + build tools"
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet "${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
"${VENV}/bin/pip" install torch silero-vad "${VENV}/bin/pip" install torch silero-vad websockets
# --- clone (skipped if already done) --- # --- clone (skipped if already done) ---
if [ ! -d "${BUILD_DIR}/src/.git" ]; then if [ ! -d "${BUILD_DIR}/src/.git" ]; then

View File

@@ -25,7 +25,7 @@ fi
echo "==> installing torch and faster-whisper" echo "==> installing torch and faster-whisper"
"${VENV}/bin/pip" install --upgrade pip --quiet "${VENV}/bin/pip" install --upgrade pip --quiet
"${VENV}/bin/pip" install torch faster-whisper silero-vad "${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
echo "" echo ""
echo "==> done. Venv ready at ${VENV}" echo "==> done. Venv ready at ${VENV}"

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"' #!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
""" """
STT process: records audio, runs Silero VAD, transcribes with faster-whisper. STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
Broadcasts JSON events to all connected WebSocket clients and to stdout.
Events (JSON lines on stdout): Events:
{"event": "ready"} {"event": "ready"}
{"event": "vad_start"} {"event": "vad_start"}
{"event": "vad_end", "duration": 1.23} {"event": "vad_end", "duration": 1.23}
@@ -11,8 +12,14 @@ Events (JSON lines on stdout):
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99} word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
Every WebSocket connection receives the full event stream from the moment it
connects — no subscription handshake required.
All log/status messages go to stderr. Stdout is machine-readable events only. All log/status messages go to stderr. Stdout is machine-readable events only.
Environment:
STT_PORT WebSocket port (default: 11501)
Usage: Usage:
./stt-server.py ./stt-server.py
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16 ./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
@@ -26,13 +33,16 @@ import threading
import queue import queue
import subprocess import subprocess
import traceback import traceback
import asyncio
import websockets
import numpy as np import numpy as np
import torch import torch
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz) VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
def log(msg): def log(msg):
@@ -40,10 +50,49 @@ def log(msg):
sys.stderr.flush() sys.stderr.flush()
def emit(event): # --- WebSocket broadcast ---
sys.stdout.write(json.dumps(event) + '\n')
sys.stdout.flush()
_ws_loop = None
_ws_clients = set() # set of asyncio.Queue, one per connection
def emit(event):
line = json.dumps(event)
sys.stdout.write(line + '\n')
sys.stdout.flush()
if _ws_loop is not None:
for q in list(_ws_clients):
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
async def ws_handler(websocket):
q = asyncio.Queue()
_ws_clients.add(q)
log(f'client connected ({len(_ws_clients)} total)')
try:
while True:
msg = await q.get()
await websocket.send(msg)
except websockets.ConnectionClosed:
pass
finally:
_ws_clients.discard(q)
log(f'client disconnected ({len(_ws_clients)} remaining)')
async def ws_main():
global _ws_loop
_ws_loop = asyncio.get_running_loop()
async with websockets.serve(ws_handler, '', PORT):
log(f'WebSocket listening on port {PORT}')
await asyncio.Future() # run forever
def start_ws_server():
asyncio.run(ws_main())
# --- Mic ---
def find_mic(): def find_mic():
candidates = [ candidates = [
@@ -63,6 +112,8 @@ def s16le_to_f32(data):
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# --- Args + model loading ---
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='base.en') parser.add_argument('--model', default='base.en')
parser.add_argument('--device', default='cuda') parser.add_argument('--device', default='cuda')
@@ -87,7 +138,8 @@ vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
log('VAD ready') log('VAD ready')
# Ring buffer for pre-roll context # --- Pre-roll ring buffer ---
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32) history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
history_pos = 0 history_pos = 0
@@ -95,7 +147,6 @@ def push_history(samples):
global history_pos global history_pos
n = len(samples) n = len(samples)
base = history_pos % HISTORY_SAMPLES base = history_pos % HISTORY_SAMPLES
# May wrap around — handle both cases
space = HISTORY_SAMPLES - base space = HISTORY_SAMPLES - base
if n <= space: if n <= space:
history[base:base + n] = samples history[base:base + n] = samples
@@ -113,7 +164,8 @@ def get_preroll():
return out return out
# Transcription runs in a separate thread so VAD is never blocked by GPU # --- Transcription thread ---
transcription_queue = queue.Queue() transcription_queue = queue.Queue()
def transcription_worker(): def transcription_worker():
@@ -152,9 +204,11 @@ def transcription_worker():
threading.Thread(target=transcription_worker, daemon=True).start() threading.Thread(target=transcription_worker, daemon=True).start()
threading.Thread(target=start_ws_server, daemon=True).start()
# Main recording + VAD loop # --- Main recording + VAD loop ---
cmd, cmd_args = find_mic() cmd, cmd_args = find_mic()
log(f'mic: {cmd} {" ".join(cmd_args)}') log(f'mic: {cmd} {" ".join(cmd_args)}')
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)