#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"' """ Chatterbox TTS WebSocket server — keeps model loaded, exposes a JSON WebSocket API. Connect to ws://host:TTS_PORT (default ws://localhost:11500). Client → Server: {"type": "speak", "id"?: N, "text": "...", ...generation_opts} {"type": "chime", "id"?: N, "path": "..."} {"type": "preload", "path": "..."} {"type": "abort_current"} — kill active playback, advance to next queued item {"type": "abort_all"} — kill active playback + drain all queues {"type": "terminate"} Server → requesting client: {"status": "ok", "id": N} (speak/chime) {"status": "ok"} {"status": "error", "message": "..."} Server → all clients (broadcast): {"event": "queued", "id": N} {"event": "started", "id": N} {"event": "finished", "id": N} {"event": "aborted", "id": N} {"event": "error", "id": N, "message": "..."} Generation options (speak): temperature, top_p, top_k, repetition_penalty, min_p audio_prompt — path to reference WAV for voice cloning exaggeration — 0.0-1.0, full model only cfg_weight — full model only Environment: TTS_PORT port to listen on (default: 11500) HF_TOKEN_FILE path to HuggingFace token file (default: ~/.secrets/hugging-face.token) HF_HUB_CACHE path to HuggingFace hub cache (default: ~/.cache/huggingface/hub) Usage: ./chatterbox-server.py ./chatterbox-server.py turbo # default ./chatterbox-server.py full # original model, supports exaggeration Paralinguistic tags supported in text: [laugh] [chuckle] [cough] [clear throat] [sigh] [shush] [groan] [sniff] [gasp] """ import os import sys import json import time import queue import threading import subprocess import traceback import tempfile import asyncio import itertools from pathlib import Path import numpy as np TOKEN_FILE = os.environ.get('HF_TOKEN_FILE', os.path.expanduser('~/.secrets/hugging-face.token')) try: with open(TOKEN_FILE) as f: os.environ['HF_TOKEN'] = f.read().strip() except FileNotFoundError: pass def find_hf_cache(repo_id): cache_dir = Path(os.environ.get('HF_HUB_CACHE', os.path.expanduser('~/.cache/huggingface/hub'))) repo_dir = cache_dir / f"models--{repo_id.replace('/', '--')}" / 'snapshots' if repo_dir.exists(): snapshots = sorted(repo_dir.iterdir(), key=lambda p: p.stat().st_mtime) if snapshots: return str(snapshots[-1]) return None VARIANT = sys.argv[1] if len(sys.argv) > 1 else 'turbo' PORT = int(os.environ.get('TTS_PORT', 11500)) SAMPLE_RATE = 24000 def log(msg): print(f'[chatterbox] {msg}', file=sys.stderr, flush=True) log(f'loading chatterbox-{VARIANT}...') t0 = time.time() import torch import soundfile as sf import librosa as _librosa _orig_resample = _librosa.resample def _resample_float32(*args, **kwargs): return _orig_resample(*args, **kwargs).astype(np.float32) _librosa.resample = _resample_float32 device = 'cuda' if torch.cuda.is_available() else 'cpu' REPO_IDS = { 'turbo': 'ResembleAI/chatterbox-turbo', 'full': 'ResembleAI/chatterbox', } if VARIANT == 'turbo': from chatterbox.tts_turbo import ChatterboxTurboTTS as Model else: from chatterbox.tts import ChatterboxTTS as Model cached = find_hf_cache(REPO_IDS[VARIANT]) if cached: log(f'loading from cache: {cached}') model = Model.from_local(cached, device=device) else: log('cache not found, downloading...') model = Model.from_pretrained(device=device) log(f'ready on {device} ({time.time() - t0:.1f}s load time)') _wav_cache = {} _chime_cache = {} _gen_lock = threading.Lock() _SENTINEL = object() _id_counter = itertools.count(1) def _next_id(): return next(_id_counter) _job_queue = queue.Queue() # dicts: {'id', 'type', ...} _playback_queue = queue.Queue() # dicts: {'id', 'samples'} _current_proc = None _current_proc_lock = threading.Lock() _abort_flag = threading.Event() _ws_clients = set() # asyncio.Queue per connected client _ws_clients_lock = threading.Lock() _ws_loop = None def broadcast(event): if _ws_loop is None: return msg = json.dumps(event) with _ws_clients_lock: clients = list(_ws_clients) for q in clients: _ws_loop.call_soon_threadsafe(q.put_nowait, msg) def ensure_float32_wav(path): if path in _wav_cache: return _wav_cache[path] wav, sr = sf.read(path, dtype='float32', always_2d=True) wav = wav.mean(axis=1) tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) sf.write(tmp.name, wav, sr, subtype='FLOAT') _wav_cache[path] = tmp.name return tmp.name def load_chime(path): if path in _chime_cache: return _chime_cache[path] samples, sr = sf.read(path, dtype='float32', always_2d=True) samples = samples.mean(axis=1) if sr != SAMPLE_RATE: samples = _librosa.resample(samples, orig_sr=sr, target_sr=SAMPLE_RATE) _chime_cache[path] = samples return samples def generate(text, opts): t1 = time.time() if VARIANT == 'turbo': kwargs = { 'temperature': opts.get('temperature', 0.8), 'top_p': opts.get('top_p', 0.95), 'top_k': opts.get('top_k', 1000), 'repetition_penalty': opts.get('repetition_penalty', 1.2), 'min_p': opts.get('min_p', 0.0), } else: kwargs = { 'temperature': opts.get('temperature', 0.8), 'top_p': opts.get('top_p', 1.0), 'repetition_penalty': opts.get('repetition_penalty', 1.2), 'min_p': opts.get('min_p', 0.05), 'exaggeration': opts.get('exaggeration', 0.5), 'cfg_weight': opts.get('cfg_weight', 0.5), } audio_prompt = opts.get('audio_prompt') if audio_prompt: kwargs['audio_prompt_path'] = ensure_float32_wav(audio_prompt) with torch.inference_mode(): wav = model.generate(text, **kwargs) samples = wav.squeeze(0).cpu().numpy().astype(np.float32) elapsed = time.time() - t1 duration = len(samples) / SAMPLE_RATE log(f'generated {duration:.1f}s audio in {elapsed:.1f}s rtf={elapsed/duration:.2f}') return samples def generation_worker(): while True: item = _job_queue.get() if item is _SENTINEL: _job_queue.task_done() break job_id = item['id'] job_type = item['type'] try: if job_type == 'speak': with _gen_lock: samples = generate(item['text'], item) _playback_queue.put({'id': job_id, 'samples': samples}) elif job_type == 'chime': samples = load_chime(item['path']) _playback_queue.put({'id': job_id, 'samples': samples}) except Exception as e: traceback.print_exc(file=sys.stderr) broadcast({'event': 'error', 'id': job_id, 'message': str(e)}) _job_queue.task_done() def playback_worker(): global _current_proc while True: item = _playback_queue.get() if item is _SENTINEL: _playback_queue.task_done() break job_id = item['id'] samples = item['samples'] _abort_flag.clear() broadcast({'event': 'started', 'id': job_id}) proc = subprocess.Popen( ['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'], stdin=subprocess.PIPE, ) with _current_proc_lock: _current_proc = proc try: proc.stdin.write(samples.tobytes()) proc.stdin.close() except BrokenPipeError: pass proc.wait() with _current_proc_lock: _current_proc = None if _abort_flag.is_set(): broadcast({'event': 'aborted', 'id': job_id}) else: broadcast({'event': 'finished', 'id': job_id}) _playback_queue.task_done() def abort_current(): _abort_flag.set() with _current_proc_lock: if _current_proc is not None: _current_proc.kill() def abort_all(): drained_ids = [] while True: try: item = _job_queue.get_nowait() if item is not _SENTINEL: drained_ids.append(item['id']) _job_queue.task_done() except queue.Empty: break while True: try: item = _playback_queue.get_nowait() if item is not _SENTINEL: drained_ids.append(item['id']) _playback_queue.task_done() except queue.Empty: break abort_current() for jid in drained_ids: broadcast({'event': 'aborted', 'id': jid}) threading.Thread(target=generation_worker, daemon=True).start() threading.Thread(target=playback_worker, daemon=True).start() async def _ws_handler(websocket): q = asyncio.Queue() with _ws_clients_lock: _ws_clients.add(q) async def sender(): while True: msg = await q.get() await websocket.send(msg) sender_task = asyncio.create_task(sender()) try: async for raw in websocket: try: msg = json.loads(raw) except json.JSONDecodeError: await websocket.send(json.dumps({'status': 'error', 'message': 'invalid JSON'})) continue msg_type = msg.get('type', '') if msg_type in ('speak', 'chime'): job_id = msg.get('id') or _next_id() job = dict(msg) job['id'] = job_id _job_queue.put(job) broadcast({'event': 'queued', 'id': job_id}) await websocket.send(json.dumps({'status': 'ok', 'id': job_id})) elif msg_type == 'preload': path = msg.get('path', '') try: await asyncio.get_running_loop().run_in_executor(None, load_chime, path) log(f'preloaded: {path}') await websocket.send(json.dumps({'status': 'ok'})) except Exception as e: await websocket.send(json.dumps({'status': 'error', 'message': str(e)})) elif msg_type == 'abort_current': abort_current() await websocket.send(json.dumps({'status': 'ok'})) elif msg_type == 'abort_all': abort_all() await websocket.send(json.dumps({'status': 'ok'})) elif msg_type == 'terminate': await websocket.send(json.dumps({'status': 'ok'})) asyncio.get_running_loop().stop() else: await websocket.send(json.dumps({'status': 'error', 'message': f'unknown type: {msg_type}'})) except Exception: pass finally: with _ws_clients_lock: _ws_clients.discard(q) sender_task.cancel() async def main(): global _ws_loop from websockets.asyncio.server import serve as ws_serve _ws_loop = asyncio.get_running_loop() async with ws_serve(_ws_handler, '0.0.0.0', PORT, reuse_address=True): log(f'listening on port {PORT}') await asyncio.Future() try: asyncio.run(main()) except KeyboardInterrupt: pass finally: _job_queue.put(_SENTINEL) _playback_queue.put(_SENTINEL)