diff --git a/modules/capturer.py b/modules/capturer.py index fd49d46..9a04ca0 100644 --- a/modules/capturer.py +++ b/modules/capturer.py @@ -4,13 +4,22 @@ import cv2 def get_video_frame(video_path: str, frame_number: int = 0) -> Any: capture = cv2.VideoCapture(video_path) + + # Set MJPEG format to ensure correct color space handling + capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) + # Force OpenCV to convert to RGB + capture.set(cv2.CAP_PROP_CONVERT_RGB, 1) + frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) has_frame, frame = capture.read() - capture.release() + if has_frame: - return frame - return None + # Convert the frame color if necessary + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + capture.release() + return frame if has_frame else None def get_video_frame_total(video_path: str) -> int: diff --git a/modules/predicter.py b/modules/predicter.py index dbb680e..4931076 100644 --- a/modules/predicter.py +++ b/modules/predicter.py @@ -1,6 +1,7 @@ import numpy import opennsfw2 from PIL import Image +import cv2 # Add OpenCV import from modules.typing import Frame @@ -10,6 +11,8 @@ MAX_PROBABILITY = 0.85 model = None def predict_frame(target_frame: Frame) -> bool: + # Convert the frame to RGB before processing + target_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(target_frame) image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) global model diff --git a/modules/processors/frame/face_swapper.py b/modules/processors/frame/face_swapper.py index 4b4a222..cde43f0 100644 --- a/modules/processors/frame/face_swapper.py +++ b/modules/processors/frame/face_swapper.py @@ -49,6 +49,8 @@ def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: def process_frame(source_face: Face, temp_frame: Frame) -> Frame: + # Ensure the frame is in RGB format + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) if modules.globals.many_faces: many_faces = get_many_faces(temp_frame) if many_faces: