From f4ae96c6b99c57f54ac38f4f216ab88405e9178d Mon Sep 17 00:00:00 2001 From: mikael-lovqvists-claude-agent Date: Tue, 9 Jun 2026 21:32:58 +0000 Subject: [PATCH] Replace HTTP API with WebSocket server Single port (TTS_PORT) handles both the WS upgrade handshake and connections. Adds job queue, generation worker, playback events (queued/started/finished/aborted/error), and abort_current/abort_all commands. Fixes BrokenPipeError when pacat is killed mid-write. Updates all examples to use WebSocket; adds abort-demo.mjs. Co-Authored-By: Claude Sonnet 4.6 --- chatterbox-server.py | 309 +++++++++++++++++++++++++++------------ examples/abort-demo.mjs | 34 +++++ examples/chime.mjs | 22 ++- examples/speak.mjs | 22 ++- examples/terminate.mjs | 24 ++- examples/voice-clone.mjs | 24 +-- setup-venv.sh | 2 +- 7 files changed, 310 insertions(+), 127 deletions(-) create mode 100644 examples/abort-demo.mjs diff --git a/chatterbox-server.py b/chatterbox-server.py index efde2a4..5aab813 100755 --- a/chatterbox-server.py +++ b/chatterbox-server.py @@ -1,18 +1,37 @@ #!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"' """ -Chatterbox TTS HTTP server — keeps model loaded, exposes a JSON HTTP API. +Chatterbox TTS WebSocket server — keeps model loaded, exposes a JSON WebSocket API. -Endpoints: - POST /speak {"text": "...", "temperature": 0.8, "top_p": 0.95, "audio_prompt": "/path.wav"} - POST /chime {"path": "/path/to/file.wav"} - POST /preload {"path": "/path/to/file.wav"} - POST /command {"command": "terminate"} +Connect to ws://host:TTS_PORT (default ws://localhost:11500). -All endpoints return {"status": "ok"} or {"status": "error", "message": "..."}. -Responses are sent after audio is queued for playback (not after playback finishes). +Client → Server: + {"type": "speak", "id"?: N, "text": "...", ...generation_opts} + {"type": "chime", "id"?: N, "path": "..."} + {"type": "preload", "path": "..."} + {"type": "abort_current"} — kill active playback, advance to next queued item + {"type": "abort_all"} — kill active playback + drain all queues + {"type": "terminate"} + +Server → requesting client: + {"status": "ok", "id": N} (speak/chime) + {"status": "ok"} + {"status": "error", "message": "..."} + +Server → all clients (broadcast): + {"event": "queued", "id": N} + {"event": "started", "id": N} + {"event": "finished", "id": N} + {"event": "aborted", "id": N} + {"event": "error", "id": N, "message": "..."} + +Generation options (speak): + temperature, top_p, top_k, repetition_penalty, min_p + audio_prompt — path to reference WAV for voice cloning + exaggeration — 0.0-1.0, full model only + cfg_weight — full model only Environment: - TTS_PORT TCP port to listen on (default: 11500) + TTS_PORT port to listen on (default: 11500) HF_TOKEN_FILE path to HuggingFace token file (default: ~/.secrets/hugging-face.token) HF_HUB_CACHE path to HuggingFace hub cache (default: ~/.cache/huggingface/hub) @@ -23,9 +42,6 @@ Usage: Paralinguistic tags supported in text: [laugh] [chuckle] [cough] [clear throat] [sigh] [shush] [groan] [sniff] [gasp] - -Full model only: - exaggeration 0.0-1.0 emotion intensity (ignored in turbo) """ import os @@ -37,7 +53,8 @@ import threading import subprocess import traceback import tempfile -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +import asyncio +import itertools from pathlib import Path import numpy as np @@ -106,26 +123,32 @@ _wav_cache = {} _chime_cache = {} _gen_lock = threading.Lock() -_SENTINEL = object() -playback_queue = queue.Queue() +_SENTINEL = object() + +_id_counter = itertools.count(1) +def _next_id(): + return next(_id_counter) + +_job_queue = queue.Queue() # dicts: {'id', 'type', ...} +_playback_queue = queue.Queue() # dicts: {'id', 'samples'} + +_current_proc = None +_current_proc_lock = threading.Lock() +_abort_flag = threading.Event() + +_ws_clients = set() # asyncio.Queue per connected client +_ws_clients_lock = threading.Lock() +_ws_loop = None -def playback_worker(): - while True: - item = playback_queue.get() - if item is _SENTINEL: - break - proc = subprocess.Popen( - ['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'], - stdin=subprocess.PIPE, - ) - proc.stdin.write(item.tobytes()) - proc.stdin.close() - proc.wait() - playback_queue.task_done() - - -threading.Thread(target=playback_worker, daemon=True).start() +def broadcast(event): + if _ws_loop is None: + return + msg = json.dumps(event) + with _ws_clients_lock: + clients = list(_ws_clients) + for q in clients: + _ws_loop.call_soon_threadsafe(q.put_nowait, msg) def ensure_float32_wav(path): @@ -185,79 +208,179 @@ def generate(text, opts): return samples -class Handler(BaseHTTPRequestHandler): - def send_json(self, data, status=200): - body = json.dumps(data).encode() - self.send_response(status) - self.send_header('Content-Type', 'application/json') - self.send_header('Content-Length', str(len(body))) - self.end_headers() - self.wfile.write(body) - - def read_json(self): - length = int(self.headers.get('Content-Length', 0)) - return json.loads(self.rfile.read(length)) - - def do_POST(self): +def generation_worker(): + while True: + item = _job_queue.get() + if item is _SENTINEL: + _job_queue.task_done() + break + job_id = item['id'] + job_type = item['type'] try: - req = self.read_json() - except Exception: - self.send_json({'status': 'error', 'message': 'invalid JSON'}, 400) - return - - if self.path == '/speak': - text = req.pop('text', '') - if not text: - self.send_json({'status': 'ok'}) - return - try: + if job_type == 'speak': with _gen_lock: - samples = generate(text, req) - playback_queue.put(samples) - self.send_json({'status': 'ok'}) - except Exception as e: - traceback.print_exc(file=sys.stderr) - self.send_json({'status': 'error', 'message': str(e)}, 500) + samples = generate(item['text'], item) + _playback_queue.put({'id': job_id, 'samples': samples}) + elif job_type == 'chime': + samples = load_chime(item['path']) + _playback_queue.put({'id': job_id, 'samples': samples}) + except Exception as e: + traceback.print_exc(file=sys.stderr) + broadcast({'event': 'error', 'id': job_id, 'message': str(e)}) + _job_queue.task_done() - elif self.path == '/chime': - path = req.get('path', '') - try: - samples = load_chime(path) - playback_queue.put(samples) - self.send_json({'status': 'ok'}) - except Exception as e: - traceback.print_exc(file=sys.stderr) - self.send_json({'status': 'error', 'message': str(e)}, 500) - elif self.path == '/preload': - path = req.get('path', '') - try: - load_chime(path) - log(f'preloaded: {path}') - self.send_json({'status': 'ok'}) - except Exception as e: - self.send_json({'status': 'error', 'message': str(e)}, 500) +def playback_worker(): + global _current_proc + while True: + item = _playback_queue.get() + if item is _SENTINEL: + _playback_queue.task_done() + break + job_id = item['id'] + samples = item['samples'] - elif self.path == '/command': - command = req.get('command', '') - if command == 'terminate': - self.send_json({'status': 'ok'}) - threading.Thread(target=server.shutdown, daemon=True).start() - else: - self.send_json({'status': 'error', 'message': f'unknown command: {command}'}, 400) + _abort_flag.clear() + broadcast({'event': 'started', 'id': job_id}) + proc = subprocess.Popen( + ['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'], + stdin=subprocess.PIPE, + ) + with _current_proc_lock: + _current_proc = proc + + try: + proc.stdin.write(samples.tobytes()) + proc.stdin.close() + except BrokenPipeError: + pass + proc.wait() + + with _current_proc_lock: + _current_proc = None + + if _abort_flag.is_set(): + broadcast({'event': 'aborted', 'id': job_id}) else: - self.send_json({'status': 'error', 'message': 'not found'}, 404) + broadcast({'event': 'finished', 'id': job_id}) - def log_message(self, fmt, *args): - log(fmt % args) + _playback_queue.task_done() + + +def abort_current(): + _abort_flag.set() + with _current_proc_lock: + if _current_proc is not None: + _current_proc.kill() + + +def abort_all(): + drained_ids = [] + + while True: + try: + item = _job_queue.get_nowait() + if item is not _SENTINEL: + drained_ids.append(item['id']) + _job_queue.task_done() + except queue.Empty: + break + + while True: + try: + item = _playback_queue.get_nowait() + if item is not _SENTINEL: + drained_ids.append(item['id']) + _playback_queue.task_done() + except queue.Empty: + break + + abort_current() + + for jid in drained_ids: + broadcast({'event': 'aborted', 'id': jid}) + + +threading.Thread(target=generation_worker, daemon=True).start() +threading.Thread(target=playback_worker, daemon=True).start() + + +async def _ws_handler(websocket): + q = asyncio.Queue() + with _ws_clients_lock: + _ws_clients.add(q) + + async def sender(): + while True: + msg = await q.get() + await websocket.send(msg) + + sender_task = asyncio.create_task(sender()) + + try: + async for raw in websocket: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await websocket.send(json.dumps({'status': 'error', 'message': 'invalid JSON'})) + continue + + msg_type = msg.get('type', '') + + if msg_type in ('speak', 'chime'): + job_id = msg.get('id') or _next_id() + job = dict(msg) + job['id'] = job_id + _job_queue.put(job) + broadcast({'event': 'queued', 'id': job_id}) + await websocket.send(json.dumps({'status': 'ok', 'id': job_id})) + + elif msg_type == 'preload': + path = msg.get('path', '') + try: + await asyncio.get_running_loop().run_in_executor(None, load_chime, path) + log(f'preloaded: {path}') + await websocket.send(json.dumps({'status': 'ok'})) + except Exception as e: + await websocket.send(json.dumps({'status': 'error', 'message': str(e)})) + + elif msg_type == 'abort_current': + abort_current() + await websocket.send(json.dumps({'status': 'ok'})) + + elif msg_type == 'abort_all': + abort_all() + await websocket.send(json.dumps({'status': 'ok'})) + + elif msg_type == 'terminate': + await websocket.send(json.dumps({'status': 'ok'})) + asyncio.get_running_loop().stop() + + else: + await websocket.send(json.dumps({'status': 'error', 'message': f'unknown type: {msg_type}'})) + + except Exception: + pass + finally: + with _ws_clients_lock: + _ws_clients.discard(q) + sender_task.cancel() + + +async def main(): + global _ws_loop + from websockets.asyncio.server import serve as ws_serve + _ws_loop = asyncio.get_running_loop() + async with ws_serve(_ws_handler, '0.0.0.0', PORT, reuse_address=True): + log(f'listening on port {PORT}') + await asyncio.Future() -server = ThreadingHTTPServer(('', PORT), Handler) -log(f'listening on port {PORT}') try: - server.serve_forever() + asyncio.run(main()) except KeyboardInterrupt: pass finally: - playback_queue.put(_SENTINEL) + _job_queue.put(_SENTINEL) + _playback_queue.put(_SENTINEL) diff --git a/examples/abort-demo.mjs b/examples/abort-demo.mjs new file mode 100644 index 0000000..d1d16e0 --- /dev/null +++ b/examples/abort-demo.mjs @@ -0,0 +1,34 @@ +// Start speaking a long sentence, then abort a few seconds in. +// Usage: node abort-demo.mjs + +const PORT = process.env.TTS_PORT ?? '11500' +const text = 'This is a very long sentence that will be cut off before it finishes, ' + + 'because a few seconds after playback starts we will send an abort command ' + + 'to demonstrate the abort current functionality of the server.' + +const ws = new WebSocket(`ws://localhost:${PORT}`) + +ws.addEventListener('open', () => { + ws.send(JSON.stringify({ type: 'speak', text })) +}) + +ws.addEventListener('message', ({ data }) => { + const msg = JSON.parse(data) + console.log(msg) + + if (msg.event === 'started') { + setTimeout(() => { + console.log('aborting...') + ws.send(JSON.stringify({ type: 'abort_current' })) + }, 3000) + } + + if (msg.event === 'aborted' || msg.event === 'finished' || msg.event === 'error') { + ws.close() + } +}) + +ws.addEventListener('error', (e) => { + console.error('error:', e.message) + process.exit(1) +}) diff --git a/examples/chime.mjs b/examples/chime.mjs index 0247451..8b5371d 100644 --- a/examples/chime.mjs +++ b/examples/chime.mjs @@ -9,14 +9,20 @@ if (!path) { process.exit(1) } -const res = await fetch(`http://localhost:${PORT}/chime`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ path }), +const ws = new WebSocket(`ws://localhost:${PORT}`) + +ws.addEventListener('open', () => { + ws.send(JSON.stringify({ type: 'chime', path })) }) -const data = await res.json() -if (data.status !== 'ok') { - console.error('error:', data.message) +ws.addEventListener('message', ({ data }) => { + const msg = JSON.parse(data) + if (msg.event === 'finished' || msg.event === 'aborted' || msg.event === 'error') { + ws.close() + } +}) + +ws.addEventListener('error', (e) => { + console.error('error:', e.message) process.exit(1) -} +}) diff --git a/examples/speak.mjs b/examples/speak.mjs index c609de3..61fa305 100644 --- a/examples/speak.mjs +++ b/examples/speak.mjs @@ -4,14 +4,20 @@ const PORT = process.env.TTS_PORT ?? '11500' const text = process.argv[2] ?? 'Hello from Node.' -const res = await fetch(`http://localhost:${PORT}/speak`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ text }), +const ws = new WebSocket(`ws://localhost:${PORT}`) + +ws.addEventListener('open', () => { + ws.send(JSON.stringify({ type: 'speak', text })) }) -const data = await res.json() -if (data.status !== 'ok') { - console.error('error:', data.message) +ws.addEventListener('message', ({ data }) => { + const msg = JSON.parse(data) + if (msg.event === 'finished' || msg.event === 'aborted' || msg.event === 'error') { + ws.close() + } +}) + +ws.addEventListener('error', (e) => { + console.error('error:', e.message) process.exit(1) -} +}) diff --git a/examples/terminate.mjs b/examples/terminate.mjs index 800f91e..8e8ff07 100644 --- a/examples/terminate.mjs +++ b/examples/terminate.mjs @@ -3,14 +3,22 @@ const PORT = process.env.TTS_PORT ?? '11500' -const res = await fetch(`http://localhost:${PORT}/command`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ command: 'terminate' }), +const ws = new WebSocket(`ws://localhost:${PORT}`) + +ws.addEventListener('open', () => { + ws.send(JSON.stringify({ type: 'terminate' })) }) -const data = await res.json() -if (data.status !== 'ok') { - console.error('error:', data.message) +ws.addEventListener('message', ({ data }) => { + const msg = JSON.parse(data) + if (msg.status !== 'ok') { + console.error('error:', msg.message) + process.exit(1) + } + ws.close() +}) + +ws.addEventListener('error', (e) => { + console.error('error:', e.message) process.exit(1) -} +}) diff --git a/examples/voice-clone.mjs b/examples/voice-clone.mjs index c8998b9..cbee145 100644 --- a/examples/voice-clone.mjs +++ b/examples/voice-clone.mjs @@ -2,7 +2,7 @@ // The server reads the audio_prompt path from its own filesystem. // Usage: node voice-clone.mjs /path/to/reference.wav "Text to speak" -const PORT = process.env.TTS_PORT ?? '11500' +const PORT = process.env.TTS_PORT ?? '11500' const audio_prompt = process.argv[2] const text = process.argv[3] ?? 'Hello, this is a cloned voice.' @@ -11,14 +11,20 @@ if (!audio_prompt) { process.exit(1) } -const res = await fetch(`http://localhost:${PORT}/speak`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ text, audio_prompt }), +const ws = new WebSocket(`ws://localhost:${PORT}`) + +ws.addEventListener('open', () => { + ws.send(JSON.stringify({ type: 'speak', text, audio_prompt })) }) -const data = await res.json() -if (data.status !== 'ok') { - console.error('error:', data.message) +ws.addEventListener('message', ({ data }) => { + const msg = JSON.parse(data) + if (msg.event === 'finished' || msg.event === 'aborted' || msg.event === 'error') { + ws.close() + } +}) + +ws.addEventListener('error', (e) => { + console.error('error:', e.message) process.exit(1) -} +}) diff --git a/setup-venv.sh b/setup-venv.sh index 726ab57..5d5873c 100755 --- a/setup-venv.sh +++ b/setup-venv.sh @@ -12,7 +12,7 @@ fi echo "==> installing Python dependencies" "${VENV}/bin/pip" install --upgrade pip --quiet -"${VENV}/bin/pip" install chatterbox-tts +"${VENV}/bin/pip" install chatterbox-tts websockets echo "" -- 2.49.1