Merge pull request 'WebSocket server, language/task args, verbose flag, misc improvements' (#2) from mikael-lovqvists-claude-agent/stt-server:websocket-server into main
Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
9
NOTES.md
Normal file
9
NOTES.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# Notes
|
||||||
|
|
||||||
|
## TranscriptionInfo — unused fields
|
||||||
|
|
||||||
|
`model.transcribe()` returns a `TranscriptionInfo` object as its second value. We currently use `language` and `language_probability`. Other available fields:
|
||||||
|
|
||||||
|
- **`all_language_probs`** — full ranked list of `(language, probability)` tuples for the segment. Useful for debugging misdetection — e.g. when the model hallucinates Sinhala on noise, this would show Sinhala at the top with a high probability. Could be included in transcript events or exposed as a diagnostic endpoint.
|
||||||
|
- **`duration`** — total audio duration fed to the model.
|
||||||
|
- **`duration_after_vad`** — speech duration according to Whisper's internal VAD (not meaningful since we pass `vad_filter=False`).
|
||||||
44
README.md
44
README.md
@@ -16,3 +16,47 @@ This project started as a [vibe-coded](https://en.wikipedia.org/wiki/Vibe_coding
|
|||||||
### Setup [venv](https://docs.python.org/3/library/venv.html) for [python](https://www.python.org/)
|
### Setup [venv](https://docs.python.org/3/library/venv.html) for [python](https://www.python.org/)
|
||||||
|
|
||||||
We will have two different setups here depending on if you want to build ctranslate2 locally or not. This shall be documented.
|
We will have two different setups here depending on if you want to build ctranslate2 locally or not. This shall be documented.
|
||||||
|
|
||||||
|
|
||||||
|
## Model selection
|
||||||
|
|
||||||
|
Pass `--model <name>` to `stt-server.py`. Models are downloaded automatically from HuggingFace on first use.
|
||||||
|
|
||||||
|
| Model | VRAM | Quality | Notes |
|
||||||
|
|-------|------|---------|-------|
|
||||||
|
| `base.en` | ~0.5 GB (`float16`) / ~1 GB (`float32`) | Low | Default. Fast, but struggles with similar-sounding consonants (V/B/D). |
|
||||||
|
| `small.en` | ~1 GB (`float16`) / ~2 GB (`float32`) | Medium | Noticeable improvement over base for most speech. |
|
||||||
|
| `medium.en` | ~2.5 GB (`float16`) / ~5 GB (`float32`) | Good | Recommended starting point for production use. |
|
||||||
|
| `large-v3` | ~5 GB (`float16`) / ~10 GB (`float32`) | Best | Highest accuracy, use if VRAM allows. |
|
||||||
|
|
||||||
|
English-only models (`.en` suffix) are faster and more accurate than multilingual models for English speech.
|
||||||
|
|
||||||
|
|
||||||
|
## Compute type
|
||||||
|
|
||||||
|
Pass `--compute-type <type>` to control the numeric precision used during inference.
|
||||||
|
|
||||||
|
| Type | Notes |
|
||||||
|
|------|-------|
|
||||||
|
| `int8_float16` | Default. Good balance of speed and accuracy on modern GPUs. |
|
||||||
|
| `float16` | Slightly better accuracy, higher VRAM usage. |
|
||||||
|
| `int8` | CPU-friendly, lower quality. |
|
||||||
|
|
||||||
|
If you see a CUDA error about mismatched library versions at startup, use `setup-venv-local-build.sh` to build ctranslate2 against your system CUDA version rather than using the PyPI wheel.
|
||||||
|
|
||||||
|
|
||||||
|
## Language and translation
|
||||||
|
|
||||||
|
By default the server auto-detects the spoken language and transcribes it as-is.
|
||||||
|
|
||||||
|
| Argument | Default | Notes |
|
||||||
|
|----------|---------|-------|
|
||||||
|
| `--language <code>` | none (auto-detect) | Force a specific language, e.g. `--language en` or `--language sv`. Speeds up detection and avoids misidentification. |
|
||||||
|
| `--task transcribe` | default | Output text in the spoken language. |
|
||||||
|
| `--task translate` | | Translate speech to English regardless of source language. |
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> The `.en` model variants (`base.en`, `small.en` etc.) are English-only and do not support `--task translate` or non-English `--language`. Use a multilingual model (`large-v3`, `medium`) for multilingual or translation use cases.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> Omitting `--language` with a multilingual model and English-only speech may cause occasional misdetection. Pass `--language en` to avoid this if you only speak English.
|
||||||
24
examples/listen.mjs
Normal file
24
examples/listen.mjs
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Connect to the STT server and print all events.
|
||||||
|
// Usage: node listen.mjs
|
||||||
|
|
||||||
|
const PORT = process.env.STT_PORT ?? '11501'
|
||||||
|
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||||
|
|
||||||
|
ws.addEventListener('open', () => {
|
||||||
|
process.stderr.write(`connected to ws://localhost:${PORT}\n`)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.addEventListener('message', ({ data }) => {
|
||||||
|
const event = JSON.parse(data)
|
||||||
|
console.log(event)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.addEventListener('close', () => {
|
||||||
|
process.stderr.write('disconnected\n')
|
||||||
|
process.exit(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.addEventListener('error', (err) => {
|
||||||
|
process.stderr.write(`error: ${err.message}\n`)
|
||||||
|
process.exit(1)
|
||||||
|
})
|
||||||
19
examples/transcripts.mjs
Normal file
19
examples/transcripts.mjs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// Connect to the STT server and print transcript text only.
|
||||||
|
// Usage: node transcripts.mjs
|
||||||
|
|
||||||
|
const PORT = process.env.STT_PORT ?? '11501'
|
||||||
|
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||||
|
|
||||||
|
ws.addEventListener('open', () => {
|
||||||
|
process.stderr.write(`connected to ws://localhost:${PORT}\n`)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.addEventListener('message', ({ data }) => {
|
||||||
|
const event = JSON.parse(data)
|
||||||
|
if (event.event === 'transcript') {
|
||||||
|
console.log(event.text)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.addEventListener('close', () => process.exit(0))
|
||||||
|
ws.addEventListener('error', () => process.exit(1))
|
||||||
@@ -57,7 +57,7 @@ fi
|
|||||||
|
|
||||||
echo "==> upgrading pip + build tools"
|
echo "==> upgrading pip + build tools"
|
||||||
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
|
"${VENV}/bin/pip" install --upgrade pip wheel setuptools pybind11 --quiet
|
||||||
"${VENV}/bin/pip" install torch silero-vad
|
"${VENV}/bin/pip" install torch silero-vad websockets
|
||||||
|
|
||||||
# --- clone (skipped if already done) ---
|
# --- clone (skipped if already done) ---
|
||||||
if [ ! -d "${BUILD_DIR}/src/.git" ]; then
|
if [ ! -d "${BUILD_DIR}/src/.git" ]; then
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ fi
|
|||||||
|
|
||||||
echo "==> installing torch and faster-whisper"
|
echo "==> installing torch and faster-whisper"
|
||||||
"${VENV}/bin/pip" install --upgrade pip --quiet
|
"${VENV}/bin/pip" install --upgrade pip --quiet
|
||||||
"${VENV}/bin/pip" install torch faster-whisper silero-vad
|
"${VENV}/bin/pip" install torch faster-whisper silero-vad websockets
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "==> done. Venv ready at ${VENV}"
|
echo "==> done. Venv ready at ${VENV}"
|
||||||
|
|||||||
117
stt-server.py
117
stt-server.py
@@ -1,8 +1,9 @@
|
|||||||
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
||||||
"""
|
"""
|
||||||
STT process: records audio, runs Silero VAD, transcribes with faster-whisper.
|
STT server: records audio, runs Silero VAD, transcribes with faster-whisper.
|
||||||
|
Broadcasts JSON events to all connected WebSocket clients and to stdout.
|
||||||
|
|
||||||
Events (JSON lines on stdout):
|
Events:
|
||||||
{"event": "ready"}
|
{"event": "ready"}
|
||||||
{"event": "vad_start"}
|
{"event": "vad_start"}
|
||||||
{"event": "vad_end", "duration": 1.23}
|
{"event": "vad_end", "duration": 1.23}
|
||||||
@@ -11,13 +12,22 @@ Events (JSON lines on stdout):
|
|||||||
|
|
||||||
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
|
word format: {"word": "hello", "start": 0.12, "end": 0.45, "probability": 0.99}
|
||||||
|
|
||||||
All log/status messages go to stderr. Stdout is machine-readable events only.
|
Every WebSocket connection receives the full event stream from the moment it
|
||||||
|
connects — no subscription handshake required.
|
||||||
|
|
||||||
|
Machine-readable events are sent over WebSocket only.
|
||||||
|
Pass --verbose to enable logging to stderr (startup, VAD events, transcripts).
|
||||||
|
Errors always go to stderr regardless of verbosity.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
STT_PORT WebSocket port (default: 11501)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
./stt-server.py
|
./stt-server.py
|
||||||
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16
|
./stt-server.py --model large-v3 --device cuda --compute-type int8_float16 --verbose
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
@@ -26,24 +36,65 @@ import threading
|
|||||||
import queue
|
import queue
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
|
VAD_WINDOW = 512 # samples per VAD chunk (32ms at 16kHz)
|
||||||
PRE_ROLL_SAMPLES = 3200 # 0.2s of audio prepended to each segment
|
PRE_ROLL_SAMPLES = 3200 # 0.2s prepended to each segment for context
|
||||||
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
|
HISTORY_SAMPLES = 960000 # 60s ring buffer for pre-roll
|
||||||
|
PORT = int(__import__('os').environ.get('STT_PORT', 11501))
|
||||||
|
|
||||||
|
|
||||||
def log(msg):
|
def log(msg, error=False):
|
||||||
|
if error or verbose:
|
||||||
sys.stderr.write(f'[stt] {msg}\n')
|
sys.stderr.write(f'[stt] {msg}\n')
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
|
||||||
def emit(event):
|
# --- WebSocket broadcast ---
|
||||||
sys.stdout.write(json.dumps(event) + '\n')
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
_ws_loop = None
|
||||||
|
_ws_clients = set() # set of asyncio.Queue, one per connection
|
||||||
|
|
||||||
|
|
||||||
|
def emit(event):
|
||||||
|
line = json.dumps(event)
|
||||||
|
if _ws_loop is not None:
|
||||||
|
for q in list(_ws_clients):
|
||||||
|
_ws_loop.call_soon_threadsafe(q.put_nowait, line)
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_handler(websocket):
|
||||||
|
q = asyncio.Queue()
|
||||||
|
_ws_clients.add(q)
|
||||||
|
log(f'client connected ({len(_ws_clients)} total)')
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await q.get()
|
||||||
|
await websocket.send(msg)
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
_ws_clients.discard(q)
|
||||||
|
log(f'client disconnected ({len(_ws_clients)} remaining)')
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_main():
|
||||||
|
global _ws_loop
|
||||||
|
_ws_loop = asyncio.get_running_loop()
|
||||||
|
async with websockets.serve(ws_handler, '', PORT):
|
||||||
|
log(f'WebSocket listening on port {PORT}')
|
||||||
|
await asyncio.Future() # run forever
|
||||||
|
|
||||||
|
|
||||||
|
def start_ws_server():
|
||||||
|
asyncio.run(ws_main())
|
||||||
|
|
||||||
|
|
||||||
|
# --- Mic ---
|
||||||
|
|
||||||
def find_mic():
|
def find_mic():
|
||||||
candidates = [
|
candidates = [
|
||||||
@@ -63,19 +114,39 @@ def s16le_to_f32(data):
|
|||||||
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
# --- Args + model loading ---
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', default='base.en')
|
parser.add_argument('--model', default='base.en')
|
||||||
parser.add_argument('--device', default='cuda')
|
parser.add_argument('--device', default='cuda')
|
||||||
parser.add_argument('--compute-type', default='int8_float16')
|
parser.add_argument('--compute-type', default='int8_float16')
|
||||||
|
parser.add_argument('--language', default=None, help='language code (e.g. en, sv) or None for auto-detect')
|
||||||
|
parser.add_argument('--task', default='transcribe', choices=['transcribe', 'translate'], help='transcribe keeps the source language; translate converts to English')
|
||||||
|
parser.add_argument('--verbose', '-v', action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
verbose = args.verbose
|
||||||
|
|
||||||
|
token_file = os.environ.get('HF_TOKEN_FILE', os.path.expanduser('~/.secrets/hugging-face.token'))
|
||||||
|
try:
|
||||||
|
with open(token_file) as f:
|
||||||
|
os.environ['HF_TOKEN'] = f.read().strip()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
try:
|
||||||
|
snapshot_download(f'Systran/faster-whisper-{args.model}', local_files_only=True)
|
||||||
|
except Exception:
|
||||||
|
log(f'downloading model {args.model}...')
|
||||||
|
|
||||||
log(f'loading faster-whisper {args.model} ({args.device}, {args.compute_type})...')
|
log(f'loading faster-whisper {args.model} ({args.device}, {args.compute_type})...')
|
||||||
from faster_whisper import WhisperModel
|
|
||||||
try:
|
try:
|
||||||
model = WhisperModel(args.model, device=args.device, compute_type=args.compute_type)
|
model = WhisperModel(args.model, device=args.device, compute_type=args.compute_type)
|
||||||
log(f'model ready on {args.device}')
|
log(f'model ready on {args.device}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log(f'{args.device} failed ({e}), falling back to cpu')
|
log(f'{args.device} failed ({e}), falling back to cpu', error=True)
|
||||||
model = WhisperModel(args.model, device='cpu', compute_type='int8')
|
model = WhisperModel(args.model, device='cpu', compute_type='int8')
|
||||||
log('model ready on cpu')
|
log('model ready on cpu')
|
||||||
|
|
||||||
@@ -87,7 +158,8 @@ vad = VADIterator(vad_model, sampling_rate=SAMPLE_RATE,
|
|||||||
log('VAD ready')
|
log('VAD ready')
|
||||||
|
|
||||||
|
|
||||||
# Ring buffer for pre-roll context
|
# --- Pre-roll ring buffer ---
|
||||||
|
|
||||||
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
|
history = np.zeros(HISTORY_SAMPLES, dtype=np.float32)
|
||||||
history_pos = 0
|
history_pos = 0
|
||||||
|
|
||||||
@@ -95,7 +167,6 @@ def push_history(samples):
|
|||||||
global history_pos
|
global history_pos
|
||||||
n = len(samples)
|
n = len(samples)
|
||||||
base = history_pos % HISTORY_SAMPLES
|
base = history_pos % HISTORY_SAMPLES
|
||||||
# May wrap around — handle both cases
|
|
||||||
space = HISTORY_SAMPLES - base
|
space = HISTORY_SAMPLES - base
|
||||||
if n <= space:
|
if n <= space:
|
||||||
history[base:base + n] = samples
|
history[base:base + n] = samples
|
||||||
@@ -113,7 +184,8 @@ def get_preroll():
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
# Transcription runs in a separate thread so VAD is never blocked by GPU
|
# --- Transcription thread ---
|
||||||
|
|
||||||
transcription_queue = queue.Queue()
|
transcription_queue = queue.Queue()
|
||||||
|
|
||||||
def transcription_worker():
|
def transcription_worker():
|
||||||
@@ -123,9 +195,10 @@ def transcription_worker():
|
|||||||
break
|
break
|
||||||
samples, duration = item
|
samples, duration = item
|
||||||
try:
|
try:
|
||||||
segments, _ = model.transcribe(
|
segments, info = model.transcribe(
|
||||||
samples,
|
samples,
|
||||||
language='en',
|
language=args.language,
|
||||||
|
task=args.task,
|
||||||
word_timestamps=True,
|
word_timestamps=True,
|
||||||
vad_filter=False,
|
vad_filter=False,
|
||||||
)
|
)
|
||||||
@@ -140,21 +213,25 @@ def transcription_worker():
|
|||||||
'end': round(float(w.end), 4),
|
'end': round(float(w.end), 4),
|
||||||
'probability': round(float(w.probability), 4),
|
'probability': round(float(w.probability), 4),
|
||||||
})
|
})
|
||||||
log(f'transcript: {json.dumps(text.strip())} ({len(words)} words)')
|
language = info.language
|
||||||
|
lang_prob = round(float(info.language_probability), 3)
|
||||||
|
log(f'transcript [{language} {lang_prob}]: {json.dumps(text.strip())} ({len(words)} words)')
|
||||||
if text.strip():
|
if text.strip():
|
||||||
emit({'event': 'transcript', 'text': text.strip(), 'words': words, 'duration': round(duration, 3)})
|
emit({'event': 'transcript', 'text': text.strip(), 'words': words, 'duration': round(duration, 3), 'language': language, 'language_probability': lang_prob})
|
||||||
except Exception:
|
except Exception:
|
||||||
msg = traceback.format_exc()
|
msg = traceback.format_exc()
|
||||||
log(f'transcription error:\n{msg}')
|
log(f'transcription error:\n{msg}', error=True)
|
||||||
emit({'event': 'error', 'message': msg})
|
emit({'event': 'error', 'message': msg})
|
||||||
finally:
|
finally:
|
||||||
transcription_queue.task_done()
|
transcription_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
threading.Thread(target=transcription_worker, daemon=True).start()
|
threading.Thread(target=transcription_worker, daemon=True).start()
|
||||||
|
threading.Thread(target=start_ws_server, daemon=True).start()
|
||||||
|
|
||||||
|
|
||||||
# Main recording + VAD loop
|
# --- Main recording + VAD loop ---
|
||||||
|
|
||||||
cmd, cmd_args = find_mic()
|
cmd, cmd_args = find_mic()
|
||||||
log(f'mic: {cmd} {" ".join(cmd_args)}')
|
log(f'mic: {cmd} {" ".join(cmd_args)}')
|
||||||
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
mic = subprocess.Popen([cmd] + cmd_args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||||
|
|||||||
Reference in New Issue
Block a user