111 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			111 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
| import numpy as np
 | |
| from sklearn.cluster import KMeans
 | |
| from sklearn.metrics import silhouette_score
 | |
| from typing import Any, List, Optional, Tuple
 | |
| import logging
 | |
| import modules.globals
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| def find_cluster_centroids(embeddings, max_k=None, kmeans_init=None) -> Any:
 | |
|     """
 | |
|     Identifies optimal face clusters using KMeans and silhouette scoring
 | |
|     
 | |
|     Args:
 | |
|         embeddings: Face embedding vectors
 | |
|         max_k: Maximum number of clusters to consider (default: from globals)
 | |
|         kmeans_init: KMeans initialization method (default: from globals)
 | |
|         
 | |
|     Returns:
 | |
|         Array of optimal cluster centroids
 | |
|     """
 | |
|     try:
 | |
|         if len(embeddings) < 2:
 | |
|             logger.warning("Not enough embeddings for clustering analysis")
 | |
|             return embeddings  # Return the single embedding as its own cluster
 | |
|             
 | |
|         # Use settings from globals if not explicitly provided
 | |
|         if max_k is None:
 | |
|             max_k = getattr(modules.globals, 'max_cluster_k', 10)
 | |
|         if kmeans_init is None:
 | |
|             kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++')
 | |
|         
 | |
|         # Try silhouette method first
 | |
|         best_k = 2  # Start with minimum viable cluster count
 | |
|         best_score = -1
 | |
|         best_centroids = None
 | |
|         
 | |
|         # We need at least 3 samples to calculate silhouette score
 | |
|         if len(embeddings) >= 3:
 | |
|             # Find optimal k using silhouette analysis
 | |
|             for k in range(2, min(max_k+1, len(embeddings))):
 | |
|                 try:
 | |
|                     kmeans = KMeans(n_clusters=k, init=kmeans_init, n_init=10, random_state=0)
 | |
|                     labels = kmeans.fit_predict(embeddings)
 | |
|                     
 | |
|                     # Calculate silhouette score
 | |
|                     score = silhouette_score(embeddings, labels)
 | |
|                     
 | |
|                     if score > best_score:
 | |
|                         best_score = score
 | |
|                         best_k = k
 | |
|                         best_centroids = kmeans.cluster_centers_
 | |
|                 except Exception as e:
 | |
|                     logger.warning(f"Error during silhouette analysis for k={k}: {str(e)}")
 | |
|                     continue
 | |
|         
 | |
|         # Fallback to elbow method if silhouette failed or for small datasets
 | |
|         if best_centroids is None:
 | |
|             inertia = []
 | |
|             cluster_centroids = []
 | |
|             K = range(1, min(max_k+1, len(embeddings)+1))
 | |
| 
 | |
|             for k in K:
 | |
|                 kmeans = KMeans(n_clusters=k, init=kmeans_init, random_state=0)
 | |
|                 kmeans.fit(embeddings)
 | |
|                 inertia.append(kmeans.inertia_)
 | |
|                 cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_})
 | |
| 
 | |
|             if len(inertia) > 1:
 | |
|                 diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)]
 | |
|                 best_idx = diffs.index(max(diffs))
 | |
|                 best_centroids = cluster_centroids[best_idx + 1]['centroids']
 | |
|             else:
 | |
|                 # Just one cluster
 | |
|                 best_centroids = cluster_centroids[0]['centroids']
 | |
|         
 | |
|         return best_centroids
 | |
|         
 | |
|     except Exception as e:
 | |
|         logger.error(f"Error in cluster analysis: {str(e)}")
 | |
|         # Return a single centroid (mean of all embeddings) as fallback
 | |
|         return np.mean(embeddings, axis=0, keepdims=True)
 | |
| 
 | |
| def find_closest_centroid(centroids: list, normed_face_embedding) -> Optional[Tuple[int, np.ndarray]]:
 | |
|     """
 | |
|     Find the closest centroid to a face embedding
 | |
|     
 | |
|     Args:
 | |
|         centroids: List of cluster centroids
 | |
|         normed_face_embedding: Normalized face embedding vector
 | |
|         
 | |
|     Returns:
 | |
|         Tuple of (centroid index, centroid vector) or None if matching fails
 | |
|     """
 | |
|     try:
 | |
|         centroids = np.array(centroids)
 | |
|         normed_face_embedding = np.array(normed_face_embedding)
 | |
|         
 | |
|         # Validate input shapes
 | |
|         if len(centroids.shape) != 2 or len(normed_face_embedding.shape) != 1:
 | |
|             logger.warning(f"Invalid shapes: centroids {centroids.shape}, embedding {normed_face_embedding.shape}")
 | |
|             return None
 | |
|             
 | |
|         # Calculate similarity (dot product) between embedding and each centroid
 | |
|         similarities = np.dot(centroids, normed_face_embedding)
 | |
|         closest_centroid_index = np.argmax(similarities)
 | |
|         
 | |
|         return closest_centroid_index, centroids[closest_centroid_index]
 | |
|     except Exception as e:
 | |
|         logger.error(f"Error finding closest centroid: {str(e)}")
 | |
|         return None |