feat: first commit
This commit is contained in:
3
src/run.sh
Executable file
3
src/run.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
exec python3 -m wyoming_whisper --download-dir /data "$@"
|
||||
1
src/wyoming_whisper/__init__.py
Normal file
1
src/wyoming_whisper/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Wyoming server for faster-whisper."""
|
||||
103
src/wyoming_whisper/__main__.py
Executable file
103
src/wyoming_whisper/__main__.py
Executable file
@@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
|
||||
from wyoming.server import AsyncServer
|
||||
|
||||
from .const import WHISPER_LANGUAGES
|
||||
import whisper
|
||||
from .handler import FasterWhisperEventHandler
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="medium",
|
||||
help="Name of whisper model to use (default medium)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uri", help="unix:// or tcp://", default="tcp://0.0.0.0:10300"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-dir",
|
||||
required=True,
|
||||
help="Directory to download models into",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cuda",
|
||||
help="Device to use for inference (default: cuda)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
default="en",
|
||||
help="Default language to set for transcription (default: en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
|
||||
_LOGGER.debug(args)
|
||||
|
||||
_LOGGER.debug("Loading %s", args.model)
|
||||
whisper_model = whisper.load_model(
|
||||
args.model, device=args.device, download_root=args.download_dir
|
||||
)
|
||||
|
||||
wyoming_info = Info(
|
||||
asr=[
|
||||
AsrProgram(
|
||||
name="whisper",
|
||||
description="Whisper",
|
||||
attribution=Attribution(
|
||||
name="OpenAI",
|
||||
url="https://github.com/openai/whisper",
|
||||
),
|
||||
installed=True,
|
||||
models=[
|
||||
AsrModel(
|
||||
name=args.model,
|
||||
description=args.model,
|
||||
attribution=Attribution(
|
||||
name="whisper",
|
||||
url="https://github.com/openai/whisper",
|
||||
),
|
||||
installed=True,
|
||||
languages=WHISPER_LANGUAGES,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
server = AsyncServer.from_uri(args.uri)
|
||||
_LOGGER.info("Ready")
|
||||
model_lock = asyncio.Lock()
|
||||
await server.run(
|
||||
partial(
|
||||
FasterWhisperEventHandler,
|
||||
wyoming_info,
|
||||
args,
|
||||
whisper_model,
|
||||
model_lock,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
101
src/wyoming_whisper/const.py
Normal file
101
src/wyoming_whisper/const.py
Normal file
@@ -0,0 +1,101 @@
|
||||
WHISPER_LANGUAGES = [
|
||||
"af",
|
||||
"am",
|
||||
"ar",
|
||||
"as",
|
||||
"az",
|
||||
"ba",
|
||||
"be",
|
||||
"bg",
|
||||
"bn",
|
||||
"bo",
|
||||
"br",
|
||||
"bs",
|
||||
"ca",
|
||||
"cs",
|
||||
"cy",
|
||||
"da",
|
||||
"de",
|
||||
"el",
|
||||
"en",
|
||||
"es",
|
||||
"et",
|
||||
"eu",
|
||||
"fa",
|
||||
"fi",
|
||||
"fo",
|
||||
"fr",
|
||||
"gl",
|
||||
"gu",
|
||||
"ha",
|
||||
"haw",
|
||||
"he",
|
||||
"hi",
|
||||
"hr",
|
||||
"ht",
|
||||
"hu",
|
||||
"hy",
|
||||
"id",
|
||||
"is",
|
||||
"it",
|
||||
"ja",
|
||||
"jw",
|
||||
"ka",
|
||||
"kk",
|
||||
"km",
|
||||
"kn",
|
||||
"ko",
|
||||
"la",
|
||||
"lb",
|
||||
"ln",
|
||||
"lo",
|
||||
"lt",
|
||||
"lv",
|
||||
"mg",
|
||||
"mi",
|
||||
"mk",
|
||||
"ml",
|
||||
"mn",
|
||||
"mr",
|
||||
"ms",
|
||||
"mt",
|
||||
"my",
|
||||
"ne",
|
||||
"nl",
|
||||
"nn",
|
||||
"no",
|
||||
"oc",
|
||||
"pa",
|
||||
"pl",
|
||||
"ps",
|
||||
"pt",
|
||||
"ro",
|
||||
"ru",
|
||||
"sa",
|
||||
"sd",
|
||||
"si",
|
||||
"sk",
|
||||
"sl",
|
||||
"sn",
|
||||
"so",
|
||||
"sq",
|
||||
"sr",
|
||||
"su",
|
||||
"sv",
|
||||
"sw",
|
||||
"ta",
|
||||
"te",
|
||||
"tg",
|
||||
"th",
|
||||
"tk",
|
||||
"tl",
|
||||
"tr",
|
||||
"tt",
|
||||
"uk",
|
||||
"ur",
|
||||
"uz",
|
||||
"vi",
|
||||
"yi",
|
||||
"yo",
|
||||
"zh",
|
||||
]
|
||||
92
src/wyoming_whisper/handler.py
Normal file
92
src/wyoming_whisper/handler.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Event handler for clients of the server."""
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import Describe, Info
|
||||
from wyoming.server import AsyncEventHandler
|
||||
|
||||
from whisper import Whisper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FasterWhisperEventHandler(AsyncEventHandler):
|
||||
"""Event handler for clients."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wyoming_info: Info,
|
||||
cli_args: argparse.Namespace,
|
||||
model: Whisper,
|
||||
model_lock: asyncio.Lock,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.cli_args = cli_args
|
||||
self.wyoming_info_event = wyoming_info.event()
|
||||
self.model = model
|
||||
self.model_lock = model_lock
|
||||
self.audio = bytes()
|
||||
self.audio_converter = AudioChunkConverter(
|
||||
rate=16000,
|
||||
width=2,
|
||||
channels=1,
|
||||
)
|
||||
self._language = self.cli_args.language
|
||||
|
||||
async def handle_event(self, event: Event) -> bool:
|
||||
if Describe.is_type(event.type):
|
||||
await self.write_event(self.wyoming_info_event)
|
||||
_LOGGER.debug("Sent info")
|
||||
return True
|
||||
|
||||
if Transcribe.is_type(event.type):
|
||||
transcribe = Transcribe.from_event(event)
|
||||
if transcribe.language:
|
||||
self._language = transcribe.language
|
||||
_LOGGER.debug("Language set to %s", transcribe.language)
|
||||
return True
|
||||
|
||||
if AudioChunk.is_type(event.type):
|
||||
if not self.audio:
|
||||
_LOGGER.debug("Receiving audio")
|
||||
|
||||
chunk = AudioChunk.from_event(event)
|
||||
chunk = self.audio_converter.convert(chunk)
|
||||
self.audio += chunk.audio
|
||||
|
||||
return True
|
||||
|
||||
if AudioStop.is_type(event.type):
|
||||
_LOGGER.debug("Audio stopped")
|
||||
audio = (
|
||||
np.frombuffer(self.audio, dtype=np.int16).astype(np.float32) / 32678.0
|
||||
)
|
||||
async with self.model_lock:
|
||||
response = self.model.transcribe(
|
||||
audio,
|
||||
beam_size=self.cli_args.beam_size,
|
||||
language=self._language,
|
||||
)
|
||||
segments = response["segments"]
|
||||
|
||||
text = " ".join(segment["text"] for segment in segments)
|
||||
_LOGGER.info(text)
|
||||
|
||||
await self.write_event(Transcript(text=text).event())
|
||||
_LOGGER.debug("Completed request")
|
||||
|
||||
# Reset
|
||||
self.audio = bytes()
|
||||
self._language = self.cli_args.language
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user