Deep-Live-Cam/modules/predicter.py

131 lines
3.9 KiB
Python

import numpy as np
import opennsfw2
from PIL import Image
import cv2
import modules.globals
import logging
from functools import lru_cache
from typing import Union, Any
from modules.typing import Frame
logger = logging.getLogger(__name__)
# Global model instance for reuse
_model = None
@lru_cache(maxsize=1)
def load_nsfw_model():
"""
Load the NSFW prediction model with caching
Returns:
Loaded NSFW model
"""
try:
logger.info("Loading NSFW detection model")
return opennsfw2.make_open_nsfw_model()
except Exception as e:
logger.error(f"Failed to load NSFW model: {str(e)}")
return None
def get_nsfw_model():
"""
Get or initialize the NSFW model
Returns:
NSFW model instance
"""
global _model
if _model is None:
_model = load_nsfw_model()
return _model
def predict_frame(target_frame: Frame, threshold=None) -> bool:
"""
Predict if a frame contains NSFW content
Args:
target_frame: Frame to analyze as numpy array
threshold: NSFW probability threshold (default: from globals)
Returns:
True if NSFW content detected, False otherwise
"""
try:
if target_frame is None:
logger.warning("Cannot predict on None frame")
return False
# Get threshold from globals if not explicitly provided
if threshold is None:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
# Convert the frame to RGB if needed
expected_format = 'RGB' if modules.globals.color_correction else 'BGR'
if expected_format == 'RGB' and target_frame.shape[2] == 3:
processed_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB)
else:
processed_frame = target_frame
# Convert to PIL image and preprocess
image = Image.fromarray(processed_frame)
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
# Get model and predict
model = get_nsfw_model()
if model is None:
logger.error("NSFW model not available")
return False
views = np.expand_dims(image, axis=0)
_, probability = model.predict(views)[0]
logger.debug(f"NSFW probability: {probability:.4f}")
return probability > threshold
except Exception as e:
logger.error(f"Error during NSFW prediction: {str(e)}")
return False
def predict_image(target_path: str, threshold=None) -> bool:
"""
Predict if an image file contains NSFW content
Args:
target_path: Path to image file
threshold: NSFW probability threshold (default: from globals)
Returns:
True if NSFW content detected, False otherwise
"""
try:
if threshold is None:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
return opennsfw2.predict_image(target_path) > threshold
except Exception as e:
logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}")
return False
def predict_video(target_path: str, threshold=None) -> bool:
"""
Predict if a video file contains NSFW content
Args:
target_path: Path to video file
threshold: NSFW probability threshold (default: from globals)
Returns:
True if NSFW content detected, False otherwise
"""
try:
if threshold is None:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
_, probabilities = opennsfw2.predict_video_frames(
video_path=target_path,
frame_interval=100
)
return any(probability > threshold for probability in probabilities)
except Exception as e:
logger.error(f"Error predicting NSFW for video {target_path}: {str(e)}")
return False