WebSocket server, language/task args, verbose flag, misc improvements #2

Merged
mikael-lovqvist merged 13 commits from mikael-lovqvists-claude-agent/stt-server:websocket-server into main 2026-06-07 09:27:02 +00:00
3 changed files with 77 additions and 23 deletions
Showing only changes of commit 18404708e3 - Show all commits

View File

@@ -57,7 +57,7 @@ fi
echo "==> upgrading pip + build tools"
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
"${VENV}/bin/pip" install torch silero-vad
"${VENV}/bin/pip" install torch silero-vad websockets
# --- clone (skipped if already done) ---
if [ ! -d "${BUILD_DIR}/src/.git" ]; then

View File

@@ -25,7 +25,7 @@ fi
echo "==> installing torch and faster-whisper"
"${VENV}/bin/pip" install --upgrade pip --quiet
"${VENV}/bin/pip" install torch faster-whisper silero-vad
"${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
echo ""
echo "==> done. Venv ready at ${VENV}"

View File

@@ -1,18 +1,25 @@
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
"""
STT process: records audio, runs Silero VAD, transcribes with faster-whisper.
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
Broadcasts JSON events to all connected WebSocket clients and to stdout.
Events (JSON lines on stdout):
Events:
{"event": "ready"}
{"event": "vad_start"}
{"event": "vad_end", "duration": 1.23}
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
{"event": "error", "message": "..."}
{"event": "vad_end", "duration": 1.23}
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
{"event": "error", "message": "..."}
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
Every WebSocket connection receives the full event stream from the moment it
connects — no subscription handshake required.
All log/status messages go to stderr. Stdout is machine-readable events only.
Environment:
STT_PORT WebSocket port (default: 11501)
Usage:
./stt-server.py
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
@@ -26,13 +33,16 @@ import threading
import queue
import subprocess
import traceback
import asyncio
import websockets
import numpy as np
import torch
SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
def log(msg):
@@ -40,10 +50,49 @@ def log(msg):
sys.stderr.flush()
def emit(event):
sys.stdout.write(json.dumps(event) + '\n')
sys.stdout.flush()
# --- WebSocket broadcast ---
_ws_loop = None
_ws_clients = set() # set of asyncio.Queue, one per connection
def emit(event):
line = json.dumps(event)
sys.stdout.write(line + '\n')
sys.stdout.flush()
if _ws_loop is not None:
for q in list(_ws_clients):
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
async def ws_handler(websocket):
q = asyncio.Queue()
_ws_clients.add(q)
log(f'client connected ({len(_ws_clients)} total)')
try:
while True:
msg = await q.get()
await websocket.send(msg)
except websockets.ConnectionClosed:
pass
finally:
_ws_clients.discard(q)
log(f'client disconnected ({len(_ws_clients)} remaining)')
async def ws_main():
global _ws_loop
_ws_loop = asyncio.get_running_loop()
async with websockets.serve(ws_handler, '', PORT):
log(f'WebSocket listening on port {PORT}')
await asyncio.Future() # run forever
def start_ws_server():
asyncio.run(ws_main())
# --- Mic ---
def find_mic():
candidates = [
@@ -63,6 +112,8 @@ def s16le_to_f32(data):
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# --- Args + model loading ---
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='base.en')
parser.add_argument('--device', default='cuda')
@@ -82,25 +133,25 @@ except Exception as e:
log('loading silero VAD...')
from silero_vad import load_silero_vad, VADIterator
vad_model = load_silero_vad()
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
threshold=0.5, min_silence_duration_ms=500)
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
threshold=0.5, min_silence_duration_ms=500)
log('VAD ready')
# Ring buffer for pre-roll context
# --- Pre-roll ring buffer ---
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
history_pos = 0
def push_history(samples):
global history_pos
n = len(samples)
base = history_pos % HISTORY_SAMPLES
# May wrap around — handle both cases
n = len(samples)
base = history_pos % HISTORY_SAMPLES
space = HISTORY_SAMPLES - base
if n <= space:
history[base:base + n] = samples
else:
history[base:] = samples[:space]
history[base:] = samples[:space]
history[:n - space] = samples[space:]
history_pos += n
@@ -113,7 +164,8 @@ def get_preroll():
return out
# Transcription runs in a separate thread so VAD is never blocked by GPU
# --- Transcription thread ---
transcription_queue = queue.Queue()
def transcription_worker():
@@ -152,9 +204,11 @@ def transcription_worker():
threading.Thread(target=transcription_worker, daemon=True).start()
threading.Thread(target=start_ws_server, daemon=True).start()
# Main recording + VAD loop
# --- Main recording + VAD loop ---
cmd, cmd_args = find_mic()
log(f'mic: {cmd} {" ".join(cmd_args)}')
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)