Replace HTTP API with WebSocket server

Single port (TTS_PORT) handles both the WS upgrade handshake and
connections. Adds job queue, generation worker, playback events
(queued/started/finished/aborted/error), and abort_current/abort_all
commands. Fixes BrokenPipeError when pacat is killed mid-write.
Updates all examples to use WebSocket; adds abort-demo.mjs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-09 21:32:58 +00:00
parent b24414c3f3
commit f4ae96c6b9
7 changed files with 310 additions and 127 deletions

View File

@@ -1,18 +1,37 @@
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
"""
Chatterbox TTS HTTP server — keeps model loaded, exposes a JSON HTTP API.
Chatterbox TTS WebSocket server — keeps model loaded, exposes a JSON WebSocket API.
Endpoints:
POST /speak {"text": "...", "temperature": 0.8, "top_p": 0.95, "audio_prompt": "/path.wav"}
POST /chime {"path": "/path/to/file.wav"}
POST /preload {"path": "/path/to/file.wav"}
POST /command {"command": "terminate"}
Connect to ws://host:TTS_PORT (default ws://localhost:11500).
All endpoints return {"status": "ok"} or {"status": "error", "message": "..."}.
Responses are sent after audio is queued for playback (not after playback finishes).
Client → Server:
{"type": "speak", "id"?: N, "text": "...", ...generation_opts}
{"type": "chime", "id"?: N, "path": "..."}
{"type": "preload", "path": "..."}
{"type": "abort_current"} — kill active playback, advance to next queued item
{"type": "abort_all"} — kill active playback + drain all queues
{"type": "terminate"}
Server → requesting client:
{"status": "ok", "id": N} (speak/chime)
{"status": "ok"}
{"status": "error", "message": "..."}
Server → all clients (broadcast):
{"event": "queued", "id": N}
{"event": "started", "id": N}
{"event": "finished", "id": N}
{"event": "aborted", "id": N}
{"event": "error", "id": N, "message": "..."}
Generation options (speak):
temperature, top_p, top_k, repetition_penalty, min_p
audio_prompt — path to reference WAV for voice cloning
exaggeration — 0.0-1.0, full model only
cfg_weight — full model only
Environment:
TTS_PORT TCP port to listen on (default: 11500)
TTS_PORT port to listen on (default: 11500)
HF_TOKEN_FILE path to HuggingFace token file (default: ~/.secrets/hugging-face.token)
HF_HUB_CACHE path to HuggingFace hub cache (default: ~/.cache/huggingface/hub)
@@ -23,9 +42,6 @@ Usage:
Paralinguistic tags supported in text:
[laugh] [chuckle] [cough] [clear throat] [sigh] [shush] [groan] [sniff] [gasp]
Full model only:
exaggeration 0.0-1.0 emotion intensity (ignored in turbo)
"""
import os
@@ -37,7 +53,8 @@ import threading
import subprocess
import traceback
import tempfile
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
import asyncio
import itertools
from pathlib import Path
import numpy as np
@@ -106,26 +123,32 @@ _wav_cache = {}
_chime_cache = {}
_gen_lock = threading.Lock()
_SENTINEL = object()
playback_queue = queue.Queue()
_SENTINEL = object()
_id_counter = itertools.count(1)
def _next_id():
return next(_id_counter)
_job_queue = queue.Queue() # dicts: {'id', 'type', ...}
_playback_queue = queue.Queue() # dicts: {'id', 'samples'}
_current_proc = None
_current_proc_lock = threading.Lock()
_abort_flag = threading.Event()
_ws_clients = set() # asyncio.Queue per connected client
_ws_clients_lock = threading.Lock()
_ws_loop = None
def playback_worker():
while True:
item = playback_queue.get()
if item is _SENTINEL:
break
proc = subprocess.Popen(
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
stdin=subprocess.PIPE,
)
proc.stdin.write(item.tobytes())
proc.stdin.close()
proc.wait()
playback_queue.task_done()
threading.Thread(target=playback_worker, daemon=True).start()
def broadcast(event):
if _ws_loop is None:
return
msg = json.dumps(event)
with _ws_clients_lock:
clients = list(_ws_clients)
for q in clients:
_ws_loop.call_soon_threadsafe(q.put_nowait, msg)
def ensure_float32_wav(path):
@@ -185,79 +208,179 @@ def generate(text, opts):
return samples
class Handler(BaseHTTPRequestHandler):
def send_json(self, data, status=200):
body = json.dumps(data).encode()
self.send_response(status)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(body)))
self.end_headers()
self.wfile.write(body)
def read_json(self):
length = int(self.headers.get('Content-Length', 0))
return json.loads(self.rfile.read(length))
def do_POST(self):
def generation_worker():
while True:
item = _job_queue.get()
if item is _SENTINEL:
_job_queue.task_done()
break
job_id = item['id']
job_type = item['type']
try:
req = self.read_json()
except Exception:
self.send_json({'status': 'error', 'message': 'invalid JSON'}, 400)
return
if self.path == '/speak':
text = req.pop('text', '')
if not text:
self.send_json({'status': 'ok'})
return
try:
if job_type == 'speak':
with _gen_lock:
samples = generate(text, req)
playback_queue.put(samples)
self.send_json({'status': 'ok'})
except Exception as e:
traceback.print_exc(file=sys.stderr)
self.send_json({'status': 'error', 'message': str(e)}, 500)
samples = generate(item['text'], item)
_playback_queue.put({'id': job_id, 'samples': samples})
elif job_type == 'chime':
samples = load_chime(item['path'])
_playback_queue.put({'id': job_id, 'samples': samples})
except Exception as e:
traceback.print_exc(file=sys.stderr)
broadcast({'event': 'error', 'id': job_id, 'message': str(e)})
_job_queue.task_done()
elif self.path == '/chime':
path = req.get('path', '')
try:
samples = load_chime(path)
playback_queue.put(samples)
self.send_json({'status': 'ok'})
except Exception as e:
traceback.print_exc(file=sys.stderr)
self.send_json({'status': 'error', 'message': str(e)}, 500)
elif self.path == '/preload':
path = req.get('path', '')
try:
load_chime(path)
log(f'preloaded: {path}')
self.send_json({'status': 'ok'})
except Exception as e:
self.send_json({'status': 'error', 'message': str(e)}, 500)
def playback_worker():
global _current_proc
while True:
item = _playback_queue.get()
if item is _SENTINEL:
_playback_queue.task_done()
break
job_id = item['id']
samples = item['samples']
elif self.path == '/command':
command = req.get('command', '')
if command == 'terminate':
self.send_json({'status': 'ok'})
threading.Thread(target=server.shutdown, daemon=True).start()
else:
self.send_json({'status': 'error', 'message': f'unknown command: {command}'}, 400)
_abort_flag.clear()
broadcast({'event': 'started', 'id': job_id})
proc = subprocess.Popen(
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
stdin=subprocess.PIPE,
)
with _current_proc_lock:
_current_proc = proc
try:
proc.stdin.write(samples.tobytes())
proc.stdin.close()
except BrokenPipeError:
pass
proc.wait()
with _current_proc_lock:
_current_proc = None
if _abort_flag.is_set():
broadcast({'event': 'aborted', 'id': job_id})
else:
self.send_json({'status': 'error', 'message': 'not found'}, 404)
broadcast({'event': 'finished', 'id': job_id})
def log_message(self, fmt, *args):
log(fmt % args)
_playback_queue.task_done()
def abort_current():
_abort_flag.set()
with _current_proc_lock:
if _current_proc is not None:
_current_proc.kill()
def abort_all():
drained_ids = []
while True:
try:
item = _job_queue.get_nowait()
if item is not _SENTINEL:
drained_ids.append(item['id'])
_job_queue.task_done()
except queue.Empty:
break
while True:
try:
item = _playback_queue.get_nowait()
if item is not _SENTINEL:
drained_ids.append(item['id'])
_playback_queue.task_done()
except queue.Empty:
break
abort_current()
for jid in drained_ids:
broadcast({'event': 'aborted', 'id': jid})
threading.Thread(target=generation_worker, daemon=True).start()
threading.Thread(target=playback_worker, daemon=True).start()
async def _ws_handler(websocket):
q = asyncio.Queue()
with _ws_clients_lock:
_ws_clients.add(q)
async def sender():
while True:
msg = await q.get()
await websocket.send(msg)
sender_task = asyncio.create_task(sender())
try:
async for raw in websocket:
try:
msg = json.loads(raw)
except json.JSONDecodeError:
await websocket.send(json.dumps({'status': 'error', 'message': 'invalid JSON'}))
continue
msg_type = msg.get('type', '')
if msg_type in ('speak', 'chime'):
job_id = msg.get('id') or _next_id()
job = dict(msg)
job['id'] = job_id
_job_queue.put(job)
broadcast({'event': 'queued', 'id': job_id})
await websocket.send(json.dumps({'status': 'ok', 'id': job_id}))
elif msg_type == 'preload':
path = msg.get('path', '')
try:
await asyncio.get_running_loop().run_in_executor(None, load_chime, path)
log(f'preloaded: {path}')
await websocket.send(json.dumps({'status': 'ok'}))
except Exception as e:
await websocket.send(json.dumps({'status': 'error', 'message': str(e)}))
elif msg_type == 'abort_current':
abort_current()
await websocket.send(json.dumps({'status': 'ok'}))
elif msg_type == 'abort_all':
abort_all()
await websocket.send(json.dumps({'status': 'ok'}))
elif msg_type == 'terminate':
await websocket.send(json.dumps({'status': 'ok'}))
asyncio.get_running_loop().stop()
else:
await websocket.send(json.dumps({'status': 'error', 'message': f'unknown type: {msg_type}'}))
except Exception:
pass
finally:
with _ws_clients_lock:
_ws_clients.discard(q)
sender_task.cancel()
async def main():
global _ws_loop
from websockets.asyncio.server import serve as ws_serve
_ws_loop = asyncio.get_running_loop()
async with ws_serve(_ws_handler, '0.0.0.0', PORT, reuse_address=True):
log(f'listening on port {PORT}')
await asyncio.Future()
server = ThreadingHTTPServer(('', PORT), Handler)
log(f'listening on port {PORT}')
try:
server.serve_forever()
asyncio.run(main())
except KeyboardInterrupt:
pass
finally:
playback_queue.put(_SENTINEL)
_job_queue.put(_SENTINEL)
_playback_queue.put(_SENTINEL)