Update face_enhancer.py for apple silicon mps

pull/829/head
Rishon 2024-12-14 16:47:07 +08:00 committed by GitHub
parent c72582506d
commit de4f765878
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 1 deletions

View File

@ -49,7 +49,10 @@ def get_face_enhancer() -> Any:
with THREAD_LOCK:
if FACE_ENHANCER is None:
model_path = os.path.join(models_dir, 'GFPGANv1.4.pth')
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
mps_device = None
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=mps_device) # type: ignore[attr-defined]
return FACE_ENHANCER