Update face_enhancer.py for apple silicon mps
parent
c72582506d
commit
de4f765878
|
@ -49,7 +49,10 @@ def get_face_enhancer() -> Any:
|
|||
with THREAD_LOCK:
|
||||
if FACE_ENHANCER is None:
|
||||
model_path = os.path.join(models_dir, 'GFPGANv1.4.pth')
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
|
||||
mps_device = None
|
||||
if torch.backends.mps.is_available():
|
||||
mps_device = torch.device("mps")
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=mps_device) # type: ignore[attr-defined]
|
||||
return FACE_ENHANCER
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue