131 lines
3.9 KiB
Python
131 lines
3.9 KiB
Python
import numpy as np
|
|
import opennsfw2
|
|
from PIL import Image
|
|
import cv2
|
|
import modules.globals
|
|
import logging
|
|
from functools import lru_cache
|
|
from typing import Union, Any
|
|
|
|
from modules.typing import Frame
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global model instance for reuse
|
|
_model = None
|
|
|
|
@lru_cache(maxsize=1)
|
|
def load_nsfw_model():
|
|
"""
|
|
Load the NSFW prediction model with caching
|
|
|
|
Returns:
|
|
Loaded NSFW model
|
|
"""
|
|
try:
|
|
logger.info("Loading NSFW detection model")
|
|
return opennsfw2.make_open_nsfw_model()
|
|
except Exception as e:
|
|
logger.error(f"Failed to load NSFW model: {str(e)}")
|
|
return None
|
|
|
|
def get_nsfw_model():
|
|
"""
|
|
Get or initialize the NSFW model
|
|
|
|
Returns:
|
|
NSFW model instance
|
|
"""
|
|
global _model
|
|
if _model is None:
|
|
_model = load_nsfw_model()
|
|
return _model
|
|
|
|
def predict_frame(target_frame: Frame, threshold=None) -> bool:
|
|
"""
|
|
Predict if a frame contains NSFW content
|
|
|
|
Args:
|
|
target_frame: Frame to analyze as numpy array
|
|
threshold: NSFW probability threshold (default: from globals)
|
|
|
|
Returns:
|
|
True if NSFW content detected, False otherwise
|
|
"""
|
|
try:
|
|
if target_frame is None:
|
|
logger.warning("Cannot predict on None frame")
|
|
return False
|
|
|
|
# Get threshold from globals if not explicitly provided
|
|
if threshold is None:
|
|
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
|
|
|
# Convert the frame to RGB if needed
|
|
expected_format = 'RGB' if modules.globals.color_correction else 'BGR'
|
|
if expected_format == 'RGB' and target_frame.shape[2] == 3:
|
|
processed_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB)
|
|
else:
|
|
processed_frame = target_frame
|
|
|
|
# Convert to PIL image and preprocess
|
|
image = Image.fromarray(processed_frame)
|
|
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
|
|
|
|
# Get model and predict
|
|
model = get_nsfw_model()
|
|
if model is None:
|
|
logger.error("NSFW model not available")
|
|
return False
|
|
|
|
views = np.expand_dims(image, axis=0)
|
|
_, probability = model.predict(views)[0]
|
|
|
|
logger.debug(f"NSFW probability: {probability:.4f}")
|
|
return probability > threshold
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during NSFW prediction: {str(e)}")
|
|
return False
|
|
|
|
def predict_image(target_path: str, threshold=None) -> bool:
|
|
"""
|
|
Predict if an image file contains NSFW content
|
|
|
|
Args:
|
|
target_path: Path to image file
|
|
threshold: NSFW probability threshold (default: from globals)
|
|
|
|
Returns:
|
|
True if NSFW content detected, False otherwise
|
|
"""
|
|
try:
|
|
if threshold is None:
|
|
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
|
return opennsfw2.predict_image(target_path) > threshold
|
|
except Exception as e:
|
|
logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}")
|
|
return False
|
|
|
|
def predict_video(target_path: str, threshold=None) -> bool:
|
|
"""
|
|
Predict if a video file contains NSFW content
|
|
|
|
Args:
|
|
target_path: Path to video file
|
|
threshold: NSFW probability threshold (default: from globals)
|
|
|
|
Returns:
|
|
True if NSFW content detected, False otherwise
|
|
"""
|
|
try:
|
|
if threshold is None:
|
|
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
|
_, probabilities = opennsfw2.predict_video_frames(
|
|
video_path=target_path,
|
|
frame_interval=100
|
|
)
|
|
return any(probability > threshold for probability in probabilities)
|
|
except Exception as e:
|
|
logger.error(f"Error predicting NSFW for video {target_path}: {str(e)}")
|
|
return False |