fix & add trt support

pull/1094/head
NeuroDonu 2025-04-19 16:03:49 +03:00 committed by GitHub
parent 75b5b096d6
commit 890beb0eae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 9 deletions

View File

@ -48,6 +48,17 @@ def pre_start() -> bool:
return True return True
TENSORRT_AVAILABLE = False
try:
import torch_tensorrt
TENSORRT_AVAILABLE = True
except ImportError as im:
print(f"TensorRT is not available: {im}")
pass
except Exception as e:
print(f"TensorRT is not available: {e}")
pass
def get_face_enhancer() -> Any: def get_face_enhancer() -> Any:
global FACE_ENHANCER global FACE_ENHANCER
@ -55,16 +66,26 @@ def get_face_enhancer() -> Any:
if FACE_ENHANCER is None: if FACE_ENHANCER is None:
model_path = os.path.join(models_dir, "GFPGANv1.4.pth") model_path = os.path.join(models_dir, "GFPGANv1.4.pth")
match platform.system(): selected_device = None
case "Darwin": # Mac OS device_priority = []
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=mps_device) # type: ignore[attr-defined]
else:
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
case _: # Other OS
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
if TENSORRT_AVAILABLE and torch.cuda.is_available():
selected_device = torch.device("cuda")
device_priority.append("TensorRT+CUDA")
elif torch.cuda.is_available():
selected_device = torch.device("cuda")
device_priority.append("CUDA")
elif torch.backends.mps.is_available() and platform.system() == "Darwin":
selected_device = torch.device("mps")
device_priority.append("MPS")
elif not torch.cuda.is_available():
selected_device = torch.device("cpu")
device_priority.append("CPU")
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=selected_device)
# for debug:
print(f"Selected device: {selected_device} and device priority: {device_priority}")
return FACE_ENHANCER return FACE_ENHANCER