Add WebSocket broadcast to stt-server.py
Every connection receives the full event stream (vad_start, vad_end, transcript, error) from the moment it connects — no subscription handshake required. The asyncio WebSocket server runs in a daemon thread alongside the VAD loop and transcription thread. Events still go to stdout unchanged. Port is configurable via STT_PORT env var (default: 11501). Add websockets to both setup scripts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -57,7 +57,7 @@ fi
|
|||||||
|
|
||||||
echo "==> upgrading pip + build tools"
|
echo "==> upgrading pip + build tools"
|
||||||
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
|
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
|
||||||
"${VENV}/bin/pip" install torch silero-vad
|
"${VENV}/bin/pip" install torch silero-vad websockets
|
||||||
|
|
||||||
# --- clone (skipped if already done) ---
|
# --- clone (skipped if already done) ---
|
||||||
if [ ! -d "${BUILD_DIR}/src/.git" ]; then
|
if [ ! -d "${BUILD_DIR}/src/.git" ]; then
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ fi
|
|||||||
|
|
||||||
echo "==> installing torch and faster-whisper"
|
echo "==> installing torch and faster-whisper"
|
||||||
"${VENV}/bin/pip" install --upgrade pip --quiet
|
"${VENV}/bin/pip" install --upgrade pip --quiet
|
||||||
"${VENV}/bin/pip" install torch faster-whisper silero-vad
|
"${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "==> done. Venv ready at ${VENV}"
|
echo "==> done. Venv ready at ${VENV}"
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
||||||
"""
|
"""
|
||||||
STT process: records audio, runs Silero VAD, transcribes with faster-whisper.
|
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
|
||||||
|
Broadcasts JSON events to all connected WebSocket clients and to stdout.
|
||||||
|
|
||||||
Events (JSON lines on stdout):
|
Events:
|
||||||
{"event": "ready"}
|
{"event": "ready"}
|
||||||
{"event": "vad_start"}
|
{"event": "vad_start"}
|
||||||
{"event": "vad_end", "duration": 1.23}
|
{"event": "vad_end", "duration": 1.23}
|
||||||
@@ -11,8 +12,14 @@ Events (JSON lines on stdout):
|
|||||||
|
|
||||||
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
|
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
|
||||||
|
|
||||||
|
Every WebSocket connection receives the full event stream from the moment it
|
||||||
|
connects — no subscription handshake required.
|
||||||
|
|
||||||
All log/status messages go to stderr. Stdout is machine-readable events only.
|
All log/status messages go to stderr. Stdout is machine-readable events only.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
STT_PORT WebSocket port (default: 11501)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
./stt-server.py
|
./stt-server.py
|
||||||
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
|
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
|
||||||
@@ -26,13 +33,16 @@ import threading
|
|||||||
import queue
|
import queue
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
|
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
|
||||||
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment
|
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
|
||||||
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
|
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
|
||||||
|
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
|
||||||
|
|
||||||
|
|
||||||
def log(msg):
|
def log(msg):
|
||||||
@@ -40,10 +50,49 @@ def log(msg):
|
|||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
|
||||||
def emit(event):
|
# --- WebSocket broadcast ---
|
||||||
sys.stdout.write(json.dumps(event) + '\n')
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
_ws_loop = None
|
||||||
|
_ws_clients = set() # set of asyncio.Queue, one per connection
|
||||||
|
|
||||||
|
|
||||||
|
def emit(event):
|
||||||
|
line = json.dumps(event)
|
||||||
|
sys.stdout.write(line + '\n')
|
||||||
|
sys.stdout.flush()
|
||||||
|
if _ws_loop is not None:
|
||||||
|
for q in list(_ws_clients):
|
||||||
|
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_handler(websocket):
|
||||||
|
q = asyncio.Queue()
|
||||||
|
_ws_clients.add(q)
|
||||||
|
log(f'client connected ({len(_ws_clients)} total)')
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await q.get()
|
||||||
|
await websocket.send(msg)
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
_ws_clients.discard(q)
|
||||||
|
log(f'client disconnected ({len(_ws_clients)} remaining)')
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_main():
|
||||||
|
global _ws_loop
|
||||||
|
_ws_loop = asyncio.get_running_loop()
|
||||||
|
async with websockets.serve(ws_handler, '', PORT):
|
||||||
|
log(f'WebSocket listening on port {PORT}')
|
||||||
|
await asyncio.Future() # run forever
|
||||||
|
|
||||||
|
|
||||||
|
def start_ws_server():
|
||||||
|
asyncio.run(ws_main())
|
||||||
|
|
||||||
|
|
||||||
|
# --- Mic ---
|
||||||
|
|
||||||
def find_mic():
|
def find_mic():
|
||||||
candidates = [
|
candidates = [
|
||||||
@@ -63,6 +112,8 @@ def s16le_to_f32(data):
|
|||||||
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
# --- Args + model loading ---
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', default='base.en')
|
parser.add_argument('--model', default='base.en')
|
||||||
parser.add_argument('--device', default='cuda')
|
parser.add_argument('--device', default='cuda')
|
||||||
@@ -87,7 +138,8 @@ vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
|
|||||||
log('VAD ready')
|
log('VAD ready')
|
||||||
|
|
||||||
|
|
||||||
# Ring buffer for pre-roll context
|
# --- Pre-roll ring buffer ---
|
||||||
|
|
||||||
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
|
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
|
||||||
history_pos = 0
|
history_pos = 0
|
||||||
|
|
||||||
@@ -95,7 +147,6 @@ def push_history(samples):
|
|||||||
global history_pos
|
global history_pos
|
||||||
n = len(samples)
|
n = len(samples)
|
||||||
base = history_pos % HISTORY_SAMPLES
|
base = history_pos % HISTORY_SAMPLES
|
||||||
# May wrap around — handle both cases
|
|
||||||
space = HISTORY_SAMPLES - base
|
space = HISTORY_SAMPLES - base
|
||||||
if n <= space:
|
if n <= space:
|
||||||
history[base:base + n] = samples
|
history[base:base + n] = samples
|
||||||
@@ -113,7 +164,8 @@ def get_preroll():
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
# Transcription runs in a separate thread so VAD is never blocked by GPU
|
# --- Transcription thread ---
|
||||||
|
|
||||||
transcription_queue = queue.Queue()
|
transcription_queue = queue.Queue()
|
||||||
|
|
||||||
def transcription_worker():
|
def transcription_worker():
|
||||||
@@ -152,9 +204,11 @@ def transcription_worker():
|
|||||||
|
|
||||||
|
|
||||||
threading.Thread(target=transcription_worker, daemon=True).start()
|
threading.Thread(target=transcription_worker, daemon=True).start()
|
||||||
|
threading.Thread(target=start_ws_server, daemon=True).start()
|
||||||
|
|
||||||
|
|
||||||
# Main recording + VAD loop
|
# --- Main recording + VAD loop ---
|
||||||
|
|
||||||
cmd, cmd_args = find_mic()
|
cmd, cmd_args = find_mic()
|
||||||
log(f'mic: {cmd} {" ".join(cmd_args)}')
|
log(f'mic: {cmd} {" ".join(cmd_args)}')
|
||||||
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||||
|
|||||||
Reference in New Issue
Block a user