Replace HTTP API with WebSocket server
Single port (TTS_PORT) handles both the WS upgrade handshake and connections. Adds job queue, generation worker, playback events (queued/started/finished/aborted/error), and abort_current/abort_all commands. Fixes BrokenPipeError when pacat is killed mid-write. Updates all examples to use WebSocket; adds abort-demo.mjs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,18 +1,37 @@
|
||||
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
||||
"""
|
||||
Chatterbox TTS HTTP server — keeps model loaded, exposes a JSON HTTP API.
|
||||
Chatterbox TTS WebSocket server — keeps model loaded, exposes a JSON WebSocket API.
|
||||
|
||||
Endpoints:
|
||||
POST /speak {"text": "...", "temperature": 0.8, "top_p": 0.95, "audio_prompt": "/path.wav"}
|
||||
POST /chime {"path": "/path/to/file.wav"}
|
||||
POST /preload {"path": "/path/to/file.wav"}
|
||||
POST /command {"command": "terminate"}
|
||||
Connect to ws://host:TTS_PORT (default ws://localhost:11500).
|
||||
|
||||
All endpoints return {"status": "ok"} or {"status": "error", "message": "..."}.
|
||||
Responses are sent after audio is queued for playback (not after playback finishes).
|
||||
Client → Server:
|
||||
{"type": "speak", "id"?: N, "text": "...", ...generation_opts}
|
||||
{"type": "chime", "id"?: N, "path": "..."}
|
||||
{"type": "preload", "path": "..."}
|
||||
{"type": "abort_current"} — kill active playback, advance to next queued item
|
||||
{"type": "abort_all"} — kill active playback + drain all queues
|
||||
{"type": "terminate"}
|
||||
|
||||
Server → requesting client:
|
||||
{"status": "ok", "id": N} (speak/chime)
|
||||
{"status": "ok"}
|
||||
{"status": "error", "message": "..."}
|
||||
|
||||
Server → all clients (broadcast):
|
||||
{"event": "queued", "id": N}
|
||||
{"event": "started", "id": N}
|
||||
{"event": "finished", "id": N}
|
||||
{"event": "aborted", "id": N}
|
||||
{"event": "error", "id": N, "message": "..."}
|
||||
|
||||
Generation options (speak):
|
||||
temperature, top_p, top_k, repetition_penalty, min_p
|
||||
audio_prompt — path to reference WAV for voice cloning
|
||||
exaggeration — 0.0-1.0, full model only
|
||||
cfg_weight — full model only
|
||||
|
||||
Environment:
|
||||
TTS_PORT TCP port to listen on (default: 11500)
|
||||
TTS_PORT port to listen on (default: 11500)
|
||||
HF_TOKEN_FILE path to HuggingFace token file (default: ~/.secrets/hugging-face.token)
|
||||
HF_HUB_CACHE path to HuggingFace hub cache (default: ~/.cache/huggingface/hub)
|
||||
|
||||
@@ -23,9 +42,6 @@ Usage:
|
||||
|
||||
Paralinguistic tags supported in text:
|
||||
[laugh] [chuckle] [cough] [clear throat] [sigh] [shush] [groan] [sniff] [gasp]
|
||||
|
||||
Full model only:
|
||||
exaggeration 0.0-1.0 emotion intensity (ignored in turbo)
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -37,7 +53,8 @@ import threading
|
||||
import subprocess
|
||||
import traceback
|
||||
import tempfile
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
import asyncio
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
@@ -106,26 +123,32 @@ _wav_cache = {}
|
||||
_chime_cache = {}
|
||||
_gen_lock = threading.Lock()
|
||||
|
||||
_SENTINEL = object()
|
||||
playback_queue = queue.Queue()
|
||||
_SENTINEL = object()
|
||||
|
||||
_id_counter = itertools.count(1)
|
||||
def _next_id():
|
||||
return next(_id_counter)
|
||||
|
||||
_job_queue = queue.Queue() # dicts: {'id', 'type', ...}
|
||||
_playback_queue = queue.Queue() # dicts: {'id', 'samples'}
|
||||
|
||||
_current_proc = None
|
||||
_current_proc_lock = threading.Lock()
|
||||
_abort_flag = threading.Event()
|
||||
|
||||
_ws_clients = set() # asyncio.Queue per connected client
|
||||
_ws_clients_lock = threading.Lock()
|
||||
_ws_loop = None
|
||||
|
||||
|
||||
def playback_worker():
|
||||
while True:
|
||||
item = playback_queue.get()
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
proc = subprocess.Popen(
|
||||
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
|
||||
stdin=subprocess.PIPE,
|
||||
)
|
||||
proc.stdin.write(item.tobytes())
|
||||
proc.stdin.close()
|
||||
proc.wait()
|
||||
playback_queue.task_done()
|
||||
|
||||
|
||||
threading.Thread(target=playback_worker, daemon=True).start()
|
||||
def broadcast(event):
|
||||
if _ws_loop is None:
|
||||
return
|
||||
msg = json.dumps(event)
|
||||
with _ws_clients_lock:
|
||||
clients = list(_ws_clients)
|
||||
for q in clients:
|
||||
_ws_loop.call_soon_threadsafe(q.put_nowait, msg)
|
||||
|
||||
|
||||
def ensure_float32_wav(path):
|
||||
@@ -185,79 +208,179 @@ def generate(text, opts):
|
||||
return samples
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_json(self, data, status=200):
|
||||
body = json.dumps(data).encode()
|
||||
self.send_response(status)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header('Content-Length', str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def read_json(self):
|
||||
length = int(self.headers.get('Content-Length', 0))
|
||||
return json.loads(self.rfile.read(length))
|
||||
|
||||
def do_POST(self):
|
||||
def generation_worker():
|
||||
while True:
|
||||
item = _job_queue.get()
|
||||
if item is _SENTINEL:
|
||||
_job_queue.task_done()
|
||||
break
|
||||
job_id = item['id']
|
||||
job_type = item['type']
|
||||
try:
|
||||
req = self.read_json()
|
||||
except Exception:
|
||||
self.send_json({'status': 'error', 'message': 'invalid JSON'}, 400)
|
||||
return
|
||||
|
||||
if self.path == '/speak':
|
||||
text = req.pop('text', '')
|
||||
if not text:
|
||||
self.send_json({'status': 'ok'})
|
||||
return
|
||||
try:
|
||||
if job_type == 'speak':
|
||||
with _gen_lock:
|
||||
samples = generate(text, req)
|
||||
playback_queue.put(samples)
|
||||
self.send_json({'status': 'ok'})
|
||||
except Exception as e:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
self.send_json({'status': 'error', 'message': str(e)}, 500)
|
||||
samples = generate(item['text'], item)
|
||||
_playback_queue.put({'id': job_id, 'samples': samples})
|
||||
elif job_type == 'chime':
|
||||
samples = load_chime(item['path'])
|
||||
_playback_queue.put({'id': job_id, 'samples': samples})
|
||||
except Exception as e:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
broadcast({'event': 'error', 'id': job_id, 'message': str(e)})
|
||||
_job_queue.task_done()
|
||||
|
||||
elif self.path == '/chime':
|
||||
path = req.get('path', '')
|
||||
try:
|
||||
samples = load_chime(path)
|
||||
playback_queue.put(samples)
|
||||
self.send_json({'status': 'ok'})
|
||||
except Exception as e:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
self.send_json({'status': 'error', 'message': str(e)}, 500)
|
||||
|
||||
elif self.path == '/preload':
|
||||
path = req.get('path', '')
|
||||
try:
|
||||
load_chime(path)
|
||||
log(f'preloaded: {path}')
|
||||
self.send_json({'status': 'ok'})
|
||||
except Exception as e:
|
||||
self.send_json({'status': 'error', 'message': str(e)}, 500)
|
||||
def playback_worker():
|
||||
global _current_proc
|
||||
while True:
|
||||
item = _playback_queue.get()
|
||||
if item is _SENTINEL:
|
||||
_playback_queue.task_done()
|
||||
break
|
||||
job_id = item['id']
|
||||
samples = item['samples']
|
||||
|
||||
elif self.path == '/command':
|
||||
command = req.get('command', '')
|
||||
if command == 'terminate':
|
||||
self.send_json({'status': 'ok'})
|
||||
threading.Thread(target=server.shutdown, daemon=True).start()
|
||||
else:
|
||||
self.send_json({'status': 'error', 'message': f'unknown command: {command}'}, 400)
|
||||
_abort_flag.clear()
|
||||
broadcast({'event': 'started', 'id': job_id})
|
||||
|
||||
proc = subprocess.Popen(
|
||||
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
|
||||
stdin=subprocess.PIPE,
|
||||
)
|
||||
with _current_proc_lock:
|
||||
_current_proc = proc
|
||||
|
||||
try:
|
||||
proc.stdin.write(samples.tobytes())
|
||||
proc.stdin.close()
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
proc.wait()
|
||||
|
||||
with _current_proc_lock:
|
||||
_current_proc = None
|
||||
|
||||
if _abort_flag.is_set():
|
||||
broadcast({'event': 'aborted', 'id': job_id})
|
||||
else:
|
||||
self.send_json({'status': 'error', 'message': 'not found'}, 404)
|
||||
broadcast({'event': 'finished', 'id': job_id})
|
||||
|
||||
def log_message(self, fmt, *args):
|
||||
log(fmt % args)
|
||||
_playback_queue.task_done()
|
||||
|
||||
|
||||
def abort_current():
|
||||
_abort_flag.set()
|
||||
with _current_proc_lock:
|
||||
if _current_proc is not None:
|
||||
_current_proc.kill()
|
||||
|
||||
|
||||
def abort_all():
|
||||
drained_ids = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = _job_queue.get_nowait()
|
||||
if item is not _SENTINEL:
|
||||
drained_ids.append(item['id'])
|
||||
_job_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = _playback_queue.get_nowait()
|
||||
if item is not _SENTINEL:
|
||||
drained_ids.append(item['id'])
|
||||
_playback_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
abort_current()
|
||||
|
||||
for jid in drained_ids:
|
||||
broadcast({'event': 'aborted', 'id': jid})
|
||||
|
||||
|
||||
threading.Thread(target=generation_worker, daemon=True).start()
|
||||
threading.Thread(target=playback_worker, daemon=True).start()
|
||||
|
||||
|
||||
async def _ws_handler(websocket):
|
||||
q = asyncio.Queue()
|
||||
with _ws_clients_lock:
|
||||
_ws_clients.add(q)
|
||||
|
||||
async def sender():
|
||||
while True:
|
||||
msg = await q.get()
|
||||
await websocket.send(msg)
|
||||
|
||||
sender_task = asyncio.create_task(sender())
|
||||
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send(json.dumps({'status': 'error', 'message': 'invalid JSON'}))
|
||||
continue
|
||||
|
||||
msg_type = msg.get('type', '')
|
||||
|
||||
if msg_type in ('speak', 'chime'):
|
||||
job_id = msg.get('id') or _next_id()
|
||||
job = dict(msg)
|
||||
job['id'] = job_id
|
||||
_job_queue.put(job)
|
||||
broadcast({'event': 'queued', 'id': job_id})
|
||||
await websocket.send(json.dumps({'status': 'ok', 'id': job_id}))
|
||||
|
||||
elif msg_type == 'preload':
|
||||
path = msg.get('path', '')
|
||||
try:
|
||||
await asyncio.get_running_loop().run_in_executor(None, load_chime, path)
|
||||
log(f'preloaded: {path}')
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
except Exception as e:
|
||||
await websocket.send(json.dumps({'status': 'error', 'message': str(e)}))
|
||||
|
||||
elif msg_type == 'abort_current':
|
||||
abort_current()
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
|
||||
elif msg_type == 'abort_all':
|
||||
abort_all()
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
|
||||
elif msg_type == 'terminate':
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
asyncio.get_running_loop().stop()
|
||||
|
||||
else:
|
||||
await websocket.send(json.dumps({'status': 'error', 'message': f'unknown type: {msg_type}'}))
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
with _ws_clients_lock:
|
||||
_ws_clients.discard(q)
|
||||
sender_task.cancel()
|
||||
|
||||
|
||||
async def main():
|
||||
global _ws_loop
|
||||
from websockets.asyncio.server import serve as ws_serve
|
||||
_ws_loop = asyncio.get_running_loop()
|
||||
async with ws_serve(_ws_handler, '0.0.0.0', PORT, reuse_address=True):
|
||||
log(f'listening on port {PORT}')
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
server = ThreadingHTTPServer(('', PORT), Handler)
|
||||
log(f'listening on port {PORT}')
|
||||
try:
|
||||
server.serve_forever()
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
playback_queue.put(_SENTINEL)
|
||||
_job_queue.put(_SENTINEL)
|
||||
_playback_queue.put(_SENTINEL)
|
||||
|
||||
Reference in New Issue
Block a user