Enable MPS GPU Accerlation on MacOS (#164)
* Enable MPS GPU Accerlation on MacOS * Fix
This commit is contained in:
@@ -95,6 +95,14 @@ To install espeak-ng on Windows:
|
|||||||
|
|
||||||
For advanced configuration and usage on Windows, see the [official espeak-ng Windows guide](https://github.com/espeak-ng/espeak-ng/blob/master/docs/guide.md)
|
For advanced configuration and usage on Windows, see the [official espeak-ng Windows guide](https://github.com/espeak-ng/espeak-ng/blob/master/docs/guide.md)
|
||||||
|
|
||||||
|
### MacOS Apple Silicon GPU Acceleration
|
||||||
|
|
||||||
|
On Mac M1/M2/M3/M4 devices, you can explicitly specify the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable GPU acceleration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK=1 python run-your-kokoro-script.py
|
||||||
|
```
|
||||||
|
|
||||||
### Conda Environment
|
### Conda Environment
|
||||||
Use the following conda `environment.yml` if you're facing any dependency issues.
|
Use the following conda `environment.yml` if you're facing any dependency issues.
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from misaki import en, espeak
|
|||||||
from typing import Callable, Generator, List, Optional, Tuple, Union
|
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
ALIASES = {
|
ALIASES = {
|
||||||
'en-us': 'a',
|
'en-us': 'a',
|
||||||
@@ -93,8 +94,17 @@ class KPipeline:
|
|||||||
elif model:
|
elif model:
|
||||||
if device == 'cuda' and not torch.cuda.is_available():
|
if device == 'cuda' and not torch.cuda.is_available():
|
||||||
raise RuntimeError("CUDA requested but not available")
|
raise RuntimeError("CUDA requested but not available")
|
||||||
|
if device == 'mps' and not torch.backends.mps.is_available():
|
||||||
|
raise RuntimeError("MPS requested but not available")
|
||||||
|
if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
|
||||||
|
raise RuntimeError("MPS requested but fallback not enabled")
|
||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
if torch.cuda.is_available():
|
||||||
|
device = 'cuda'
|
||||||
|
elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
|
||||||
|
device = 'mps'
|
||||||
|
else:
|
||||||
|
device = 'cpu'
|
||||||
try:
|
try:
|
||||||
self.model = KModel(repo_id=repo_id).to(device).eval()
|
self.model = KModel(repo_id=repo_id).to(device).eval()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user