feat: first commit

This commit is contained in:
Simon Guigui
2023-11-25 18:22:45 +01:00
commit c0948a60c3
13 changed files with 1618 additions and 0 deletions

3
.dockerignore Normal file
View File

@@ -0,0 +1,3 @@
.venv
.git
data

117
.gitignore vendored Normal file
View File

@@ -0,0 +1,117 @@
output/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that dont work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Pycharm
.idea
data

27
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,27 @@
---
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
args: [--allow-multiple-documents]
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: local
hooks:
- id: black
name: black
entry: poetry run black src
language: system
types: [python]
pass_filenames: false
- id: ruff
name: ruff
entry: poetry run ruff --fix src
language: system
types: [python]
pass_filenames: false
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.3.2
hooks:
- id: prettier

9
Dockerfile Normal file
View File

@@ -0,0 +1,9 @@
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
RUN pip install -U pip && pip install wyoming==1.2.0 openai-whisper==20231117 tokenizers==0.13.*
COPY src /src
WORKDIR /src
ENTRYPOINT ["/src/run.sh"]

29
README.md Normal file
View File

@@ -0,0 +1,29 @@
# rocm-wyoming-whisper
A docker image and a few lines of python to use OpenAI whisper with Rhasspy and/or Home Assistant on AMD GPUs with ROCm.
## Run with docker-compose
```shell
docker-compose up -d
```
## Run with Docker
Build docker image:
```shell
docker build -t wyoming-whisper .
```
Run docker image:
```shell
docker run --entrypoint '' -v $(pwd)/data:/data -v $(pwd)/src:/src -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video -p 10300:10300 wyoming-whisper bash
```
Run script:
```shell
python -m wyoming_whisper --download-dir /data --model medium --debug
```

22
docker-compose.yml Normal file
View File

@@ -0,0 +1,22 @@
services:
whisper:
container_name: whisper
restart: unless-stopped
build:
context: .
group_add:
- video
volumes:
- ./data:/data
command:
- --download-dir
- /data
- --language
- en
- --device
- cuda
ports:
- 10300:10300
devices:
- "/dev/dri"
- "/dev/kfd"

1093
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

18
pyproject.toml Normal file
View File

@@ -0,0 +1,18 @@
[tool.poetry]
name = "wyoming-whisper"
version = "0.1.0"
description = ""
authors = ["Simon Guigui <simon.guigui@parkoview.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
openai-whisper = "20231117"
wyoming = "^1.2.0"
tokenizers = "0.13.*"
black = "^23.11.0"
ruff = "^0.1.6"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

3
src/run.sh Executable file
View File

@@ -0,0 +1,3 @@
#!/usr/bin/env bash
exec python3 -m wyoming_whisper --download-dir /data "$@"

View File

@@ -0,0 +1 @@
"""Wyoming server for faster-whisper."""

103
src/wyoming_whisper/__main__.py Executable file
View File

@@ -0,0 +1,103 @@
import argparse
import asyncio
import logging
from functools import partial
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
from wyoming.server import AsyncServer
from .const import WHISPER_LANGUAGES
import whisper
from .handler import FasterWhisperEventHandler
_LOGGER = logging.getLogger(__name__)
async def main() -> None:
"""Main entry point."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="medium",
help="Name of whisper model to use (default medium)",
)
parser.add_argument(
"--uri", help="unix:// or tcp://", default="tcp://0.0.0.0:10300"
)
parser.add_argument(
"--download-dir",
required=True,
help="Directory to download models into",
)
parser.add_argument(
"--device",
default="cuda",
help="Device to use for inference (default: cuda)",
)
parser.add_argument(
"--language",
default="en",
help="Default language to set for transcription (default: en)",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
)
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_LOGGER.debug(args)
_LOGGER.debug("Loading %s", args.model)
whisper_model = whisper.load_model(
args.model, device=args.device, download_root=args.download_dir
)
wyoming_info = Info(
asr=[
AsrProgram(
name="whisper",
description="Whisper",
attribution=Attribution(
name="OpenAI",
url="https://github.com/openai/whisper",
),
installed=True,
models=[
AsrModel(
name=args.model,
description=args.model,
attribution=Attribution(
name="whisper",
url="https://github.com/openai/whisper",
),
installed=True,
languages=WHISPER_LANGUAGES,
)
],
)
],
)
server = AsyncServer.from_uri(args.uri)
_LOGGER.info("Ready")
model_lock = asyncio.Lock()
await server.run(
partial(
FasterWhisperEventHandler,
wyoming_info,
args,
whisper_model,
model_lock,
)
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass

View File

@@ -0,0 +1,101 @@
WHISPER_LANGUAGES = [
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"zh",
]

View File

@@ -0,0 +1,92 @@
"""Event handler for clients of the server."""
import argparse
import asyncio
import logging
import numpy as np
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from whisper import Whisper
_LOGGER = logging.getLogger(__name__)
class FasterWhisperEventHandler(AsyncEventHandler):
"""Event handler for clients."""
def __init__(
self,
wyoming_info: Info,
cli_args: argparse.Namespace,
model: Whisper,
model_lock: asyncio.Lock,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.cli_args = cli_args
self.wyoming_info_event = wyoming_info.event()
self.model = model
self.model_lock = model_lock
self.audio = bytes()
self.audio_converter = AudioChunkConverter(
rate=16000,
width=2,
channels=1,
)
self._language = self.cli_args.language
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self.wyoming_info_event)
_LOGGER.debug("Sent info")
return True
if Transcribe.is_type(event.type):
transcribe = Transcribe.from_event(event)
if transcribe.language:
self._language = transcribe.language
_LOGGER.debug("Language set to %s", transcribe.language)
return True
if AudioChunk.is_type(event.type):
if not self.audio:
_LOGGER.debug("Receiving audio")
chunk = AudioChunk.from_event(event)
chunk = self.audio_converter.convert(chunk)
self.audio += chunk.audio
return True
if AudioStop.is_type(event.type):
_LOGGER.debug("Audio stopped")
audio = (
np.frombuffer(self.audio, dtype=np.int16).astype(np.float32) / 32678.0
)
async with self.model_lock:
response = self.model.transcribe(
audio,
beam_size=self.cli_args.beam_size,
language=self._language,
)
segments = response["segments"]
text = " ".join(segment["text"] for segment in segments)
_LOGGER.info(text)
await self.write_event(Transcript(text=text).event())
_LOGGER.debug("Completed request")
# Reset
self.audio = bytes()
self._language = self.cli_args.language
return False
return True