feat: first commit
This commit is contained in:
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
.venv
|
||||||
|
.git
|
||||||
|
data
|
||||||
117
.gitignore
vendored
Normal file
117
.gitignore
vendored
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
output/*
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don’t work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# Pycharm
|
||||||
|
.idea
|
||||||
|
|
||||||
|
data
|
||||||
27
.pre-commit-config.yaml
Normal file
27
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
---
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
args: [--allow-multiple-documents]
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
name: black
|
||||||
|
entry: poetry run black src
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: false
|
||||||
|
- id: ruff
|
||||||
|
name: ruff
|
||||||
|
entry: poetry run ruff --fix src
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: false
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||||
|
rev: v2.3.2
|
||||||
|
hooks:
|
||||||
|
- id: prettier
|
||||||
9
Dockerfile
Normal file
9
Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
|
||||||
|
|
||||||
|
RUN pip install -U pip && pip install wyoming==1.2.0 openai-whisper==20231117 tokenizers==0.13.*
|
||||||
|
|
||||||
|
COPY src /src
|
||||||
|
|
||||||
|
WORKDIR /src
|
||||||
|
|
||||||
|
ENTRYPOINT ["/src/run.sh"]
|
||||||
29
README.md
Normal file
29
README.md
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# rocm-wyoming-whisper
|
||||||
|
|
||||||
|
A docker image and a few lines of python to use OpenAI whisper with Rhasspy and/or Home Assistant on AMD GPUs with ROCm.
|
||||||
|
|
||||||
|
## Run with docker-compose
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run with Docker
|
||||||
|
|
||||||
|
Build docker image:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker build -t wyoming-whisper .
|
||||||
|
```
|
||||||
|
|
||||||
|
Run docker image:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run --entrypoint '' -v $(pwd)/data:/data -v $(pwd)/src:/src -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video -p 10300:10300 wyoming-whisper bash
|
||||||
|
```
|
||||||
|
|
||||||
|
Run script:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m wyoming_whisper --download-dir /data --model medium --debug
|
||||||
|
```
|
||||||
22
docker-compose.yml
Normal file
22
docker-compose.yml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
services:
|
||||||
|
whisper:
|
||||||
|
container_name: whisper
|
||||||
|
restart: unless-stopped
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
group_add:
|
||||||
|
- video
|
||||||
|
volumes:
|
||||||
|
- ./data:/data
|
||||||
|
command:
|
||||||
|
- --download-dir
|
||||||
|
- /data
|
||||||
|
- --language
|
||||||
|
- en
|
||||||
|
- --device
|
||||||
|
- cuda
|
||||||
|
ports:
|
||||||
|
- 10300:10300
|
||||||
|
devices:
|
||||||
|
- "/dev/dri"
|
||||||
|
- "/dev/kfd"
|
||||||
1093
poetry.lock
generated
Normal file
1093
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "wyoming-whisper"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = ["Simon Guigui <simon.guigui@parkoview.com>"]
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.10"
|
||||||
|
openai-whisper = "20231117"
|
||||||
|
wyoming = "^1.2.0"
|
||||||
|
tokenizers = "0.13.*"
|
||||||
|
black = "^23.11.0"
|
||||||
|
ruff = "^0.1.6"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
3
src/run.sh
Executable file
3
src/run.sh
Executable file
@@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
exec python3 -m wyoming_whisper --download-dir /data "$@"
|
||||||
1
src/wyoming_whisper/__init__.py
Normal file
1
src/wyoming_whisper/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Wyoming server for faster-whisper."""
|
||||||
103
src/wyoming_whisper/__main__.py
Executable file
103
src/wyoming_whisper/__main__.py
Executable file
@@ -0,0 +1,103 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
|
||||||
|
from wyoming.server import AsyncServer
|
||||||
|
|
||||||
|
from .const import WHISPER_LANGUAGES
|
||||||
|
import whisper
|
||||||
|
from .handler import FasterWhisperEventHandler
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
"""Main entry point."""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="medium",
|
||||||
|
help="Name of whisper model to use (default medium)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--uri", help="unix:// or tcp://", default="tcp://0.0.0.0:10300"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--download-dir",
|
||||||
|
required=True,
|
||||||
|
help="Directory to download models into",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda",
|
||||||
|
help="Device to use for inference (default: cuda)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
default="en",
|
||||||
|
help="Default language to set for transcription (default: en)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
|
||||||
|
_LOGGER.debug(args)
|
||||||
|
|
||||||
|
_LOGGER.debug("Loading %s", args.model)
|
||||||
|
whisper_model = whisper.load_model(
|
||||||
|
args.model, device=args.device, download_root=args.download_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
wyoming_info = Info(
|
||||||
|
asr=[
|
||||||
|
AsrProgram(
|
||||||
|
name="whisper",
|
||||||
|
description="Whisper",
|
||||||
|
attribution=Attribution(
|
||||||
|
name="OpenAI",
|
||||||
|
url="https://github.com/openai/whisper",
|
||||||
|
),
|
||||||
|
installed=True,
|
||||||
|
models=[
|
||||||
|
AsrModel(
|
||||||
|
name=args.model,
|
||||||
|
description=args.model,
|
||||||
|
attribution=Attribution(
|
||||||
|
name="whisper",
|
||||||
|
url="https://github.com/openai/whisper",
|
||||||
|
),
|
||||||
|
installed=True,
|
||||||
|
languages=WHISPER_LANGUAGES,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
server = AsyncServer.from_uri(args.uri)
|
||||||
|
_LOGGER.info("Ready")
|
||||||
|
model_lock = asyncio.Lock()
|
||||||
|
await server.run(
|
||||||
|
partial(
|
||||||
|
FasterWhisperEventHandler,
|
||||||
|
wyoming_info,
|
||||||
|
args,
|
||||||
|
whisper_model,
|
||||||
|
model_lock,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
101
src/wyoming_whisper/const.py
Normal file
101
src/wyoming_whisper/const.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
WHISPER_LANGUAGES = [
|
||||||
|
"af",
|
||||||
|
"am",
|
||||||
|
"ar",
|
||||||
|
"as",
|
||||||
|
"az",
|
||||||
|
"ba",
|
||||||
|
"be",
|
||||||
|
"bg",
|
||||||
|
"bn",
|
||||||
|
"bo",
|
||||||
|
"br",
|
||||||
|
"bs",
|
||||||
|
"ca",
|
||||||
|
"cs",
|
||||||
|
"cy",
|
||||||
|
"da",
|
||||||
|
"de",
|
||||||
|
"el",
|
||||||
|
"en",
|
||||||
|
"es",
|
||||||
|
"et",
|
||||||
|
"eu",
|
||||||
|
"fa",
|
||||||
|
"fi",
|
||||||
|
"fo",
|
||||||
|
"fr",
|
||||||
|
"gl",
|
||||||
|
"gu",
|
||||||
|
"ha",
|
||||||
|
"haw",
|
||||||
|
"he",
|
||||||
|
"hi",
|
||||||
|
"hr",
|
||||||
|
"ht",
|
||||||
|
"hu",
|
||||||
|
"hy",
|
||||||
|
"id",
|
||||||
|
"is",
|
||||||
|
"it",
|
||||||
|
"ja",
|
||||||
|
"jw",
|
||||||
|
"ka",
|
||||||
|
"kk",
|
||||||
|
"km",
|
||||||
|
"kn",
|
||||||
|
"ko",
|
||||||
|
"la",
|
||||||
|
"lb",
|
||||||
|
"ln",
|
||||||
|
"lo",
|
||||||
|
"lt",
|
||||||
|
"lv",
|
||||||
|
"mg",
|
||||||
|
"mi",
|
||||||
|
"mk",
|
||||||
|
"ml",
|
||||||
|
"mn",
|
||||||
|
"mr",
|
||||||
|
"ms",
|
||||||
|
"mt",
|
||||||
|
"my",
|
||||||
|
"ne",
|
||||||
|
"nl",
|
||||||
|
"nn",
|
||||||
|
"no",
|
||||||
|
"oc",
|
||||||
|
"pa",
|
||||||
|
"pl",
|
||||||
|
"ps",
|
||||||
|
"pt",
|
||||||
|
"ro",
|
||||||
|
"ru",
|
||||||
|
"sa",
|
||||||
|
"sd",
|
||||||
|
"si",
|
||||||
|
"sk",
|
||||||
|
"sl",
|
||||||
|
"sn",
|
||||||
|
"so",
|
||||||
|
"sq",
|
||||||
|
"sr",
|
||||||
|
"su",
|
||||||
|
"sv",
|
||||||
|
"sw",
|
||||||
|
"ta",
|
||||||
|
"te",
|
||||||
|
"tg",
|
||||||
|
"th",
|
||||||
|
"tk",
|
||||||
|
"tl",
|
||||||
|
"tr",
|
||||||
|
"tt",
|
||||||
|
"uk",
|
||||||
|
"ur",
|
||||||
|
"uz",
|
||||||
|
"vi",
|
||||||
|
"yi",
|
||||||
|
"yo",
|
||||||
|
"zh",
|
||||||
|
]
|
||||||
92
src/wyoming_whisper/handler.py
Normal file
92
src/wyoming_whisper/handler.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Event handler for clients of the server."""
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from wyoming.asr import Transcribe, Transcript
|
||||||
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||||
|
from wyoming.event import Event
|
||||||
|
from wyoming.info import Describe, Info
|
||||||
|
from wyoming.server import AsyncEventHandler
|
||||||
|
|
||||||
|
from whisper import Whisper
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FasterWhisperEventHandler(AsyncEventHandler):
|
||||||
|
"""Event handler for clients."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
wyoming_info: Info,
|
||||||
|
cli_args: argparse.Namespace,
|
||||||
|
model: Whisper,
|
||||||
|
model_lock: asyncio.Lock,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.cli_args = cli_args
|
||||||
|
self.wyoming_info_event = wyoming_info.event()
|
||||||
|
self.model = model
|
||||||
|
self.model_lock = model_lock
|
||||||
|
self.audio = bytes()
|
||||||
|
self.audio_converter = AudioChunkConverter(
|
||||||
|
rate=16000,
|
||||||
|
width=2,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
self._language = self.cli_args.language
|
||||||
|
|
||||||
|
async def handle_event(self, event: Event) -> bool:
|
||||||
|
if Describe.is_type(event.type):
|
||||||
|
await self.write_event(self.wyoming_info_event)
|
||||||
|
_LOGGER.debug("Sent info")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if Transcribe.is_type(event.type):
|
||||||
|
transcribe = Transcribe.from_event(event)
|
||||||
|
if transcribe.language:
|
||||||
|
self._language = transcribe.language
|
||||||
|
_LOGGER.debug("Language set to %s", transcribe.language)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if AudioChunk.is_type(event.type):
|
||||||
|
if not self.audio:
|
||||||
|
_LOGGER.debug("Receiving audio")
|
||||||
|
|
||||||
|
chunk = AudioChunk.from_event(event)
|
||||||
|
chunk = self.audio_converter.convert(chunk)
|
||||||
|
self.audio += chunk.audio
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
if AudioStop.is_type(event.type):
|
||||||
|
_LOGGER.debug("Audio stopped")
|
||||||
|
audio = (
|
||||||
|
np.frombuffer(self.audio, dtype=np.int16).astype(np.float32) / 32678.0
|
||||||
|
)
|
||||||
|
async with self.model_lock:
|
||||||
|
response = self.model.transcribe(
|
||||||
|
audio,
|
||||||
|
beam_size=self.cli_args.beam_size,
|
||||||
|
language=self._language,
|
||||||
|
)
|
||||||
|
segments = response["segments"]
|
||||||
|
|
||||||
|
text = " ".join(segment["text"] for segment in segments)
|
||||||
|
_LOGGER.info(text)
|
||||||
|
|
||||||
|
await self.write_event(Transcript(text=text).event())
|
||||||
|
_LOGGER.debug("Completed request")
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
self.audio = bytes()
|
||||||
|
self._language = self.cli_args.language
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
Reference in New Issue
Block a user