diff --git a/modules/processors/frame/face_enhancer.py b/modules/processors/frame/face_enhancer.py index 608071a..d714fdb 100644 --- a/modules/processors/frame/face_enhancer.py +++ b/modules/processors/frame/face_enhancer.py @@ -9,23 +9,35 @@ import modules.processors.frame.core from modules.core import update_status from modules.face_analyser import get_one_face from modules.typing import Frame, Face -from modules.utilities import conditional_download, resolve_relative_path, is_image, is_video +from modules.utilities import ( + conditional_download, + resolve_relative_path, + is_image, + is_video, +) FACE_ENHANCER = None THREAD_SEMAPHORE = threading.Semaphore() THREAD_LOCK = threading.Lock() -NAME = 'DLC.FACE-ENHANCER' +NAME = "DLC.FACE-ENHANCER" def pre_check() -> bool: - download_directory_path = resolve_relative_path('..\models') - conditional_download(download_directory_path, ['https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth']) + download_directory_path = resolve_relative_path("..\models") + conditional_download( + download_directory_path, + [ + "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" + ], + ) return True def pre_start() -> bool: - if not is_image(modules.globals.target_path) and not is_video(modules.globals.target_path): - update_status('Select an image or video for target path.', NAME) + if not is_image(modules.globals.target_path) and not is_video( + modules.globals.target_path + ): + update_status("Select an image or video for target path.", NAME) return False return True @@ -35,21 +47,18 @@ def get_face_enhancer() -> Any: with THREAD_LOCK: if FACE_ENHANCER is None: - if os.name == 'nt': - model_path = resolve_relative_path('..\models\GFPGANv1.4.pth') + if os.name == "nt": + model_path = resolve_relative_path("..\models\GFPGANv1.4.pth") # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399 else: - model_path = resolve_relative_path('../models/GFPGANv1.4.pth') - FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined] + model_path = resolve_relative_path("../models/GFPGANv1.4.pth") + FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined] return FACE_ENHANCER def enhance_face(temp_frame: Frame) -> Frame: with THREAD_SEMAPHORE: - _, _, temp_frame = get_face_enhancer().enhance( - temp_frame, - paste_back=True - ) + _, _, temp_frame = get_face_enhancer().enhance(temp_frame, paste_back=True) return temp_frame @@ -60,7 +69,9 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame: return temp_frame -def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: +def process_frames( + source_path: str, temp_frame_paths: List[str], progress: Any = None +) -> None: for temp_frame_path in temp_frame_paths: temp_frame = cv2.imread(temp_frame_path) result = process_frame(None, temp_frame) @@ -77,3 +88,10 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None: def process_video(source_path: str, temp_frame_paths: List[str]) -> None: modules.processors.frame.core.process_video(None, temp_frame_paths, process_frames) + + +def process_frame_v2(temp_frame: Frame) -> Frame: + target_face = get_one_face(temp_frame) + if target_face: + temp_frame = enhance_face(temp_frame) + return temp_frame