feat: first commit

This commit is contained in:
Simon Guigui
2023-11-25 18:22:45 +01:00
commit c0948a60c3
13 changed files with 1618 additions and 0 deletions

3
src/run.sh Executable file
View File

@@ -0,0 +1,3 @@
#!/usr/bin/env bash
exec python3 -m wyoming_whisper --download-dir /data "$@"

View File

@@ -0,0 +1 @@
"""Wyoming server for faster-whisper."""

103
src/wyoming_whisper/__main__.py Executable file
View File

@@ -0,0 +1,103 @@
import argparse
import asyncio
import logging
from functools import partial
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
from wyoming.server import AsyncServer
from .const import WHISPER_LANGUAGES
import whisper
from .handler import FasterWhisperEventHandler
_LOGGER = logging.getLogger(__name__)
async def main() -> None:
"""Main entry point."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="medium",
help="Name of whisper model to use (default medium)",
)
parser.add_argument(
"--uri", help="unix:// or tcp://", default="tcp://0.0.0.0:10300"
)
parser.add_argument(
"--download-dir",
required=True,
help="Directory to download models into",
)
parser.add_argument(
"--device",
default="cuda",
help="Device to use for inference (default: cuda)",
)
parser.add_argument(
"--language",
default="en",
help="Default language to set for transcription (default: en)",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
)
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_LOGGER.debug(args)
_LOGGER.debug("Loading %s", args.model)
whisper_model = whisper.load_model(
args.model, device=args.device, download_root=args.download_dir
)
wyoming_info = Info(
asr=[
AsrProgram(
name="whisper",
description="Whisper",
attribution=Attribution(
name="OpenAI",
url="https://github.com/openai/whisper",
),
installed=True,
models=[
AsrModel(
name=args.model,
description=args.model,
attribution=Attribution(
name="whisper",
url="https://github.com/openai/whisper",
),
installed=True,
languages=WHISPER_LANGUAGES,
)
],
)
],
)
server = AsyncServer.from_uri(args.uri)
_LOGGER.info("Ready")
model_lock = asyncio.Lock()
await server.run(
partial(
FasterWhisperEventHandler,
wyoming_info,
args,
whisper_model,
model_lock,
)
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass

View File

@@ -0,0 +1,101 @@
WHISPER_LANGUAGES = [
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"zh",
]

View File

@@ -0,0 +1,92 @@
"""Event handler for clients of the server."""
import argparse
import asyncio
import logging
import numpy as np
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from whisper import Whisper
_LOGGER = logging.getLogger(__name__)
class FasterWhisperEventHandler(AsyncEventHandler):
"""Event handler for clients."""
def __init__(
self,
wyoming_info: Info,
cli_args: argparse.Namespace,
model: Whisper,
model_lock: asyncio.Lock,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.cli_args = cli_args
self.wyoming_info_event = wyoming_info.event()
self.model = model
self.model_lock = model_lock
self.audio = bytes()
self.audio_converter = AudioChunkConverter(
rate=16000,
width=2,
channels=1,
)
self._language = self.cli_args.language
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self.wyoming_info_event)
_LOGGER.debug("Sent info")
return True
if Transcribe.is_type(event.type):
transcribe = Transcribe.from_event(event)
if transcribe.language:
self._language = transcribe.language
_LOGGER.debug("Language set to %s", transcribe.language)
return True
if AudioChunk.is_type(event.type):
if not self.audio:
_LOGGER.debug("Receiving audio")
chunk = AudioChunk.from_event(event)
chunk = self.audio_converter.convert(chunk)
self.audio += chunk.audio
return True
if AudioStop.is_type(event.type):
_LOGGER.debug("Audio stopped")
audio = (
np.frombuffer(self.audio, dtype=np.int16).astype(np.float32) / 32678.0
)
async with self.model_lock:
response = self.model.transcribe(
audio,
beam_size=self.cli_args.beam_size,
language=self._language,
)
segments = response["segments"]
text = " ".join(segment["text"] for segment in segments)
_LOGGER.info(text)
await self.write_event(Transcript(text=text).event())
_LOGGER.debug("Completed request")
# Reset
self.audio = bytes()
self._language = self.cli_args.language
return False
return True