Files
stt-server/stt-server.py

263 lines
7.1 KiB
Python
Executable File

#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
"""
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
Broadcasts JSON events to all connected WebSocket clients and to stdout.
Events:
{"event": "ready"}
{"event": "vad_start"}
{"event": "vad_end", "duration": 1.23}
{"event": "transcript", "text": "...", "words": [...], "duration": 1.23}
{"event": "error", "message": "..."}
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
Every WebSocket connection receives the full event stream from the moment it
connects — no subscription handshake required.
Machine-readable events are sent over WebSocket only.
Pass --verbose to enable logging to stderr (startup, VAD events, transcripts).
Errors always go to stderr regardless of verbosity.
Environment:
STT_PORT WebSocket port (default: 11501)
Usage:
./stt-server.py
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16 --verbose
"""
import sys
import json
import signal
import argparse
import threading
import queue
import subprocess
import traceback
import asyncio
import websockets
import numpy as np
import torch
SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
def log(msg, error=False):
if error or verbose:
sys.stderr.write(f'[stt] {msg}\n')
sys.stderr.flush()
# --- WebSocket broadcast ---
_ws_loop = None
_ws_clients = set() # set of asyncio.Queue, one per connection
def emit(event):
line = json.dumps(event)
if _ws_loop is not None:
for q in list(_ws_clients):
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
async def ws_handler(websocket):
q = asyncio.Queue()
_ws_clients.add(q)
log(f'client connected ({len(_ws_clients)} total)')
try:
while True:
msg = await q.get()
await websocket.send(msg)
except websockets.ConnectionClosed:
pass
finally:
_ws_clients.discard(q)
log(f'client disconnected ({len(_ws_clients)} remaining)')
async def ws_main():
global _ws_loop
_ws_loop = asyncio.get_running_loop()
async with websockets.serve(ws_handler, '', PORT):
log(f'WebSocket listening on port {PORT}')
await asyncio.Future() # run forever
def start_ws_server():
asyncio.run(ws_main())
# --- Mic ---
def find_mic():
candidates = [
['parec', ['--format=s16le', '--rate=16000', '--channels=1', '--latency-msec=50']],
['arecord', ['-f', 'S16_LE', '-r', '16000', '-c', '1', '-t', 'raw', '-q']],
]
for cmd, args in candidates:
try:
subprocess.run(['which', cmd], check=True, capture_output=True)
return cmd, args
except subprocess.CalledProcessError:
pass
raise RuntimeError('no mic capture command found — need parec or arecord')
def s16le_to_f32(data):
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# --- Args + model loading ---
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='base.en')
parser.add_argument('--device', default='cuda')
parser.add_argument('--compute-type', default='int8_float16')
parser.add_argument('--verbose', '-v', action='store_true')
args = parser.parse_args()
verbose = args.verbose
log(f'loading faster-whisper {args.model} ({args.device}, {args.compute_type})...')
from faster_whisper import WhisperModel
try:
model = WhisperModel(args.model, device=args.device, compute_type=args.compute_type)
log(f'model ready on {args.device}')
except Exception as e:
log(f'{args.device} failed ({e}), falling back to cpu', error=True)
model = WhisperModel(args.model, device='cpu', compute_type='int8')
log('model ready on cpu')
log('loading silero VAD...')
from silero_vad import load_silero_vad, VADIterator
vad_model = load_silero_vad()
vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
threshold=0.5, min_silence_duration_ms=500)
log('VAD ready')
# --- Pre-roll ring buffer ---
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
history_pos = 0
def push_history(samples):
global history_pos
n = len(samples)
base = history_pos % HISTORY_SAMPLES
space = HISTORY_SAMPLES - base
if n <= space:
history[base:base + n] = samples
else:
history[base:] = samples[:space]
history[:n - space] = samples[space:]
history_pos += n
def get_preroll():
start = max(0, history_pos - PRE_ROLL_SAMPLES)
count = history_pos - start
out = np.empty(count, dtype=np.float32)
for i in range(count):
out[i] = history[(start + i) % HISTORY_SAMPLES]
return out
# --- Transcription thread ---
transcription_queue = queue.Queue()
def transcription_worker():
while True:
item = transcription_queue.get()
if item is None:
break
samples, duration = item
try:
segments, _ = model.transcribe(
samples,
language='en',
word_timestamps=True,
vad_filter=False,
)
text = ''
words = []
for seg in segments:
text += seg.text
for w in (seg.words or []):
words.append({
'word': w.word,
'start': round(float(w.start), 4),
'end': round(float(w.end), 4),
'probability': round(float(w.probability), 4),
})
log(f'transcript: {json.dumps(text.strip())} ({len(words)} words)')
if text.strip():
emit({'event': 'transcript', 'text': text.strip(), 'words': words, 'duration': round(duration, 3)})
except Exception:
msg = traceback.format_exc()
log(f'transcription error:\n{msg}', error=True)
emit({'event': 'error', 'message': msg})
finally:
transcription_queue.task_done()
threading.Thread(target=transcription_worker, daemon=True).start()
threading.Thread(target=start_ws_server, daemon=True).start()
# --- Main recording + VAD loop ---
cmd, cmd_args = find_mic()
log(f'mic: {cmd} {" ".join(cmd_args)}')
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
def shutdown(sig=None, frame=None):
mic.terminate()
transcription_queue.put(None)
sys.exit(0)
signal.signal(signal.SIGTERM, shutdown)
signal.signal(signal.SIGINT, shutdown)
emit({'event': 'ready'})
speech_samples = []
speech_start = None
pending = b''
for chunk in mic.stdout:
pending += chunk
while len(pending) >= VAD_WINDOW * 2:
raw = pending[:VAD_WINDOW * 2]
pending = pending[VAD_WINDOW * 2:]
f32 = s16le_to_f32(raw)
push_history(f32)
result = vad(torch.from_numpy(f32), return_seconds=True)
if result is not None:
if 'start' in result:
speech_start = result['start']
speech_samples = [get_preroll()]
log(f'VAD start at {speech_start:.2f}s')
emit({'event': 'vad_start'})
elif 'end' in result and speech_start is not None:
duration = result['end'] - speech_start
log(f'VAD end at {result["end"]:.2f}s (duration {duration:.2f}s)')
emit({'event': 'vad_end', 'duration': round(duration, 3)})
segment = np.concatenate(speech_samples)
transcription_queue.put((segment, duration))
speech_samples = []
speech_start = None
vad.reset_states()
if speech_start is not None:
speech_samples.append(f32)