fix & add trt support
							parent
							
								
									75b5b096d6
								
							
						
					
					
						commit
						890beb0eae
					
				|  | @ -48,6 +48,17 @@ def pre_start() -> bool: | |||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| TENSORRT_AVAILABLE = False | ||||
| try: | ||||
|     import torch_tensorrt | ||||
|     TENSORRT_AVAILABLE = True | ||||
| except ImportError as im: | ||||
|     print(f"TensorRT is not available: {im}") | ||||
|     pass | ||||
| except Exception as e: | ||||
|     print(f"TensorRT is not available: {e}") | ||||
|     pass | ||||
| 
 | ||||
| def get_face_enhancer() -> Any: | ||||
|     global FACE_ENHANCER | ||||
| 
 | ||||
|  | @ -55,16 +66,26 @@ def get_face_enhancer() -> Any: | |||
|         if FACE_ENHANCER is None: | ||||
|             model_path = os.path.join(models_dir, "GFPGANv1.4.pth") | ||||
|              | ||||
|             match platform.system(): | ||||
|                 case "Darwin":  # Mac OS | ||||
|                     if torch.backends.mps.is_available(): | ||||
|                         mps_device = torch.device("mps") | ||||
|                         FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=mps_device)  # type: ignore[attr-defined] | ||||
|                     else: | ||||
|                         FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1)  # type: ignore[attr-defined] | ||||
|                 case _:  # Other OS | ||||
|                     FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1)  # type: ignore[attr-defined] | ||||
|             selected_device = None | ||||
|             device_priority = [] | ||||
| 
 | ||||
|             if TENSORRT_AVAILABLE and torch.cuda.is_available(): | ||||
|                 selected_device = torch.device("cuda") | ||||
|                 device_priority.append("TensorRT+CUDA") | ||||
|             elif torch.cuda.is_available(): | ||||
|                 selected_device = torch.device("cuda") | ||||
|                 device_priority.append("CUDA") | ||||
|             elif torch.backends.mps.is_available() and platform.system() == "Darwin": | ||||
|                 selected_device = torch.device("mps") | ||||
|                 device_priority.append("MPS") | ||||
|             elif not torch.cuda.is_available(): | ||||
|                 selected_device = torch.device("cpu") | ||||
|                 device_priority.append("CPU") | ||||
|              | ||||
|             FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=selected_device) | ||||
| 
 | ||||
|             # for debug: | ||||
|             print(f"Selected device: {selected_device} and device priority: {device_priority}") | ||||
|     return FACE_ENHANCER | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue