Merge pull request 'WebSocket server, language/task args, verbose flag, misc improvements' (#2) from mikael-lovqvists-claude-agent/stt-server:websocket-server into main

Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
2026-06-07 09:27:01 +00:00
7 changed files with 209 additions and 36 deletions

9
NOTES.md Normal file
View File

@@ -0,0 +1,9 @@
# Notes
## TranscriptionInfo — unused fields
`model.transcribe()` returns a `TranscriptionInfo` object as its second value. We currently use `language` and `language_probability`. Other available fields:
- **`all_language_probs`** — full ranked list of `(language, probability)` tuples for the segment. Useful for debugging misdetection — e.g. when the model hallucinates Sinhala on noise, this would show Sinhala at the top with a high probability. Could be included in transcript events or exposed as a diagnostic endpoint.
- **`duration`** — total audio duration fed to the model.
- **`duration_after_vad`** — speech duration according to Whisper's internal VAD (not meaningful since we pass `vad_filter=False`).

View File

@@ -16,3 +16,47 @@ This project started as a [vibe-coded](https://en.wikipedia.org/wiki/Vibe_coding
### Setup [venv](https://docs.python.org/3/library/venv.html) for [python](https://www.python.org/) ### Setup [venv](https://docs.python.org/3/library/venv.html) for [python](https://www.python.org/)
We will have two different setups here depending on if you want to build ctranslate2 locally or not. This shall be documented. We will have two different setups here depending on if you want to build ctranslate2 locally or not. This shall be documented.
## Model selection
Pass `--model <name>` to `stt-server.py`. Models are downloaded automatically from HuggingFace on first use.
| Model | VRAM | Quality | Notes |
|-------|------|---------|-------|
| `base.en` | ~0.5 GB (`float16`) / ~1 GB (`float32`) | Low | Default. Fast, but struggles with similar-sounding consonants (V/B/D). |
| `small.en` | ~1 GB (`float16`) / ~2 GB (`float32`) | Medium | Noticeable improvement over base for most speech. |
| `medium.en` | ~2.5 GB (`float16`) / ~5 GB (`float32`) | Good | Recommended starting point for production use. |
| `large-v3` | ~5 GB (`float16`) / ~10 GB (`float32`) | Best | Highest accuracy, use if VRAM allows. |
English-only models (`.en` suffix) are faster and more accurate than multilingual models for English speech.
## Compute type
Pass `--compute-type <type>` to control the numeric precision used during inference.
| Type | Notes |
|------|-------|
| `int8_float16` | Default. Good balance of speed and accuracy on modern GPUs. |
| `float16` | Slightly better accuracy, higher VRAM usage. |
| `int8` | CPU-friendly, lower quality. |
If you see a CUDA error about mismatched library versions at startup, use `setup-venv-local-build.sh` to build ctranslate2 against your system CUDA version rather than using the PyPI wheel.
## Language and translation
By default the server auto-detects the spoken language and transcribes it as-is.
| Argument | Default | Notes |
|----------|---------|-------|
| `--language <code>` | none (auto-detect) | Force a specific language, e.g. `--language en` or `--language sv`. Speeds up detection and avoids misidentification. |
| `--task transcribe` | default | Output text in the spoken language. |
| `--task translate` | | Translate speech to English regardless of source language. |
> [!NOTE]
> The `.en` model variants (`base.en`, `small.en` etc.) are English-only and do not support `--task translate` or non-English `--language`. Use a multilingual model (`large-v3`, `medium`) for multilingual or translation use cases.
> [!WARNING]
> Omitting `--language` with a multilingual model and English-only speech may cause occasional misdetection. Pass `--language en` to avoid this if you only speak English.

24
examples/listen.mjs Normal file
View File

@@ -0,0 +1,24 @@
// Connect to the STT server and print all events.
// Usage: node listen.mjs
const PORT = process.env.STT_PORT ?? '11501'
const ws = new WebSocket(`ws://localhost:${PORT}`)
ws.addEventListener('open', () => {
process.stderr.write(`connected to ws://localhost:${PORT}\n`)
})
ws.addEventListener('message', ({ data }) => {
const event = JSON.parse(data)
console.log(event)
})
ws.addEventListener('close', () => {
process.stderr.write('disconnected\n')
process.exit(0)
})
ws.addEventListener('error', (err) => {
process.stderr.write(`error: ${err.message}\n`)
process.exit(1)
})

19
examples/transcripts.mjs Normal file
View File

@@ -0,0 +1,19 @@
// Connect to the STT server and print transcript text only.
// Usage: node transcripts.mjs
const PORT = process.env.STT_PORT ?? '11501'
const ws = new WebSocket(`ws://localhost:${PORT}`)
ws.addEventListener('open', () => {
process.stderr.write(`connected to ws://localhost:${PORT}\n`)
})
ws.addEventListener('message', ({ data }) => {
const event = JSON.parse(data)
if (event.event === 'transcript') {
console.log(event.text)
}
})
ws.addEventListener('close', () => process.exit(0))
ws.addEventListener('error', () => process.exit(1))

View File

@@ -57,7 +57,7 @@ fi
echo "==> upgrading pip + build tools" echo "==> upgrading pip + build tools"
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet "${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
"${VENV}/bin/pip" install torch silero-vad "${VENV}/bin/pip" install torch silero-vad websockets
# --- clone (skipped if already done) --- # --- clone (skipped if already done) ---
if [ ! -d "${BUILD_DIR}/src/.git" ]; then if [ ! -d "${BUILD_DIR}/src/.git" ]; then

View File

@@ -25,7 +25,7 @@ fi
echo "==> installing torch and faster-whisper" echo "==> installing torch and faster-whisper"
"${VENV}/bin/pip" install --upgrade pip --quiet "${VENV}/bin/pip" install --upgrade pip --quiet
"${VENV}/bin/pip" install torch faster-whisper silero-vad "${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
echo "" echo ""
echo "==> done. Venv ready at ${VENV}" echo "==> done. Venv ready at ${VENV}"

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"' #!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
""" """
STT process: records audio, runs Silero VAD, transcribes with faster-whisper. STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
Broadcasts JSON events to all connected WebSocket clients and to stdout.
Events (JSON lines on stdout): Events:
{"event": "ready"} {"event": "ready"}
{"event": "vad_start"} {"event": "vad_start"}
{"event": "vad_end", "duration": 1.23} {"event": "vad_end", "duration": 1.23}
@@ -11,13 +12,22 @@ Events (JSON lines on stdout):
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99} word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
All log/status messages go to stderr. Stdout is machine-readable events only. Every WebSocket connection receives the full event stream from the moment it
connects — no subscription handshake required.
Machine-readable events are sent over WebSocket only.
Pass --verbose to enable logging to stderr (startup, VAD events, transcripts).
Errors always go to stderr regardless of verbosity.
Environment:
STT_PORT WebSocket port (default: 11501)
Usage: Usage:
./stt-server.py ./stt-server.py
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16 ./stt-server.py --model large-v3 --device cuda --compute-type int8_float16 --verbose
""" """
import os
import sys import sys
import json import json
import signal import signal
@@ -26,24 +36,65 @@ import threading
import queue import queue
import subprocess import subprocess
import traceback import traceback
import asyncio
import websockets
import numpy as np import numpy as np
import torch import torch
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz) VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
def log(msg): def log(msg, error=False):
if error or verbose:
sys.stderr.write(f'[stt] {msg}\n') sys.stderr.write(f'[stt] {msg}\n')
sys.stderr.flush() sys.stderr.flush()
def emit(event): # --- WebSocket broadcast ---
sys.stdout.write(json.dumps(event) + '\n')
sys.stdout.flush()
_ws_loop = None
_ws_clients = set() # set of asyncio.Queue, one per connection
def emit(event):
line = json.dumps(event)
if _ws_loop is not None:
for q in list(_ws_clients):
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
async def ws_handler(websocket):
q = asyncio.Queue()
_ws_clients.add(q)
log(f'client connected ({len(_ws_clients)} total)')
try:
while True:
msg = await q.get()
await websocket.send(msg)
except websockets.ConnectionClosed:
pass
finally:
_ws_clients.discard(q)
log(f'client disconnected ({len(_ws_clients)} remaining)')
async def ws_main():
global _ws_loop
_ws_loop = asyncio.get_running_loop()
async with websockets.serve(ws_handler, '', PORT):
log(f'WebSocket listening on port {PORT}')
await asyncio.Future() # run forever
def start_ws_server():
asyncio.run(ws_main())
# --- Mic ---
def find_mic(): def find_mic():
candidates = [ candidates = [
@@ -63,19 +114,39 @@ def s16le_to_f32(data):
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# --- Args + model loading ---
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='base.en') parser.add_argument('--model', default='base.en')
parser.add_argument('--device', default='cuda') parser.add_argument('--device', default='cuda')
parser.add_argument('--compute-type', default='int8_float16') parser.add_argument('--compute-type', default='int8_float16')
parser.add_argument('--language', default=None, help='language code (e.g. en, sv) or None for auto-detect')
parser.add_argument('--task', default='transcribe', choices=['transcribe', 'translate'], help='transcribe keeps the source language; translate converts to English')
parser.add_argument('--verbose', '-v', action='store_true')
args = parser.parse_args() args = parser.parse_args()
verbose = args.verbose
token_file = os.environ.get('HF_TOKEN_FILE', os.path.expanduser('~/.secrets/hugging-face.token'))
try:
with open(token_file) as f:
os.environ['HF_TOKEN'] = f.read().strip()
except FileNotFoundError:
pass
from faster_whisper import WhisperModel
from huggingface_hub import snapshot_download
try:
snapshot_download(f'Systran/faster-whisper-{args.model}', local_files_only=True)
except Exception:
log(f'downloading model {args.model}...')
log(f'loading faster-whisper {args.model} ({args.device}, {args.compute_type})...') log(f'loading faster-whisper {args.model} ({args.device}, {args.compute_type})...')
from faster_whisper import WhisperModel
try: try:
model = WhisperModel(args.model, device=args.device, compute_type=args.compute_type) model = WhisperModel(args.model, device=args.device, compute_type=args.compute_type)
log(f'model ready on {args.device}') log(f'model ready on {args.device}')
except Exception as e: except Exception as e:
log(f'{args.device} failed ({e}), falling back to cpu') log(f'{args.device} failed ({e}), falling back to cpu', error=True)
model = WhisperModel(args.model, device='cpu', compute_type='int8') model = WhisperModel(args.model, device='cpu', compute_type='int8')
log('model ready on cpu') log('model ready on cpu')
@@ -87,7 +158,8 @@ vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
log('VAD ready') log('VAD ready')
# Ring buffer for pre-roll context # --- Pre-roll ring buffer ---
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32) history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
history_pos = 0 history_pos = 0
@@ -95,7 +167,6 @@ def push_history(samples):
global history_pos global history_pos
n = len(samples) n = len(samples)
base = history_pos % HISTORY_SAMPLES base = history_pos % HISTORY_SAMPLES
# May wrap around — handle both cases
space = HISTORY_SAMPLES - base space = HISTORY_SAMPLES - base
if n <= space: if n <= space:
history[base:base + n] = samples history[base:base + n] = samples
@@ -113,7 +184,8 @@ def get_preroll():
return out return out
# Transcription runs in a separate thread so VAD is never blocked by GPU # --- Transcription thread ---
transcription_queue = queue.Queue() transcription_queue = queue.Queue()
def transcription_worker(): def transcription_worker():
@@ -123,9 +195,10 @@ def transcription_worker():
break break
samples, duration = item samples, duration = item
try: try:
segments, _ = model.transcribe( segments, info = model.transcribe(
samples, samples,
language='en', language=args.language,
task=args.task,
word_timestamps=True, word_timestamps=True,
vad_filter=False, vad_filter=False,
) )
@@ -140,21 +213,25 @@ def transcription_worker():
'end': round(float(w.end), 4), 'end': round(float(w.end), 4),
'probability': round(float(w.probability), 4), 'probability': round(float(w.probability), 4),
}) })
log(f'transcript: {json.dumps(text.strip())} ({len(words)} words)') language = info.language
lang_prob = round(float(info.language_probability), 3)
log(f'transcript [{language} {lang_prob}]: {json.dumps(text.strip())} ({len(words)} words)')
if text.strip(): if text.strip():
emit({'event': 'transcript', 'text': text.strip(), 'words': words, 'duration': round(duration, 3)}) emit({'event': 'transcript', 'text': text.strip(), 'words': words, 'duration': round(duration, 3), 'language': language, 'language_probability': lang_prob})
except Exception: except Exception:
msg = traceback.format_exc() msg = traceback.format_exc()
log(f'transcription error:\n{msg}') log(f'transcription error:\n{msg}', error=True)
emit({'event': 'error', 'message': msg}) emit({'event': 'error', 'message': msg})
finally: finally:
transcription_queue.task_done() transcription_queue.task_done()
threading.Thread(target=transcription_worker, daemon=True).start() threading.Thread(target=transcription_worker, daemon=True).start()
threading.Thread(target=start_ws_server, daemon=True).start()
# Main recording + VAD loop # --- Main recording + VAD loop ---
cmd, cmd_args = find_mic() cmd, cmd_args = find_mic()
log(f'mic: {cmd} {" ".join(cmd_args)}') log(f'mic: {cmd} {" ".join(cmd_args)}')
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)