Merge pull request 'Replace HTTP API with WebSocket server' (#4) from mikael-lovqvists-claude-agent/tts-server:websocket-api into main
Reviewed-on: #4
This commit was merged in pull request #4.
This commit is contained in:
@@ -1,18 +1,37 @@
|
||||
#!/usr/bin/env -S bash -c 'exec "$(dirname "$0")/venv/bin/python3" "$0" "$@"'
|
||||
"""
|
||||
Chatterbox TTS HTTP server — keeps model loaded, exposes a JSON HTTP API.
|
||||
Chatterbox TTS WebSocket server — keeps model loaded, exposes a JSON WebSocket API.
|
||||
|
||||
Endpoints:
|
||||
POST /speak {"text": "...", "temperature": 0.8, "top_p": 0.95, "audio_prompt": "/path.wav"}
|
||||
POST /chime {"path": "/path/to/file.wav"}
|
||||
POST /preload {"path": "/path/to/file.wav"}
|
||||
POST /command {"command": "terminate"}
|
||||
Connect to ws://host:TTS_PORT (default ws://localhost:11500).
|
||||
|
||||
All endpoints return {"status": "ok"} or {"status": "error", "message": "..."}.
|
||||
Responses are sent after audio is queued for playback (not after playback finishes).
|
||||
Client → Server:
|
||||
{"type": "speak", "id"?: N, "text": "...", ...generation_opts}
|
||||
{"type": "chime", "id"?: N, "path": "..."}
|
||||
{"type": "preload", "path": "..."}
|
||||
{"type": "abort_current"} — kill active playback, advance to next queued item
|
||||
{"type": "abort_all"} — kill active playback + drain all queues
|
||||
{"type": "terminate"}
|
||||
|
||||
Server → requesting client:
|
||||
{"status": "ok", "id": N} (speak/chime)
|
||||
{"status": "ok"}
|
||||
{"status": "error", "message": "..."}
|
||||
|
||||
Server → all clients (broadcast):
|
||||
{"event": "queued", "id": N}
|
||||
{"event": "started", "id": N}
|
||||
{"event": "finished", "id": N}
|
||||
{"event": "aborted", "id": N}
|
||||
{"event": "error", "id": N, "message": "..."}
|
||||
|
||||
Generation options (speak):
|
||||
temperature, top_p, top_k, repetition_penalty, min_p
|
||||
audio_prompt — path to reference WAV for voice cloning
|
||||
exaggeration — 0.0-1.0, full model only
|
||||
cfg_weight — full model only
|
||||
|
||||
Environment:
|
||||
TTS_PORT TCP port to listen on (default: 11500)
|
||||
TTS_PORT port to listen on (default: 11500)
|
||||
HF_TOKEN_FILE path to HuggingFace token file (default: ~/.secrets/hugging-face.token)
|
||||
HF_HUB_CACHE path to HuggingFace hub cache (default: ~/.cache/huggingface/hub)
|
||||
|
||||
@@ -23,9 +42,6 @@ Usage:
|
||||
|
||||
Paralinguistic tags supported in text:
|
||||
[laugh] [chuckle] [cough] [clear throat] [sigh] [shush] [groan] [sniff] [gasp]
|
||||
|
||||
Full model only:
|
||||
exaggeration 0.0-1.0 emotion intensity (ignored in turbo)
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -37,7 +53,8 @@ import threading
|
||||
import subprocess
|
||||
import traceback
|
||||
import tempfile
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
import asyncio
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
@@ -107,25 +124,31 @@ _chime_cache = {}
|
||||
_gen_lock = threading.Lock()
|
||||
|
||||
_SENTINEL = object()
|
||||
playback_queue = queue.Queue()
|
||||
|
||||
_id_counter = itertools.count(1)
|
||||
def _next_id():
|
||||
return next(_id_counter)
|
||||
|
||||
_job_queue = queue.Queue() # dicts: {'id', 'type', ...}
|
||||
_playback_queue = queue.Queue() # dicts: {'id', 'samples'}
|
||||
|
||||
_current_proc = None
|
||||
_current_proc_lock = threading.Lock()
|
||||
_abort_flag = threading.Event()
|
||||
|
||||
_ws_clients = set() # asyncio.Queue per connected client
|
||||
_ws_clients_lock = threading.Lock()
|
||||
_ws_loop = None
|
||||
|
||||
|
||||
def playback_worker():
|
||||
while True:
|
||||
item = playback_queue.get()
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
proc = subprocess.Popen(
|
||||
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
|
||||
stdin=subprocess.PIPE,
|
||||
)
|
||||
proc.stdin.write(item.tobytes())
|
||||
proc.stdin.close()
|
||||
proc.wait()
|
||||
playback_queue.task_done()
|
||||
|
||||
|
||||
threading.Thread(target=playback_worker, daemon=True).start()
|
||||
def broadcast(event):
|
||||
if _ws_loop is None:
|
||||
return
|
||||
msg = json.dumps(event)
|
||||
with _ws_clients_lock:
|
||||
clients = list(_ws_clients)
|
||||
for q in clients:
|
||||
_ws_loop.call_soon_threadsafe(q.put_nowait, msg)
|
||||
|
||||
|
||||
def ensure_float32_wav(path):
|
||||
@@ -185,79 +208,179 @@ def generate(text, opts):
|
||||
return samples
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_json(self, data, status=200):
|
||||
body = json.dumps(data).encode()
|
||||
self.send_response(status)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header('Content-Length', str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def read_json(self):
|
||||
length = int(self.headers.get('Content-Length', 0))
|
||||
return json.loads(self.rfile.read(length))
|
||||
|
||||
def do_POST(self):
|
||||
try:
|
||||
req = self.read_json()
|
||||
except Exception:
|
||||
self.send_json({'status': 'error', 'message': 'invalid JSON'}, 400)
|
||||
return
|
||||
|
||||
if self.path == '/speak':
|
||||
text = req.pop('text', '')
|
||||
if not text:
|
||||
self.send_json({'status': 'ok'})
|
||||
return
|
||||
def generation_worker():
|
||||
while True:
|
||||
item = _job_queue.get()
|
||||
if item is _SENTINEL:
|
||||
_job_queue.task_done()
|
||||
break
|
||||
job_id = item['id']
|
||||
job_type = item['type']
|
||||
try:
|
||||
if job_type == 'speak':
|
||||
with _gen_lock:
|
||||
samples = generate(text, req)
|
||||
playback_queue.put(samples)
|
||||
self.send_json({'status': 'ok'})
|
||||
samples = generate(item['text'], item)
|
||||
_playback_queue.put({'id': job_id, 'samples': samples})
|
||||
elif job_type == 'chime':
|
||||
samples = load_chime(item['path'])
|
||||
_playback_queue.put({'id': job_id, 'samples': samples})
|
||||
except Exception as e:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
self.send_json({'status': 'error', 'message': str(e)}, 500)
|
||||
broadcast({'event': 'error', 'id': job_id, 'message': str(e)})
|
||||
_job_queue.task_done()
|
||||
|
||||
elif self.path == '/chime':
|
||||
path = req.get('path', '')
|
||||
try:
|
||||
samples = load_chime(path)
|
||||
playback_queue.put(samples)
|
||||
self.send_json({'status': 'ok'})
|
||||
except Exception as e:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
self.send_json({'status': 'error', 'message': str(e)}, 500)
|
||||
|
||||
elif self.path == '/preload':
|
||||
path = req.get('path', '')
|
||||
def playback_worker():
|
||||
global _current_proc
|
||||
while True:
|
||||
item = _playback_queue.get()
|
||||
if item is _SENTINEL:
|
||||
_playback_queue.task_done()
|
||||
break
|
||||
job_id = item['id']
|
||||
samples = item['samples']
|
||||
|
||||
_abort_flag.clear()
|
||||
broadcast({'event': 'started', 'id': job_id})
|
||||
|
||||
proc = subprocess.Popen(
|
||||
['pacat', '--format=float32le', f'--rate={SAMPLE_RATE}', '--channels=1'],
|
||||
stdin=subprocess.PIPE,
|
||||
)
|
||||
with _current_proc_lock:
|
||||
_current_proc = proc
|
||||
|
||||
try:
|
||||
load_chime(path)
|
||||
proc.stdin.write(samples.tobytes())
|
||||
proc.stdin.close()
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
proc.wait()
|
||||
|
||||
with _current_proc_lock:
|
||||
_current_proc = None
|
||||
|
||||
if _abort_flag.is_set():
|
||||
broadcast({'event': 'aborted', 'id': job_id})
|
||||
else:
|
||||
broadcast({'event': 'finished', 'id': job_id})
|
||||
|
||||
_playback_queue.task_done()
|
||||
|
||||
|
||||
def abort_current():
|
||||
_abort_flag.set()
|
||||
with _current_proc_lock:
|
||||
if _current_proc is not None:
|
||||
_current_proc.kill()
|
||||
|
||||
|
||||
def abort_all():
|
||||
drained_ids = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = _job_queue.get_nowait()
|
||||
if item is not _SENTINEL:
|
||||
drained_ids.append(item['id'])
|
||||
_job_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = _playback_queue.get_nowait()
|
||||
if item is not _SENTINEL:
|
||||
drained_ids.append(item['id'])
|
||||
_playback_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
abort_current()
|
||||
|
||||
for jid in drained_ids:
|
||||
broadcast({'event': 'aborted', 'id': jid})
|
||||
|
||||
|
||||
threading.Thread(target=generation_worker, daemon=True).start()
|
||||
threading.Thread(target=playback_worker, daemon=True).start()
|
||||
|
||||
|
||||
async def _ws_handler(websocket):
|
||||
q = asyncio.Queue()
|
||||
with _ws_clients_lock:
|
||||
_ws_clients.add(q)
|
||||
|
||||
async def sender():
|
||||
while True:
|
||||
msg = await q.get()
|
||||
await websocket.send(msg)
|
||||
|
||||
sender_task = asyncio.create_task(sender())
|
||||
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send(json.dumps({'status': 'error', 'message': 'invalid JSON'}))
|
||||
continue
|
||||
|
||||
msg_type = msg.get('type', '')
|
||||
|
||||
if msg_type in ('speak', 'chime'):
|
||||
job_id = msg.get('id') or _next_id()
|
||||
job = dict(msg)
|
||||
job['id'] = job_id
|
||||
_job_queue.put(job)
|
||||
broadcast({'event': 'queued', 'id': job_id})
|
||||
await websocket.send(json.dumps({'status': 'ok', 'id': job_id}))
|
||||
|
||||
elif msg_type == 'preload':
|
||||
path = msg.get('path', '')
|
||||
try:
|
||||
await asyncio.get_running_loop().run_in_executor(None, load_chime, path)
|
||||
log(f'preloaded: {path}')
|
||||
self.send_json({'status': 'ok'})
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
except Exception as e:
|
||||
self.send_json({'status': 'error', 'message': str(e)}, 500)
|
||||
await websocket.send(json.dumps({'status': 'error', 'message': str(e)}))
|
||||
|
||||
elif self.path == '/command':
|
||||
command = req.get('command', '')
|
||||
if command == 'terminate':
|
||||
self.send_json({'status': 'ok'})
|
||||
threading.Thread(target=server.shutdown, daemon=True).start()
|
||||
else:
|
||||
self.send_json({'status': 'error', 'message': f'unknown command: {command}'}, 400)
|
||||
elif msg_type == 'abort_current':
|
||||
abort_current()
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
|
||||
elif msg_type == 'abort_all':
|
||||
abort_all()
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
|
||||
elif msg_type == 'terminate':
|
||||
await websocket.send(json.dumps({'status': 'ok'}))
|
||||
asyncio.get_running_loop().stop()
|
||||
|
||||
else:
|
||||
self.send_json({'status': 'error', 'message': 'not found'}, 404)
|
||||
await websocket.send(json.dumps({'status': 'error', 'message': f'unknown type: {msg_type}'}))
|
||||
|
||||
def log_message(self, fmt, *args):
|
||||
log(fmt % args)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
with _ws_clients_lock:
|
||||
_ws_clients.discard(q)
|
||||
sender_task.cancel()
|
||||
|
||||
|
||||
server = ThreadingHTTPServer(('', PORT), Handler)
|
||||
async def main():
|
||||
global _ws_loop
|
||||
from websockets.asyncio.server import serve as ws_serve
|
||||
_ws_loop = asyncio.get_running_loop()
|
||||
async with ws_serve(_ws_handler, '0.0.0.0', PORT, reuse_address=True):
|
||||
log(f'listening on port {PORT}')
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
try:
|
||||
server.serve_forever()
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
playback_queue.put(_SENTINEL)
|
||||
_job_queue.put(_SENTINEL)
|
||||
_playback_queue.put(_SENTINEL)
|
||||
|
||||
34
examples/abort-demo.mjs
Normal file
34
examples/abort-demo.mjs
Normal file
@@ -0,0 +1,34 @@
|
||||
// Start speaking a long sentence, then abort a few seconds in.
|
||||
// Usage: node abort-demo.mjs
|
||||
|
||||
const PORT = process.env.TTS_PORT ?? '11500'
|
||||
const text = 'This is a very long sentence that will be cut off before it finishes, ' +
|
||||
'because a few seconds after playback starts we will send an abort command ' +
|
||||
'to demonstrate the abort current functionality of the server.'
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||
|
||||
ws.addEventListener('open', () => {
|
||||
ws.send(JSON.stringify({ type: 'speak', text }))
|
||||
})
|
||||
|
||||
ws.addEventListener('message', ({ data }) => {
|
||||
const msg = JSON.parse(data)
|
||||
console.log(msg)
|
||||
|
||||
if (msg.event === 'started') {
|
||||
setTimeout(() => {
|
||||
console.log('aborting...')
|
||||
ws.send(JSON.stringify({ type: 'abort_current' }))
|
||||
}, 3000)
|
||||
}
|
||||
|
||||
if (msg.event === 'aborted' || msg.event === 'finished' || msg.event === 'error') {
|
||||
ws.close()
|
||||
}
|
||||
})
|
||||
|
||||
ws.addEventListener('error', (e) => {
|
||||
console.error('error:', e.message)
|
||||
process.exit(1)
|
||||
})
|
||||
@@ -9,14 +9,20 @@ if (!path) {
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
const res = await fetch(`http://localhost:${PORT}/chime`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ path }),
|
||||
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||
|
||||
ws.addEventListener('open', () => {
|
||||
ws.send(JSON.stringify({ type: 'chime', path }))
|
||||
})
|
||||
|
||||
const data = await res.json()
|
||||
if (data.status !== 'ok') {
|
||||
console.error('error:', data.message)
|
||||
process.exit(1)
|
||||
ws.addEventListener('message', ({ data }) => {
|
||||
const msg = JSON.parse(data)
|
||||
if (msg.event === 'finished' || msg.event === 'aborted' || msg.event === 'error') {
|
||||
ws.close()
|
||||
}
|
||||
})
|
||||
|
||||
ws.addEventListener('error', (e) => {
|
||||
console.error('error:', e.message)
|
||||
process.exit(1)
|
||||
})
|
||||
|
||||
@@ -4,14 +4,20 @@
|
||||
const PORT = process.env.TTS_PORT ?? '11500'
|
||||
const text = process.argv[2] ?? 'Hello from Node.'
|
||||
|
||||
const res = await fetch(`http://localhost:${PORT}/speak`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ text }),
|
||||
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||
|
||||
ws.addEventListener('open', () => {
|
||||
ws.send(JSON.stringify({ type: 'speak', text }))
|
||||
})
|
||||
|
||||
const data = await res.json()
|
||||
if (data.status !== 'ok') {
|
||||
console.error('error:', data.message)
|
||||
process.exit(1)
|
||||
ws.addEventListener('message', ({ data }) => {
|
||||
const msg = JSON.parse(data)
|
||||
if (msg.event === 'finished' || msg.event === 'aborted' || msg.event === 'error') {
|
||||
ws.close()
|
||||
}
|
||||
})
|
||||
|
||||
ws.addEventListener('error', (e) => {
|
||||
console.error('error:', e.message)
|
||||
process.exit(1)
|
||||
})
|
||||
|
||||
@@ -3,14 +3,22 @@
|
||||
|
||||
const PORT = process.env.TTS_PORT ?? '11500'
|
||||
|
||||
const res = await fetch(`http://localhost:${PORT}/command`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ command: 'terminate' }),
|
||||
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||
|
||||
ws.addEventListener('open', () => {
|
||||
ws.send(JSON.stringify({ type: 'terminate' }))
|
||||
})
|
||||
|
||||
const data = await res.json()
|
||||
if (data.status !== 'ok') {
|
||||
console.error('error:', data.message)
|
||||
ws.addEventListener('message', ({ data }) => {
|
||||
const msg = JSON.parse(data)
|
||||
if (msg.status !== 'ok') {
|
||||
console.error('error:', msg.message)
|
||||
process.exit(1)
|
||||
}
|
||||
ws.close()
|
||||
})
|
||||
|
||||
ws.addEventListener('error', (e) => {
|
||||
console.error('error:', e.message)
|
||||
process.exit(1)
|
||||
})
|
||||
|
||||
@@ -11,14 +11,20 @@ if (!audio_prompt) {
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
const res = await fetch(`http://localhost:${PORT}/speak`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ text, audio_prompt }),
|
||||
const ws = new WebSocket(`ws://localhost:${PORT}`)
|
||||
|
||||
ws.addEventListener('open', () => {
|
||||
ws.send(JSON.stringify({ type: 'speak', text, audio_prompt }))
|
||||
})
|
||||
|
||||
const data = await res.json()
|
||||
if (data.status !== 'ok') {
|
||||
console.error('error:', data.message)
|
||||
process.exit(1)
|
||||
ws.addEventListener('message', ({ data }) => {
|
||||
const msg = JSON.parse(data)
|
||||
if (msg.event === 'finished' || msg.event === 'aborted' || msg.event === 'error') {
|
||||
ws.close()
|
||||
}
|
||||
})
|
||||
|
||||
ws.addEventListener('error', (e) => {
|
||||
console.error('error:', e.message)
|
||||
process.exit(1)
|
||||
})
|
||||
|
||||
@@ -12,7 +12,7 @@ fi
|
||||
|
||||
echo "==> installing Python dependencies"
|
||||
"${VENV}/bin/pip" install --upgrade pip --quiet
|
||||
"${VENV}/bin/pip" install chatterbox-tts
|
||||
"${VENV}/bin/pip" install chatterbox-tts websockets
|
||||
|
||||
|
||||
echo ""
|
||||
|
||||
Reference in New Issue
Block a user