Merge 532c8e57db into e879d2ca64
				
					
				
			
						commit
						d79355d86a
					
				|  | @ -1,32 +1,74 @@ | |||
| from typing import Any | ||||
| from typing import Any, Optional | ||||
| import cv2 | ||||
| import modules.globals  # Import the globals to check the color correction toggle | ||||
| import modules.globals | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| def get_video_frame(video_path: str, frame_number: int = 0) -> Any: | ||||
|     capture = cv2.VideoCapture(video_path) | ||||
| def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Any]: | ||||
|     """ | ||||
|     Extract a specific frame from a video file with proper color handling. | ||||
|      | ||||
|     # Set MJPEG format to ensure correct color space handling | ||||
|     capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) | ||||
|     Args: | ||||
|         video_path: Path to the video file | ||||
|         frame_number: Frame number to extract (defaults to first frame) | ||||
|          | ||||
|     # Only force RGB conversion if color correction is enabled | ||||
|     if modules.globals.color_correction: | ||||
|         capture.set(cv2.CAP_PROP_CONVERT_RGB, 1) | ||||
|     Returns: | ||||
|         Video frame as numpy array or None if frame extraction fails | ||||
|     """ | ||||
|     try: | ||||
|         capture = cv2.VideoCapture(video_path) | ||||
|         if not capture.isOpened(): | ||||
|             logger.error(f"Failed to open video: {video_path}") | ||||
|             return None | ||||
| 
 | ||||
|     frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) | ||||
|     capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) | ||||
|     has_frame, frame = capture.read() | ||||
|         # Set MJPEG format to ensure correct color space handling | ||||
|         capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) | ||||
|          | ||||
|     if has_frame and modules.globals.color_correction: | ||||
|         # Convert the frame color if necessary | ||||
|         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||||
|         # Configure color conversion based on setting | ||||
|         if modules.globals.color_correction: | ||||
|             capture.set(cv2.CAP_PROP_CONVERT_RGB, 1) | ||||
|         else: | ||||
|             capture.set(cv2.CAP_PROP_CONVERT_RGB, 0)  # Explicitly disable if not needed | ||||
|          | ||||
|     capture.release() | ||||
|     return frame if has_frame else None | ||||
|         frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) | ||||
|         # Ensure frame_number is valid (0-based index) | ||||
|         target_frame = max(0, min(frame_total - 1, frame_number)) | ||||
|         capture.set(cv2.CAP_PROP_POS_FRAMES, target_frame) | ||||
|         has_frame, frame = capture.read() | ||||
| 
 | ||||
|         # Only convert manually if color_correction is enabled but capture didn't handle it | ||||
|         if has_frame and modules.globals.color_correction and frame is not None: | ||||
|             frame_channels = frame.shape[2] if len(frame.shape) == 3 else 1 | ||||
|             if frame_channels == 3:  # Only convert if we have a color image | ||||
|                 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||||
| 
 | ||||
|         capture.release() | ||||
|         return frame if has_frame else None | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error processing video frame: {str(e)}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| def get_video_frame_total(video_path: str) -> int: | ||||
|     capture = cv2.VideoCapture(video_path) | ||||
|     video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) | ||||
|     capture.release() | ||||
|     return video_frame_total | ||||
|     """ | ||||
|     Get the total number of frames in a video file. | ||||
|      | ||||
|     Args: | ||||
|         video_path: Path to the video file | ||||
|          | ||||
|     Returns: | ||||
|         Total number of frames in the video | ||||
|     """ | ||||
|     try: | ||||
|         capture = cv2.VideoCapture(video_path) | ||||
|         if not capture.isOpened(): | ||||
|             logger.error(f"Failed to open video for frame counting: {video_path}") | ||||
|             return 0 | ||||
|              | ||||
|         video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) | ||||
|         capture.release() | ||||
|         return video_frame_total | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error counting video frames: {str(e)}") | ||||
|         return 0 | ||||
|  | @ -1,32 +1,111 @@ | |||
| import numpy as np | ||||
| from sklearn.cluster import KMeans | ||||
| from sklearn.metrics import silhouette_score | ||||
| from typing import Any | ||||
| from typing import Any, List, Optional, Tuple | ||||
| import logging | ||||
| import modules.globals | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| def find_cluster_centroids(embeddings, max_k=10) -> Any: | ||||
|     inertia = [] | ||||
|     cluster_centroids = [] | ||||
|     K = range(1, max_k+1) | ||||
| def find_cluster_centroids(embeddings, max_k=None, kmeans_init=None) -> Any: | ||||
|     """ | ||||
|     Identifies optimal face clusters using KMeans and silhouette scoring | ||||
|      | ||||
|     for k in K: | ||||
|         kmeans = KMeans(n_clusters=k, random_state=0) | ||||
|         kmeans.fit(embeddings) | ||||
|         inertia.append(kmeans.inertia_) | ||||
|         cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_}) | ||||
|     Args: | ||||
|         embeddings: Face embedding vectors | ||||
|         max_k: Maximum number of clusters to consider (default: from globals) | ||||
|         kmeans_init: KMeans initialization method (default: from globals) | ||||
|          | ||||
|     diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)] | ||||
|     optimal_centroids = cluster_centroids[diffs.index(max(diffs)) + 1]['centroids'] | ||||
|     Returns: | ||||
|         Array of optimal cluster centroids | ||||
|     """ | ||||
|     try: | ||||
|         if len(embeddings) < 2: | ||||
|             logger.warning("Not enough embeddings for clustering analysis") | ||||
|             return embeddings  # Return the single embedding as its own cluster | ||||
|              | ||||
|     return optimal_centroids | ||||
|         # Use settings from globals if not explicitly provided | ||||
|         if max_k is None: | ||||
|             max_k = getattr(modules.globals, 'max_cluster_k', 10) | ||||
|         if kmeans_init is None: | ||||
|             kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++') | ||||
|          | ||||
| def find_closest_centroid(centroids: list, normed_face_embedding) -> list: | ||||
|         # Try silhouette method first | ||||
|         best_k = 2  # Start with minimum viable cluster count | ||||
|         best_score = -1 | ||||
|         best_centroids = None | ||||
|          | ||||
|         # We need at least 3 samples to calculate silhouette score | ||||
|         if len(embeddings) >= 3: | ||||
|             # Find optimal k using silhouette analysis | ||||
|             for k in range(2, min(max_k+1, len(embeddings))): | ||||
|                 try: | ||||
|                     kmeans = KMeans(n_clusters=k, init=kmeans_init, n_init=10, random_state=0) | ||||
|                     labels = kmeans.fit_predict(embeddings) | ||||
|                      | ||||
|                     # Calculate silhouette score | ||||
|                     score = silhouette_score(embeddings, labels) | ||||
|                      | ||||
|                     if score > best_score: | ||||
|                         best_score = score | ||||
|                         best_k = k | ||||
|                         best_centroids = kmeans.cluster_centers_ | ||||
|                 except Exception as e: | ||||
|                     logger.warning(f"Error during silhouette analysis for k={k}: {str(e)}") | ||||
|                     continue | ||||
|          | ||||
|         # Fallback to elbow method if silhouette failed or for small datasets | ||||
|         if best_centroids is None: | ||||
|             inertia = [] | ||||
|             cluster_centroids = [] | ||||
|             K = range(1, min(max_k+1, len(embeddings)+1)) | ||||
| 
 | ||||
|             for k in K: | ||||
|                 kmeans = KMeans(n_clusters=k, init=kmeans_init, random_state=0) | ||||
|                 kmeans.fit(embeddings) | ||||
|                 inertia.append(kmeans.inertia_) | ||||
|                 cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_}) | ||||
| 
 | ||||
|             if len(inertia) > 1: | ||||
|                 diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)] | ||||
|                 best_idx = diffs.index(max(diffs)) | ||||
|                 best_centroids = cluster_centroids[best_idx + 1]['centroids'] | ||||
|             else: | ||||
|                 # Just one cluster | ||||
|                 best_centroids = cluster_centroids[0]['centroids'] | ||||
|          | ||||
|         return best_centroids | ||||
|          | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error in cluster analysis: {str(e)}") | ||||
|         # Return a single centroid (mean of all embeddings) as fallback | ||||
|         return np.mean(embeddings, axis=0, keepdims=True) | ||||
| 
 | ||||
| def find_closest_centroid(centroids: list, normed_face_embedding) -> Optional[Tuple[int, np.ndarray]]: | ||||
|     """ | ||||
|     Find the closest centroid to a face embedding | ||||
|      | ||||
|     Args: | ||||
|         centroids: List of cluster centroids | ||||
|         normed_face_embedding: Normalized face embedding vector | ||||
|          | ||||
|     Returns: | ||||
|         Tuple of (centroid index, centroid vector) or None if matching fails | ||||
|     """ | ||||
|     try: | ||||
|         centroids = np.array(centroids) | ||||
|         normed_face_embedding = np.array(normed_face_embedding) | ||||
|          | ||||
|         # Validate input shapes | ||||
|         if len(centroids.shape) != 2 or len(normed_face_embedding.shape) != 1: | ||||
|             logger.warning(f"Invalid shapes: centroids {centroids.shape}, embedding {normed_face_embedding.shape}") | ||||
|             return None | ||||
|              | ||||
|         # Calculate similarity (dot product) between embedding and each centroid | ||||
|         similarities = np.dot(centroids, normed_face_embedding) | ||||
|         closest_centroid_index = np.argmax(similarities) | ||||
|          | ||||
|         return closest_centroid_index, centroids[closest_centroid_index] | ||||
|     except ValueError: | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error finding closest centroid: {str(e)}") | ||||
|         return None | ||||
|  | @ -1,17 +1,36 @@ | |||
| import os | ||||
| from typing import List, Dict, Any | ||||
| import json | ||||
| import logging | ||||
| from typing import List, Dict, Any, Optional | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| # Core paths | ||||
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | ||||
| WORKFLOW_DIR = os.path.join(ROOT_DIR, "workflow") | ||||
| CONFIG_PATH = os.path.join(ROOT_DIR, "config.json") | ||||
| 
 | ||||
| # Default configuration settings | ||||
| DEFAULT_SETTINGS = { | ||||
|     'max_cluster_k': 10, | ||||
|     'kmeans_init': 'k-means++', | ||||
|     'nsfw_threshold': 0.85, | ||||
|     'mask_feather_ratio': 8, | ||||
|     'mask_down_size': 0.50, | ||||
|     'mask_size': 1 | ||||
| } | ||||
| 
 | ||||
| # File type definitions | ||||
| file_types = [ | ||||
|     ("Image", ("*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp")), | ||||
|     ("Video", ("*.mp4", "*.mkv")), | ||||
| ] | ||||
| 
 | ||||
| # Runtime variables | ||||
| source_target_map = [] | ||||
| simple_map = {} | ||||
| 
 | ||||
| # Paths and processing options | ||||
| source_path = None | ||||
| target_path = None | ||||
| output_path = None | ||||
|  | @ -21,7 +40,7 @@ keep_audio = True | |||
| keep_frames = False | ||||
| many_faces = False | ||||
| map_faces = False | ||||
| color_correction = False  # New global variable for color correction toggle | ||||
| color_correction = False | ||||
| nsfw_filter = False | ||||
| video_encoder = None | ||||
| video_quality = None | ||||
|  | @ -38,6 +57,70 @@ webcam_preview_running = False | |||
| show_fps = False | ||||
| mouth_mask = False | ||||
| show_mouth_mask_box = False | ||||
| mask_feather_ratio = 8 | ||||
| mask_down_size = 0.50 | ||||
| mask_size = 1 | ||||
| 
 | ||||
| # Masking parameters - moved from hardcoded to configurable | ||||
| mask_feather_ratio = DEFAULT_SETTINGS['mask_feather_ratio'] | ||||
| mask_down_size = DEFAULT_SETTINGS['mask_down_size'] | ||||
| mask_size = DEFAULT_SETTINGS['mask_size'] | ||||
| 
 | ||||
| # Advanced parameters | ||||
| max_cluster_k = DEFAULT_SETTINGS['max_cluster_k'] | ||||
| kmeans_init = DEFAULT_SETTINGS['kmeans_init'] | ||||
| nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold'] | ||||
| 
 | ||||
| def init() -> None: | ||||
|     """ | ||||
|     Initialize the globals module and load settings | ||||
|     Should be called explicitly by the application during startup | ||||
|     """ | ||||
|     load_settings() | ||||
|     logger.info("Globals module initialized") | ||||
| 
 | ||||
| def load_settings() -> None: | ||||
|     """ | ||||
|     Load user settings from config file | ||||
|     """ | ||||
|     global mask_feather_ratio, mask_down_size, mask_size | ||||
|     global max_cluster_k, kmeans_init, nsfw_threshold | ||||
|      | ||||
|     try: | ||||
|         if os.path.exists(CONFIG_PATH): | ||||
|             with open(CONFIG_PATH, 'r') as f: | ||||
|                 config = json.load(f) | ||||
|                  | ||||
|             # Apply settings from config, falling back to defaults | ||||
|             mask_feather_ratio = config.get('mask_feather_ratio', DEFAULT_SETTINGS['mask_feather_ratio']) | ||||
|             mask_down_size = config.get('mask_down_size', DEFAULT_SETTINGS['mask_down_size']) | ||||
|             mask_size = config.get('mask_size', DEFAULT_SETTINGS['mask_size']) | ||||
|             max_cluster_k = config.get('max_cluster_k', DEFAULT_SETTINGS['max_cluster_k']) | ||||
|             kmeans_init = config.get('kmeans_init', DEFAULT_SETTINGS['kmeans_init']) | ||||
|             nsfw_threshold = config.get('nsfw_threshold', DEFAULT_SETTINGS['nsfw_threshold']) | ||||
|              | ||||
|             logger.info("Settings loaded from config file") | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error loading settings: {str(e)}") | ||||
|         # Use defaults if loading fails | ||||
| 
 | ||||
| def save_settings() -> None: | ||||
|     """ | ||||
|     Save current settings to config file | ||||
|     """ | ||||
|     try: | ||||
|         config = { | ||||
|             'mask_feather_ratio': mask_feather_ratio, | ||||
|             'mask_down_size': mask_down_size, | ||||
|             'mask_size': mask_size, | ||||
|             'max_cluster_k': max_cluster_k, | ||||
|             'kmeans_init': kmeans_init, | ||||
|             'nsfw_threshold': nsfw_threshold | ||||
|         } | ||||
|          | ||||
|         with open(CONFIG_PATH, 'w') as f: | ||||
|             json.dump(config, f, indent=2) | ||||
|              | ||||
|         logger.info("Settings saved to config file") | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error saving settings: {str(e)}") | ||||
| 
 | ||||
| # Don't load settings at import time to avoid side effects | ||||
| # Will be called explicitly by the application's initialization | ||||
|  | @ -1,36 +1,131 @@ | |||
| import numpy | ||||
| import numpy as np | ||||
| import opennsfw2 | ||||
| from PIL import Image | ||||
| import cv2  # Add OpenCV import | ||||
| import modules.globals  # Import globals to access the color correction toggle | ||||
| import cv2 | ||||
| import modules.globals | ||||
| import logging | ||||
| from functools import lru_cache | ||||
| from typing import Union, Any | ||||
| 
 | ||||
| from modules.typing import Frame | ||||
| 
 | ||||
| MAX_PROBABILITY = 0.85 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| # Preload the model once for efficiency | ||||
| model = None | ||||
| # Global model instance for reuse | ||||
| _model = None | ||||
| 
 | ||||
| def predict_frame(target_frame: Frame) -> bool: | ||||
|     # Convert the frame to RGB before processing if color correction is enabled | ||||
|     if modules.globals.color_correction: | ||||
|         target_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB) | ||||
| @lru_cache(maxsize=1) | ||||
| def load_nsfw_model(): | ||||
|     """ | ||||
|     Load the NSFW prediction model with caching | ||||
|      | ||||
|     image = Image.fromarray(target_frame) | ||||
|     image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) | ||||
|     global model | ||||
|     if model is None:  | ||||
|         model = opennsfw2.make_open_nsfw_model() | ||||
|     Returns: | ||||
|         Loaded NSFW model | ||||
|     """ | ||||
|     try: | ||||
|         logger.info("Loading NSFW detection model") | ||||
|         return opennsfw2.make_open_nsfw_model() | ||||
|     except Exception as e: | ||||
|         logger.error(f"Failed to load NSFW model: {str(e)}") | ||||
|         return None | ||||
| 
 | ||||
|     views = numpy.expand_dims(image, axis=0) | ||||
|     _, probability = model.predict(views)[0] | ||||
|     return probability > MAX_PROBABILITY | ||||
| def get_nsfw_model(): | ||||
|     """ | ||||
|     Get or initialize the NSFW model | ||||
|      | ||||
|     Returns: | ||||
|         NSFW model instance | ||||
|     """ | ||||
|     global _model | ||||
|     if _model is None: | ||||
|         _model = load_nsfw_model() | ||||
|     return _model | ||||
| 
 | ||||
| def predict_image(target_path: str) -> bool: | ||||
|     return opennsfw2.predict_image(target_path) > MAX_PROBABILITY | ||||
| def predict_frame(target_frame: Frame, threshold=None) -> bool: | ||||
|     """ | ||||
|     Predict if a frame contains NSFW content | ||||
|      | ||||
|     Args: | ||||
|         target_frame: Frame to analyze as numpy array | ||||
|         threshold: NSFW probability threshold (default: from globals) | ||||
|          | ||||
| def predict_video(target_path: str) -> bool: | ||||
|     _, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100) | ||||
|     return any(probability > MAX_PROBABILITY for probability in probabilities) | ||||
|     Returns: | ||||
|         True if NSFW content detected, False otherwise | ||||
|     """ | ||||
|     try: | ||||
|         if target_frame is None: | ||||
|             logger.warning("Cannot predict on None frame") | ||||
|             return False | ||||
|              | ||||
|         # Get threshold from globals if not explicitly provided | ||||
|         if threshold is None: | ||||
|             threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) | ||||
|          | ||||
|         # Convert the frame to RGB if needed | ||||
|         expected_format = 'RGB' if modules.globals.color_correction else 'BGR' | ||||
|         if expected_format == 'RGB' and target_frame.shape[2] == 3: | ||||
|             processed_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB) | ||||
|         else: | ||||
|             processed_frame = target_frame | ||||
|              | ||||
|         # Convert to PIL image and preprocess | ||||
|         image = Image.fromarray(processed_frame) | ||||
|         image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) | ||||
|          | ||||
|         # Get model and predict | ||||
|         model = get_nsfw_model() | ||||
|         if model is None: | ||||
|             logger.error("NSFW model not available") | ||||
|             return False | ||||
|              | ||||
|         views = np.expand_dims(image, axis=0) | ||||
|         _, probability = model.predict(views)[0] | ||||
|          | ||||
|         logger.debug(f"NSFW probability: {probability:.4f}") | ||||
|         return probability > threshold | ||||
|          | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error during NSFW prediction: {str(e)}") | ||||
|         return False | ||||
| 
 | ||||
| def predict_image(target_path: str, threshold=None) -> bool: | ||||
|     """ | ||||
|     Predict if an image file contains NSFW content | ||||
|      | ||||
|     Args: | ||||
|         target_path: Path to image file | ||||
|         threshold: NSFW probability threshold (default: from globals) | ||||
|          | ||||
|     Returns: | ||||
|         True if NSFW content detected, False otherwise | ||||
|     """ | ||||
|     try: | ||||
|         if threshold is None: | ||||
|             threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) | ||||
|         return opennsfw2.predict_image(target_path) > threshold | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}") | ||||
|         return False | ||||
| 
 | ||||
| def predict_video(target_path: str, threshold=None) -> bool: | ||||
|     """ | ||||
|     Predict if a video file contains NSFW content | ||||
|      | ||||
|     Args: | ||||
|         target_path: Path to video file | ||||
|         threshold: NSFW probability threshold (default: from globals) | ||||
|          | ||||
|     Returns: | ||||
|         True if NSFW content detected, False otherwise | ||||
|     """ | ||||
|     try: | ||||
|         if threshold is None: | ||||
|             threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) | ||||
|         _, probabilities = opennsfw2.predict_video_frames( | ||||
|             video_path=target_path,  | ||||
|             frame_interval=100 | ||||
|         ) | ||||
|         return any(probability > threshold for probability in probabilities) | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error predicting NSFW for video {target_path}: {str(e)}") | ||||
|         return False | ||||
		Loading…
	
		Reference in New Issue