fix(face_swapper): replacing with new code

pull/738/head
GhoulBoii 2024-10-24 00:55:18 +05:30
parent d616c513c9
commit f2894470a7
No known key found for this signature in database
GPG Key ID: F443BDA92506073F
1 changed files with 93 additions and 45 deletions

View File

@ -1,8 +1,8 @@
from typing import Any, List, Tuple from typing import Any, List, Tuple, Optional
import cv2 import cv2
import numpy as np
import insightface import insightface
import threading import threading
import numpy as np
import modules.globals import modules.globals
import modules.processors.frame.core import modules.processors.frame.core
@ -20,11 +20,7 @@ from modules.cluster_analysis import find_closest_centroid
FACE_SWAPPER = None FACE_SWAPPER = None
THREAD_LOCK = threading.Lock() THREAD_LOCK = threading.Lock()
NAME = "DLC.FACE-SWAPPER" NAME = "DLC.FACE-SWAPPER"
BLUR_AMOUNT = 12
# Add mouth landmarks indices for masking
MOUTH_LANDMARKS = list(
range(46, 68)
) # Common indices for mouth landmarks in facial detection
def pre_check() -> bool: def pre_check() -> bool:
@ -67,47 +63,105 @@ def get_face_swapper() -> Any:
return FACE_SWAPPER return FACE_SWAPPER
def create_mouth_mask(face: Face, frame_shape: Tuple[int, int]) -> np.ndarray: def create_face_mask(face: Face, frame: Frame) -> np.ndarray:
"""Create a mask for the mouth region""" """Create a binary mask for the face region."""
mask = np.zeros(frame_shape[:2], dtype=np.uint8) mask = np.zeros(frame.shape[:2], dtype=np.uint8)
landmarks = face.landmark_2d_106
# Get mouth landmarks from the face if landmarks is not None:
landmarks = face.kps hull = cv2.convexHull(landmarks.astype(np.int32))
mouth_points = landmarks[MOUTH_LANDMARKS].astype(np.int32) cv2.fillConvexPoly(mask, hull, 255)
# Create a polygon around the mouth region
cv2.fillPoly(mask, [mouth_points], 255)
# Dilate the mask slightly to ensure smooth blending
kernel = np.ones((5, 5), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=2)
# Blur the mask edges
mask = cv2.GaussianBlur(mask, (15, 15), 10)
return mask return mask
def blend_with_mask( def create_lower_mouth_mask(
swapped_frame: Frame, original_frame: Frame, mask: np.ndarray face: Face, frame: Frame
) -> Frame: ) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int, int, int], np.ndarray]:
"""Blend the swapped face with the original frame using the mouth mask""" """Create a mask for the lower mouth region."""
mask_3channel = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) / 255.0 mask = np.zeros(frame.shape[:2], dtype=np.uint8)
landmarks = face.landmark_2d_106
# Blend the images based on the mask if landmarks is not None:
blended = swapped_frame * (1 - mask_3channel) + original_frame * mask_3channel # Extract mouth landmarks
return blended.astype(np.uint8) mouth_points = landmarks[84:96] # Adjust indices based on your landmark format
lower_lip = mouth_points[6:12] # Lower lip points
# Create polygon for lower mouth area
lower_lip_polygon = cv2.convexHull(lower_lip.astype(np.int32))
cv2.fillConvexPoly(mask, lower_lip_polygon, 255)
# Get bounding box
x, y, w, h = cv2.boundingRect(lower_lip_polygon)
mouth_box = (x, y, w, h)
# Extract the mouth region
mouth_cutout = frame[y : y + h, x : x + w].copy()
return mask, mouth_cutout, mouth_box, lower_lip_polygon
return None, None, None, None
def apply_mouth_area(
frame: Frame,
mouth_cutout: np.ndarray,
mouth_box: Tuple[int, int, int, int],
face_mask: np.ndarray,
lower_lip_polygon: Optional[np.ndarray],
) -> Frame:
"""Apply the original mouth area back to the face-swapped frame."""
if mouth_cutout is None or mouth_box is None:
return frame
x, y, w, h = mouth_box
# Create a blurred version of the mask
mask = np.zeros(frame.shape[:2], dtype=np.uint8)
if lower_lip_polygon is not None:
cv2.fillConvexPoly(mask, lower_lip_polygon, 255)
else:
mask[y : y + h, x : x + w] = 255
# Blur the mask
blurred_mask = cv2.GaussianBlur(mask, (BLUR_AMOUNT * 2 + 1, BLUR_AMOUNT * 2 + 1), 0)
blurred_mask = blurred_mask / 255.0
# Create 3-channel mask
blurred_mask_3channel = np.repeat(blurred_mask[:, :, np.newaxis], 3, axis=2)
# Blend the original mouth area with the swapped face
frame_copy = frame.copy()
frame_copy[y : y + h, x : x + w] = mouth_cutout
# Combine using the blurred mask
result = (
frame_copy * blurred_mask_3channel + frame * (1 - blurred_mask_3channel)
).astype(np.uint8)
return result
def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
# Store the original frame for mouth preservation face_swapper = get_face_swapper()
original_frame = temp_frame.copy() # Apply the face swap
swapped_frame = face_swapper.get(
# Perform the face swap
swapped_frame = get_face_swapper().get(
temp_frame, target_face, source_face, paste_back=True temp_frame, target_face, source_face, paste_back=True
) )
if modules.globals.mouth_mask:
# Create masks
face_mask = create_face_mask(target_face, temp_frame)
mouth_mask, mouth_cutout, mouth_box, lower_lip_polygon = (
create_lower_mouth_mask(target_face, temp_frame)
)
if mouth_mask is not None:
# Apply the mouth area preservation
swapped_frame = apply_mouth_area(
swapped_frame, mouth_cutout, mouth_box, face_mask, lower_lip_polygon
)
return swapped_frame
def process_frame(source_face: Face, temp_frame: Frame) -> Frame: def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
# Ensure the frame is in RGB format if color correction is enabled # Ensure the frame is in RGB format if color correction is enabled
@ -123,6 +177,7 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
target_face = get_one_face(temp_frame) target_face = get_one_face(temp_frame)
if target_face: if target_face:
temp_frame = swap_face(source_face, target_face, temp_frame) temp_frame = swap_face(source_face, target_face, temp_frame)
return temp_frame return temp_frame
@ -133,7 +188,6 @@ def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame:
for map in modules.globals.souce_target_map: for map in modules.globals.souce_target_map:
target_face = map["target"]["face"] target_face = map["target"]["face"]
temp_frame = swap_face(source_face, target_face, temp_frame) temp_frame = swap_face(source_face, target_face, temp_frame)
elif not modules.globals.many_faces: elif not modules.globals.many_faces:
for map in modules.globals.souce_target_map: for map in modules.globals.souce_target_map:
if "source" in map: if "source" in map:
@ -150,11 +204,9 @@ def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame:
for f in map["target_faces_in_frame"] for f in map["target_faces_in_frame"]
if f["location"] == temp_frame_path if f["location"] == temp_frame_path
] ]
for frame in target_frame: for frame in target_frame:
for target_face in frame["faces"]: for target_face in frame["faces"]:
temp_frame = swap_face(source_face, target_face, temp_frame) temp_frame = swap_face(source_face, target_face, temp_frame)
elif not modules.globals.many_faces: elif not modules.globals.many_faces:
for map in modules.globals.souce_target_map: for map in modules.globals.souce_target_map:
if "source" in map: if "source" in map:
@ -164,7 +216,6 @@ def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame:
if f["location"] == temp_frame_path if f["location"] == temp_frame_path
] ]
source_face = map["source"]["face"] source_face = map["source"]["face"]
for frame in target_frame: for frame in target_frame:
for target_face in frame["faces"]: for target_face in frame["faces"]:
temp_frame = swap_face(source_face, target_face, temp_frame) temp_frame = swap_face(source_face, target_face, temp_frame)
@ -175,7 +226,6 @@ def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame:
source_face = default_source_face() source_face = default_source_face()
for target_face in detected_faces: for target_face in detected_faces:
temp_frame = swap_face(source_face, target_face, temp_frame) temp_frame = swap_face(source_face, target_face, temp_frame)
elif not modules.globals.many_faces: elif not modules.globals.many_faces:
if detected_faces: if detected_faces:
if len(detected_faces) <= len( if len(detected_faces) <= len(
@ -186,7 +236,6 @@ def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame:
modules.globals.simple_map["target_embeddings"], modules.globals.simple_map["target_embeddings"],
detected_face.normed_embedding, detected_face.normed_embedding,
) )
temp_frame = swap_face( temp_frame = swap_face(
modules.globals.simple_map["source_faces"][ modules.globals.simple_map["source_faces"][
closest_centroid_index closest_centroid_index
@ -205,7 +254,6 @@ def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame:
closest_centroid_index, _ = find_closest_centroid( closest_centroid_index, _ = find_closest_centroid(
detected_faces_centroids, target_embedding detected_faces_centroids, target_embedding
) )
temp_frame = swap_face( temp_frame = swap_face(
modules.globals.simple_map["source_faces"][i], modules.globals.simple_map["source_faces"][i],
detected_faces[closest_centroid_index], detected_faces[closest_centroid_index],