experiments with websocket-stream

This commit is contained in:
lexasub 2025-04-16 06:34:09 +04:00
parent f0d2bfbfb7
commit e144b459c4
12 changed files with 559 additions and 2 deletions

View File

@ -83,6 +83,7 @@ option(WHISPER_BUILD_SERVER "whisper: build server example" ${WHISPER_STANDALO
# 3rd party libs
option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF)
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)
option(WEBSOCKET "whisper: support for websocket" OFF)
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF)

View File

@ -100,8 +100,10 @@ if (EMSCRIPTEN)
add_subdirectory(bench.wasm)
elseif(CMAKE_JS_VERSION)
add_subdirectory(addon.node)
else()
add_subdirectory(cli)
else()
if (WEBSOCKET)
add_subdirectory(cli)
endif()
add_subdirectory(bench)
add_subdirectory(server)
add_subdirectory(quantize)

View File

@ -0,0 +1,8 @@
set(TARGET whisper-websocket-stream)
add_executable(${TARGET} main.cpp whisper-server.cpp message-buffer.cpp)
find_package(ixwebsocket)
find_package(CURL REQUIRED)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ixwebsocket z CURL::libcurl ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -0,0 +1,90 @@
# WebSocket Whisper Stream Example
This example demonstrates a WebSocket-based real-time audio transcription service using the Whisper model. The server captures audio from clients, processes it using the Whisper model, and sends transcriptions back through WebSocket connections.
## Features
- Real-time audio transcription
- WebSocket communication for audio and transcription data
- Configurable parameters for model, language, and processing settings
- Integration with backend services via HTTP requests
## Usage
Run the server with the following command:
```bash
./build/bin/whisper-websocket-stream -m ./models/ggml-large-v3-turbo.bin -t 8 --host 0.0.0.0 --port 9002 --forward-url http://localhost:8080/completion
```
### Parameters
- `-m` or `--model`: Path to the Whisper model file.
- `-t` or `--threads`: Number of threads for processing.
- `-H` or `--host`: Hostname or IP address to bind the server to.
- `-p` or `--port`: Port number for the server.
- `-f` or `--forward-url`: URL to forward transcriptions to a backend service.
- `-nm` or `--max-messages`: Maximum number of messages before sending to the backend.
- `-l` or `--language`: Spoken language for transcription.
- `-vth` or `--vad-thold`: Voice activity detection threshold.
- `-tr` or `--translate`: Enable translation to English.
- `-ng` or `--no-gpu`: Disable GPU usage.
- `-bs` or `--beam-size`: Beam size for beam search.
## Building
To build the server, follow these steps:
```bash
# Install dependencies
git clone --depth 1 https://github.com/machinezone/IXWebSocket/
cd IXWebSocket
mkdir -p build && cd build && cmake -GNinja .. && sudo ninja -j$((npoc)) install
# Build the project
#cuda is optional
git clone --depth 1 https://github.com/ggml-org/whisper.cpp
cd whisper.cpp
mkdir -p build && cd build
cmake -GNinja -DCMAKE_BUILD_TYPE=Release -DWEBSOCKET=ON -DGGML_CUDA ..
ninja -j$((npoc))
# Run the server
./bin/whisper-websocket-stream --help
```
## Client Integration
Clients can connect to the WebSocket server and send audio data. The server processes the audio and sends transcriptions back through the WebSocket connection.
### Example Client Code (JavaScript)
```javascript
const socket = new WebSocket('ws://localhost:9002');
socket.onopen = () => {
console.log('Connected to WebSocket server');
};
socket.onmessage = (event) => {
console.log('Transcription:', event.data);
};
socket.onclose = () => {
console.log('Disconnected from WebSocket server');
};
// Function to send audio data to the server
function sendAudioData(audioData) {
socket.send(audioData);
}
```
## Backend Integration
The server can forward transcriptions to a backend service via HTTP requests. Configure the `forward_url` parameter to specify the backend service URL.
## Dependencies
- whisper.cpp
- ixwebsocket for WebSocket communication
- libcurl for HTTP requests
```

View File

@ -0,0 +1,15 @@
#ifndef CLIENT_SESSION_H
#define CLIENT_SESSION_H
#include <vector>
#include <mutex>
#include <atomic>
#include "ixwebsocket/IXWebSocketServer.h"
#include "message-buffer.h"
struct ClientSession {
std::vector<float> pcm_buffer;
std::mutex mtx;
std::atomic<bool> active{true};
ix::WebSocket *connection;
MessageBuffer buffToBackend;
};
#endif

View File

@ -0,0 +1,61 @@
<!DOCTYPE html>
<html>
<head>
<title>Mic to WebSocket</title>
</head>
<body>
<button id="startBtn">Start Mic</button>
<div id="status"></div>
<script>
const startBtn = document.getElementById('startBtn');
const statusDiv = document.getElementById('status');
let isRecording = false;
let socket;
startBtn.addEventListener('click', async () => {
if (!isRecording) {
try {
socket = new WebSocket('ws://192.168.2.109:9002');
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
const audioContext = new AudioContext({sampleRate: 16000});
const source = audioContext.createMediaStreamSource(stream);
const processor = audioContext.createScriptProcessor(1024, 1, 1);
source.connect(processor);
processor.connect(audioContext.destination);
function floatTo16BitPCM(input) {
const output = new Int16Array(input.length);
for (let i = 0; i < input.length; i++) {
output[i] = Math.max(-1, Math.min(1, input[i])) * 0x7FFF;
}
return output;
}
processor.onaudioprocess = (e) => {
const input = e.inputBuffer.getChannelData(0);
const int16Data = floatTo16BitPCM(input);
if (socket.readyState === WebSocket.OPEN) {
socket.send(int16Data.buffer);
}
};
statusDiv.textContent = 'Recording...';
startBtn.textContent = 'Stop';
isRecording = true;
} catch (err) {
console.error('Error accessing microphone:', err);
statusDiv.textContent = 'Error accessing microphone';
}
} else {
if (socket) socket.close();
statusDiv.textContent = 'Stopped';
startBtn.textContent = 'Start Mic';
isRecording = false;
}
});
</script>
</body>
</html>

View File

@ -0,0 +1,74 @@
#include <string>
#include "whisper.h"
#include "server-params.h"
#include "whisper-server.h"
#define CONVERT_FROM_PCM_16
std::string forward_url = "http://127.0.0.1:8080/completion";
size_t max_messages = 1000;
void print_usage(int argc, char** argv, const ServerParams& params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -H HOST, --host HOST [%-7s] hostname or ip\n", params.host.c_str());
fprintf(stderr, " -p PORT, --port PORT [%-7d] server port\n", params.port);
fprintf(stderr, " -f FORWARD_URL, --forward-url FORWARD_URL [%-7s] forward url\n", forward_url.c_str());
fprintf(stderr, " -t N, --threads N [%-7d] number of threads\n", params.n_threads);
fprintf(stderr, " -nm max_messages, --max-messages max_messages [%-7d] max messages before send to backend\n", max_messages);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity threshold\n", params.vad_thold);
fprintf(stderr, " -tr, --translate [%-7s] translate to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, "\n");
}
bool parse_params(int argc, char** argv, ServerParams& params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-H" || arg == "--host") { params.host = argv[++i]; }
else if (arg == "-p" || arg == "--port") { params.port = std::stoi(argv[++i]); }
else if (arg == "-f" || arg == "--forward-url") { forward_url = argv[++i]; }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-nm" || arg == "--max-messages") { max_messages = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv, params);
return false;
}
}
return true;
}
int main(int argc, char** argv) {
ServerParams params;
if (!parse_params(argc, argv, params)) {
return 1;
}
if (params.port < 1 || params.port > 65535) {
throw std::invalid_argument("Invalid port number");
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
return 1;
}
WhisperServer server(params);
server.run();
return 0;
}

View File

@ -0,0 +1,79 @@
#include <sstream>
#include <mutex>
#include <curl/curl.h>
#include "message-buffer.h"
extern std::string forward_url;
extern size_t max_messages;
namespace {
std::stringstream ss;
std::mutex mtx;
size_t current_count = 0;
static size_t write_callback(char* ptr, size_t size, size_t nmemb, void* userdata) {
((std::string*)userdata)->append(ptr, size * nmemb);
return size * nmemb;
}
}
void MessageBuffer::add_message(const char* msg) {
std::lock_guard<std::mutex> lock(mtx);
ss << std::string(msg) << '\n';
if (++current_count >= max_messages) {
flush();
}
}
std::string MessageBuffer::get_payload() {
std::lock_guard<std::mutex> lock(mtx);
return ss.str();
}
void MessageBuffer::flush() {
std::string payload = get_payload();
if (!payload.empty()) {
send_via_http(payload);
ss.str(""); //clear string stream
current_count = 0;
}
}
void MessageBuffer::send_via_http(const std::string& data) {
CURL* curl = curl_easy_init();
if (!curl) {
printf("CURL init failed");
return;
}
//make headers
struct curl_slist* headers = NULL;
headers = curl_slist_append(headers, "Content-Type: text/plain");
std::string cid_header = "X-Connection-ID: " + connection_id;
headers = curl_slist_append(headers, cid_header.c_str());
//config curl
std::string response;
printf("sending to %s\n", forward_url.c_str());
curl_easy_setopt(curl, CURLOPT_URL, forward_url.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, data.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, data.size());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 2L);
//run curl
for (int retry = 0; retry < 3; ++retry) {
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
printf("[Response (%s): %s\n", connection_id.c_str(), response.c_str());
break;
}
printf("[CURL error: %s\n", curl_easy_strerror(res));
}
//clean
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
}

View File

@ -0,0 +1,14 @@
#ifndef MESSAGE_BUFFER_H
#define MESSAGE_BUFFER_H
class MessageBuffer {
public:
std::string connection_id;
void add_message(const char* msg);
std::string get_payload();
void flush();
void send_via_http(const std::string& data);
};
#endif

View File

@ -0,0 +1,23 @@
#ifndef SERVER_PARAMS_H
#define SERVER_PARAMS_H
#include <thread>
struct ServerParams {
int32_t port = 9002;
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
int32_t audio_ctx = 0;
int32_t beam_size = -1;
float vad_thold = 0.6f;
bool translate = false;
bool print_special = false;
bool no_timestamps = true;
bool tinydiarize = false;
bool use_gpu = true;
bool flash_attn = true;
std::string language = "en";
std::string model = "ggml-large-v3-turbo.bin";
std::string host = "0.0.0.0";
};
#endif

View File

@ -0,0 +1,177 @@
#include <unordered_map>
#include <memory>
#include <atomic>
#include <mutex>
#include <chrono>
#include <random>
#include <sstream>
#include "whisper-server.h"
#include "client-session.h"
#include "whisper.h"
namespace {
ServerParams params;
std::unordered_map<std::string, std::unique_ptr<ClientSession>> clients;
std::mutex clients_mtx;
std::thread processor_thread;
std::atomic<bool> running{true};
std::mutex g_ctx_mtx;
whisper_context* g_ctx = nullptr;
constexpr int CHUNK_SIZE = 3 * 16000;
}
std::string generate_uuid_v4() {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, 15);
std::uniform_int_distribution<> dis2(8, 11);
std::stringstream ss;
ss << std::hex;
for (int i = 0; i < 8; i++) ss << dis(gen);
ss << "-";
for (int i = 0; i < 4; i++) ss << dis(gen);
ss << "-4"; // v4
for (int i = 0; i < 3; i++) ss << dis(gen);
ss << "-";
ss << dis2(gen);
for (int i = 0; i < 3; i++) ss << dis(gen);
ss << "-";
for (int i = 0; i < 12; i++) ss << dis(gen);
return ss.str();
}
void handleMessage(std::shared_ptr<ix::ConnectionState> state,
ix::WebSocket& ws,
const ix::WebSocketMessagePtr& msg) {
const std::string client_id = state->getId();
if (msg->type == ix::WebSocketMessageType::Open) {
printf("[%s] new client\n", client_id.c_str());
std::lock_guard<std::mutex> lock(clients_mtx);
clients[client_id] = std::make_unique<ClientSession>();
// UUID v4
clients[client_id]->buffToBackend.connection_id = generate_uuid_v4();
ws.sendText("CONNECTION_ID:" + clients[client_id]->buffToBackend.connection_id);
clients[client_id]->connection = &ws;
}
else if (msg->type == ix::WebSocketMessageType::Close) {
printf("[%s] delete client\n", client_id.c_str());
clients[client_id]->buffToBackend.flush();
std::lock_guard<std::mutex> lock(clients_mtx);
if (clients.count(client_id)) {
clients[client_id]->active = false;
clients.erase(client_id);
}
}
else if (msg->type == ix::WebSocketMessageType::Message && msg->binary) {
std::lock_guard<std::mutex> lock(clients_mtx);
if (!clients.count(client_id)) return;
auto& session = *clients[client_id];
const auto& data = msg->str;
#ifdef CONVERT_FROM_PCM_16
const int16_t* pcm16 = reinterpret_cast<const int16_t*>(data.data());
size_t n_samples = data.size() / sizeof(int16_t);
std::lock_guard<std::mutex> session_lock(session.mtx);
for (size_t i = 0; i < n_samples; i++) {
session.pcm_buffer.push_back(pcm16[i] / 32768.0f);
}
#else
const int32_t* pcm32 = reinterpret_cast<const int32_t*>(data.data());
//also we may use memcpy ))
size_t n_samples = data.size() / sizeof(int32_t);
std::lock_guard<std::mutex> session_lock(session.mtx);
for (size_t i = 0; i < n_samples; i++) {
session.pcm_buffer.push_back(pcm32[i]);
}
#endif
}
}
void processChunk(std::vector<float> &chunk, const std::string &id, ClientSession *session) {
std::lock_guard<std::mutex> ctx_lock(g_ctx_mtx);
whisper_full_params wparams = whisper_full_default_params(
params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH
: WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.beam_search.beam_size = params.beam_size;
wparams.audio_ctx = params.audio_ctx;
wparams.tdrz_enable = params.tinydiarize;
if (whisper_full(g_ctx, wparams, chunk.data(), chunk.size()) == 0) {
const char* text = whisper_full_get_segment_text(g_ctx, 0);
printf("[%s] %s\n", id.c_str(), text);
session->connection->sendText(text);
session->buffToBackend.add_message(text);
}
whisper_reset_timings(g_ctx);
}
void process() {
while (running) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::lock_guard<std::mutex> lock(clients_mtx);
for (auto& [id, session] : clients) {
std::lock_guard<std::mutex> session_lock(session->mtx);
if (session->pcm_buffer.size() < CHUNK_SIZE) continue;
std::vector<float> chunk(
session->pcm_buffer.begin(),
session->pcm_buffer.begin() + CHUNK_SIZE
);
session->pcm_buffer.erase(
session->pcm_buffer.begin(),
session->pcm_buffer.begin() + CHUNK_SIZE
);
processChunk(chunk, id, session.get());
}
}
}
WhisperServer::WhisperServer(const ServerParams& _params) : server(params.port, params.host) {
params = _params;
whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
g_ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
server.setTLSOptions({});
server.setOnClientMessageCallback([this](auto&&... args) {
handleMessage(args...);
});
processor_thread = std::thread([this] { process(); });
}
WhisperServer::~WhisperServer() {
running = false;
server.stop();
if (processor_thread.joinable()) processor_thread.join();
std::lock_guard<std::mutex> lock(clients_mtx);
for (auto& [id, session] : clients) {
session->buffToBackend.flush();
}
whisper_free(g_ctx);
}
void WhisperServer::run() {
server.listenAndStart();
while (running) std::this_thread::sleep_for(std::chrono::seconds(1));
}

View File

@ -0,0 +1,13 @@
#ifndef WHISPER_SERVER_H
#define WHISPER_SERVER_H
#include "server-params.h"
#include "ixwebsocket/IXWebSocketServer.h"
class WhisperServer {
ix::WebSocketServer server;
public:
WhisperServer(const ServerParams& params);
~WhisperServer();
void run();
};
#endif