Files
tts-server/chatterbox-server.py
mikael-lovqvists-claude-agent f4ae96c6b9 Replace HTTP API with WebSocket server
Single port (TTS_PORT) handles both the WS upgrade handshake and
connections. Adds job queue, generation worker, playback events
(queued/started/finished/aborted/error), and abort_current/abort_all
commands. Fixes BrokenPipeError when pacat is killed mid-write.
Updates all examples to use WebSocket; adds abort-demo.mjs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 21:32:58 +00:00

387 lines
10 KiB
Python
Executable File

#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
"""
Chatterbox TTS WebSocket server — keeps model loaded, exposes a JSON WebSocket API.
Connect to ws://host:TTS_PORT (default ws://localhost:11500).
Client → Server:
{"type": "speak", "id"?: N, "text": "...", ...generation_opts}
{"type": "chime", "id"?: N, "path": "..."}
{"type": "preload", "path": "..."}
{"type": "abort_current"} — kill active playback, advance to next queued item
{"type": "abort_all"} — kill active playback + drain all queues
{"type": "terminate"}
Server → requesting client:
{"status": "ok", "id": N} (speak/chime)
{"status": "ok"}
{"status": "error", "message": "..."}
Server → all clients (broadcast):
{"event": "queued", "id": N}
{"event": "started", "id": N}
{"event": "finished", "id": N}
{"event": "aborted", "id": N}
{"event": "error", "id": N, "message": "..."}
Generation options (speak):
temperature, top_p, top_k, repetition_penalty, min_p
audio_prompt — path to reference WAV for voice cloning
exaggeration — 0.0-1.0, full model only
cfg_weight — full model only
Environment:
TTS_PORT port to listen on (default: 11500)
HF_TOKEN_FILE path to HuggingFace token file (default: ~/.secrets/hugging-face.token)
HF_HUB_CACHE path to HuggingFace hub cache (default: ~/.cache/huggingface/hub)
Usage:
./chatterbox-server.py
./chatterbox-server.py turbo # default
./chatterbox-server.py full # original model, supports exaggeration
Paralinguistic tags supported in text:
[laugh] [chuckle] [cough] [clear throat] [sigh] [shush] [groan] [sniff] [gasp]
"""
import os
import sys
import json
import time
import queue
import threading
import subprocess
import traceback
import tempfile
import asyncio
import itertools
from pathlib import Path
import numpy as np
TOKEN_FILE = os.environ.get('HF_TOKEN_FILE', os.path.expanduser('~/.secrets/hugging-face.token'))
try:
with open(TOKEN_FILE) as f:
os.environ['HF_TOKEN'] = f.read().strip()
except FileNotFoundError:
pass
def find_hf_cache(repo_id):
cache_dir = Path(os.environ.get('HF_HUB_CACHE', os.path.expanduser('~/.cache/huggingface/hub')))
repo_dir = cache_dir / f"models--{repo_id.replace('/', '--')}" / 'snapshots'
if repo_dir.exists():
snapshots = sorted(repo_dir.iterdir(), key=lambda p: p.stat().st_mtime)
if snapshots:
return str(snapshots[-1])
return None
VARIANT = sys.argv[1] if len(sys.argv) > 1 else 'turbo'
PORT = int(os.environ.get('TTS_PORT', 11500))
SAMPLE_RATE = 24000
def log(msg):
print(f'[chatterbox] {msg}', file=sys.stderr, flush=True)
log(f'loading chatterbox-{VARIANT}...')
t0 = time.time()
import torch
import soundfile as sf
import librosa as _librosa
_orig_resample = _librosa.resample
def _resample_float32(*args, **kwargs):
return _orig_resample(*args, **kwargs).astype(np.float32)
_librosa.resample = _resample_float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
REPO_IDS = {
'turbo': 'ResembleAI/chatterbox-turbo',
'full': 'ResembleAI/chatterbox',
}
if VARIANT == 'turbo':
from chatterbox.tts_turbo import ChatterboxTurboTTS as Model
else:
from chatterbox.tts import ChatterboxTTS as Model
cached = find_hf_cache(REPO_IDS[VARIANT])
if cached:
log(f'loading from cache: {cached}')
model = Model.from_local(cached, device=device)
else:
log('cache not found, downloading...')
model = Model.from_pretrained(device=device)
log(f'ready on {device} ({time.time() - t0:.1f}s load time)')
_wav_cache = {}
_chime_cache = {}
_gen_lock = threading.Lock()
_SENTINEL = object()
_id_counter = itertools.count(1)
def _next_id():
return next(_id_counter)
_job_queue = queue.Queue() # dicts: {'id', 'type', ...}
_playback_queue = queue.Queue() # dicts: {'id', 'samples'}
_current_proc = None
_current_proc_lock = threading.Lock()
_abort_flag = threading.Event()
_ws_clients = set() # asyncio.Queue per connected client
_ws_clients_lock = threading.Lock()
_ws_loop = None
def broadcast(event):
if _ws_loop is None:
return
msg = json.dumps(event)
with _ws_clients_lock:
clients = list(_ws_clients)
for q in clients:
_ws_loop.call_soon_threadsafe(q.put_nowait, msg)
def ensure_float32_wav(path):
if path in _wav_cache:
return _wav_cache[path]
wav, sr = sf.read(path, dtype='float32', always_2d=True)
wav = wav.mean(axis=1)
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
sf.write(tmp.name, wav, sr, subtype='FLOAT')
_wav_cache[path] = tmp.name
return tmp.name
def load_chime(path):
if path in _chime_cache:
return _chime_cache[path]
samples, sr = sf.read(path, dtype='float32', always_2d=True)
samples = samples.mean(axis=1)
if sr != SAMPLE_RATE:
samples = _librosa.resample(samples, orig_sr=sr, target_sr=SAMPLE_RATE)
_chime_cache[path] = samples
return samples
def generate(text, opts):
t1 = time.time()
if VARIANT == 'turbo':
kwargs = {
'temperature': opts.get('temperature', 0.8),
'top_p': opts.get('top_p', 0.95),
'top_k': opts.get('top_k', 1000),
'repetition_penalty': opts.get('repetition_penalty', 1.2),
'min_p': opts.get('min_p', 0.0),
}
else:
kwargs = {
'temperature': opts.get('temperature', 0.8),
'top_p': opts.get('top_p', 1.0),
'repetition_penalty': opts.get('repetition_penalty', 1.2),
'min_p': opts.get('min_p', 0.05),
'exaggeration': opts.get('exaggeration', 0.5),
'cfg_weight': opts.get('cfg_weight', 0.5),
}
audio_prompt = opts.get('audio_prompt')
if audio_prompt:
kwargs['audio_prompt_path'] = ensure_float32_wav(audio_prompt)
with torch.inference_mode():
wav = model.generate(text, **kwargs)
samples = wav.squeeze(0).cpu().numpy().astype(np.float32)
elapsed = time.time() - t1
duration = len(samples) / SAMPLE_RATE
log(f'generated {duration:.1f}s audio in {elapsed:.1f}s rtf={elapsed/duration:.2f}')
return samples
def generation_worker():
while True:
item = _job_queue.get()
if item is _SENTINEL:
_job_queue.task_done()
break
job_id = item['id']
job_type = item['type']
try:
if job_type == 'speak':
with _gen_lock:
samples = generate(item['text'], item)
_playback_queue.put({'id': job_id, 'samples': samples})
elif job_type == 'chime':
samples = load_chime(item['path'])
_playback_queue.put({'id': job_id, 'samples': samples})
except Exception as e:
traceback.print_exc(file=sys.stderr)
broadcast({'event': 'error', 'id': job_id, 'message': str(e)})
_job_queue.task_done()
def playback_worker():
global _current_proc
while True:
item = _playback_queue.get()
if item is _SENTINEL:
_playback_queue.task_done()
break
job_id = item['id']
samples = item['samples']
_abort_flag.clear()
broadcast({'event': 'started', 'id': job_id})
proc = subprocess.Popen(
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
stdin=subprocess.PIPE,
)
with _current_proc_lock:
_current_proc = proc
try:
proc.stdin.write(samples.tobytes())
proc.stdin.close()
except BrokenPipeError:
pass
proc.wait()
with _current_proc_lock:
_current_proc = None
if _abort_flag.is_set():
broadcast({'event': 'aborted', 'id': job_id})
else:
broadcast({'event': 'finished', 'id': job_id})
_playback_queue.task_done()
def abort_current():
_abort_flag.set()
with _current_proc_lock:
if _current_proc is not None:
_current_proc.kill()
def abort_all():
drained_ids = []
while True:
try:
item = _job_queue.get_nowait()
if item is not _SENTINEL:
drained_ids.append(item['id'])
_job_queue.task_done()
except queue.Empty:
break
while True:
try:
item = _playback_queue.get_nowait()
if item is not _SENTINEL:
drained_ids.append(item['id'])
_playback_queue.task_done()
except queue.Empty:
break
abort_current()
for jid in drained_ids:
broadcast({'event': 'aborted', 'id': jid})
threading.Thread(target=generation_worker, daemon=True).start()
threading.Thread(target=playback_worker, daemon=True).start()
async def _ws_handler(websocket):
q = asyncio.Queue()
with _ws_clients_lock:
_ws_clients.add(q)
async def sender():
while True:
msg = await q.get()
await websocket.send(msg)
sender_task = asyncio.create_task(sender())
try:
async for raw in websocket:
try:
msg = json.loads(raw)
except json.JSONDecodeError:
await websocket.send(json.dumps({'status': 'error', 'message': 'invalid JSON'}))
continue
msg_type = msg.get('type', '')
if msg_type in ('speak', 'chime'):
job_id = msg.get('id') or _next_id()
job = dict(msg)
job['id'] = job_id
_job_queue.put(job)
broadcast({'event': 'queued', 'id': job_id})
await websocket.send(json.dumps({'status': 'ok', 'id': job_id}))
elif msg_type == 'preload':
path = msg.get('path', '')
try:
await asyncio.get_running_loop().run_in_executor(None, load_chime, path)
log(f'preloaded: {path}')
await websocket.send(json.dumps({'status': 'ok'}))
except Exception as e:
await websocket.send(json.dumps({'status': 'error', 'message': str(e)}))
elif msg_type == 'abort_current':
abort_current()
await websocket.send(json.dumps({'status': 'ok'}))
elif msg_type == 'abort_all':
abort_all()
await websocket.send(json.dumps({'status': 'ok'}))
elif msg_type == 'terminate':
await websocket.send(json.dumps({'status': 'ok'}))
asyncio.get_running_loop().stop()
else:
await websocket.send(json.dumps({'status': 'error', 'message': f'unknown type: {msg_type}'}))
except Exception:
pass
finally:
with _ws_clients_lock:
_ws_clients.discard(q)
sender_task.cancel()
async def main():
global _ws_loop
from websockets.asyncio.server import serve as ws_serve
_ws_loop = asyncio.get_running_loop()
async with ws_serve(_ws_handler, '0.0.0.0', PORT, reuse_address=True):
log(f'listening on port {PORT}')
await asyncio.Future()
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
finally:
_job_queue.put(_SENTINEL)
_playback_queue.put(_SENTINEL)