Enable MPS GPU Accerlation on MacOS (#164)

* Enable MPS GPU Accerlation on MacOS

* Fix
This commit is contained in:
fondoger
2025-04-11 02:40:25 +08:00
committed by GitHub
parent 1c7bdd971d
commit 6d87f4ae7a
2 changed files with 19 additions and 1 deletions

View File

@@ -95,6 +95,14 @@ To install espeak-ng on Windows:
For advanced configuration and usage on Windows, see the [official espeak-ng Windows guide](https://github.com/espeak-ng/espeak-ng/blob/master/docs/guide.md) For advanced configuration and usage on Windows, see the [official espeak-ng Windows guide](https://github.com/espeak-ng/espeak-ng/blob/master/docs/guide.md)
### MacOS Apple Silicon GPU Acceleration
On Mac M1/M2/M3/M4 devices, you can explicitly specify the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable GPU acceleration.
```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 python run-your-kokoro-script.py
```
### Conda Environment ### Conda Environment
Use the following conda `environment.yml` if you're facing any dependency issues. Use the following conda `environment.yml` if you're facing any dependency issues.
```yaml ```yaml

View File

@@ -6,6 +6,7 @@ from misaki import en, espeak
from typing import Callable, Generator, List, Optional, Tuple, Union from typing import Callable, Generator, List, Optional, Tuple, Union
import re import re
import torch import torch
import os
ALIASES = { ALIASES = {
'en-us': 'a', 'en-us': 'a',
@@ -93,8 +94,17 @@ class KPipeline:
elif model: elif model:
if device == 'cuda' and not torch.cuda.is_available(): if device == 'cuda' and not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available") raise RuntimeError("CUDA requested but not available")
if device == 'mps' and not torch.backends.mps.is_available():
raise RuntimeError("MPS requested but not available")
if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
raise RuntimeError("MPS requested but fallback not enabled")
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' if torch.cuda.is_available():
device = 'cuda'
elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
try: try:
self.model = KModel(repo_id=repo_id).to(device).eval() self.model = KModel(repo_id=repo_id).to(device).eval()
except RuntimeError as e: except RuntimeError as e: