forked from efforting.tech/stt-server
Add WebSocket broadcast to stt-server.py
Every connection receives the full event stream (vad_start, vad_end, transcript, error) from the moment it connects — no subscription handshake required. The asyncio WebSocket server runs in a daemon thread alongside the VAD loop and transcription thread. Events still go to stdout unchanged. Port is configurable via STT_PORT env var (default: 11501). Add websockets to both setup scripts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -57,7 +57,7 @@ fi
|
||||
|
||||
echo "==> upgrading pip + build tools"
|
||||
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
|
||||
"${VENV}/bin/pip" install torch silero-vad
|
||||
"${VENV}/bin/pip" install torch silero-vad websockets
|
||||
|
||||
# --- clone (skipped if already done) ---
|
||||
if [ ! -d "${BUILD_DIR}/src/.git" ]; then
|
||||
|
||||
@@ -25,7 +25,7 @@ fi
|
||||
|
||||
echo "==> installing torch and faster-whisper"
|
||||
"${VENV}/bin/pip" install --upgrade pip --quiet
|
||||
"${VENV}/bin/pip" install torch faster-whisper silero-vad
|
||||
"${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
|
||||
|
||||
echo ""
|
||||
echo "==> done. Venv ready at ${VENV}"
|
||||
|
||||
@@ -1,18 +1,25 @@
|
||||
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
||||
"""
|
||||
STT process: records audio, runs Silero VAD, transcribes with faster-whisper.
|
||||
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
|
||||
Broadcasts JSON events to all connected WebSocket clients and to stdout.
|
||||
|
||||
Events (JSON lines on stdout):
|
||||
Events:
|
||||
{"event": "ready"}
|
||||
{"event": "vad_start"}
|
||||
{"event": "vad_end", "duration": 1.23}
|
||||
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
|
||||
{"event": "error", "message": "..."}
|
||||
{"event": "vad_end", "duration": 1.23}
|
||||
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
|
||||
{"event": "error", "message": "..."}
|
||||
|
||||
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
|
||||
|
||||
Every WebSocket connection receives the full event stream from the moment it
|
||||
connects — no subscription handshake required.
|
||||
|
||||
All log/status messages go to stderr. Stdout is machine-readable events only.
|
||||
|
||||
Environment:
|
||||
STT_PORT WebSocket port (default: 11501)
|
||||
|
||||
Usage:
|
||||
./stt-server.py
|
||||
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
|
||||
@@ -26,13 +33,16 @@ import threading
|
||||
import queue
|
||||
import subprocess
|
||||
import traceback
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
|
||||
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment
|
||||
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
|
||||
SAMPLE_RATE = 16000
|
||||
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
|
||||
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
|
||||
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
|
||||
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
|
||||
|
||||
|
||||
def log(msg):
|
||||
@@ -40,10 +50,49 @@ def log(msg):
|
||||
sys.stderr.flush()
|
||||
|
||||
|
||||
def emit(event):
|
||||
sys.stdout.write(json.dumps(event) + '\n')
|
||||
sys.stdout.flush()
|
||||
# --- WebSocket broadcast ---
|
||||
|
||||
_ws_loop = None
|
||||
_ws_clients = set() # set of asyncio.Queue, one per connection
|
||||
|
||||
|
||||
def emit(event):
|
||||
line = json.dumps(event)
|
||||
sys.stdout.write(line + '\n')
|
||||
sys.stdout.flush()
|
||||
if _ws_loop is not None:
|
||||
for q in list(_ws_clients):
|
||||
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
|
||||
|
||||
|
||||
async def ws_handler(websocket):
|
||||
q = asyncio.Queue()
|
||||
_ws_clients.add(q)
|
||||
log(f'client connected ({len(_ws_clients)} total)')
|
||||
try:
|
||||
while True:
|
||||
msg = await q.get()
|
||||
await websocket.send(msg)
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
_ws_clients.discard(q)
|
||||
log(f'client disconnected ({len(_ws_clients)} remaining)')
|
||||
|
||||
|
||||
async def ws_main():
|
||||
global _ws_loop
|
||||
_ws_loop = asyncio.get_running_loop()
|
||||
async with websockets.serve(ws_handler, '', PORT):
|
||||
log(f'WebSocket listening on port {PORT}')
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
def start_ws_server():
|
||||
asyncio.run(ws_main())
|
||||
|
||||
|
||||
# --- Mic ---
|
||||
|
||||
def find_mic():
|
||||
candidates = [
|
||||
@@ -63,6 +112,8 @@ def s16le_to_f32(data):
|
||||
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
# --- Args + model loading ---
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='base.en')
|
||||
parser.add_argument('--device', default='cuda')
|
||||
@@ -82,25 +133,25 @@ except Exception as e:
|
||||
log('loading silero VAD...')
|
||||
from silero_vad import load_silero_vad, VADIterator
|
||||
vad_model = load_silero_vad()
|
||||
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
|
||||
threshold=0.5, min_silence_duration_ms=500)
|
||||
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
|
||||
threshold=0.5, min_silence_duration_ms=500)
|
||||
log('VAD ready')
|
||||
|
||||
|
||||
# Ring buffer for pre-roll context
|
||||
# --- Pre-roll ring buffer ---
|
||||
|
||||
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
|
||||
history_pos = 0
|
||||
|
||||
def push_history(samples):
|
||||
global history_pos
|
||||
n = len(samples)
|
||||
base = history_pos % HISTORY_SAMPLES
|
||||
# May wrap around — handle both cases
|
||||
n = len(samples)
|
||||
base = history_pos % HISTORY_SAMPLES
|
||||
space = HISTORY_SAMPLES - base
|
||||
if n <= space:
|
||||
history[base:base + n] = samples
|
||||
else:
|
||||
history[base:] = samples[:space]
|
||||
history[base:] = samples[:space]
|
||||
history[:n - space] = samples[space:]
|
||||
history_pos += n
|
||||
|
||||
@@ -113,7 +164,8 @@ def get_preroll():
|
||||
return out
|
||||
|
||||
|
||||
# Transcription runs in a separate thread so VAD is never blocked by GPU
|
||||
# --- Transcription thread ---
|
||||
|
||||
transcription_queue = queue.Queue()
|
||||
|
||||
def transcription_worker():
|
||||
@@ -152,9 +204,11 @@ def transcription_worker():
|
||||
|
||||
|
||||
threading.Thread(target=transcription_worker, daemon=True).start()
|
||||
threading.Thread(target=start_ws_server, daemon=True).start()
|
||||
|
||||
|
||||
# Main recording + VAD loop
|
||||
# --- Main recording + VAD loop ---
|
||||
|
||||
cmd, cmd_args = find_mic()
|
||||
log(f'mic: {cmd} {" ".join(cmd_args)}')
|
||||
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||
|
||||
Reference in New Issue
Block a user