import torch
import numpy as np
import cv2

class RemoveSmallMaskRegions:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "mask": ("MASK",),
                "threshold": ("FLOAT", {
                    "default": 0.01,
                    "min": 0.0,
                    "max": 1.0,
                    "step": 0.001,
                    "display": "number"
                }),
            }
        }
    
    RETURN_TYPES = ("MASK",)
    FUNCTION = "execute"
    CATEGORY = "ReImage AI"
    
    def mask_to_numpy(self, mask_tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI mask tensor (H, W) to numpy array (H, W)."""
        mask_np = mask_tensor.cpu().numpy()
        mask_np = (mask_np * 255).clip(0, 255).astype(np.uint8)
        return mask_np
    
    def numpy_to_tensor(self, numpy_array: np.ndarray) -> torch.Tensor:
        """Convert numpy array (H, W) to ComfyUI mask tensor (H, W)."""
        tensor = torch.from_numpy(numpy_array.astype(np.float32) / 255.0)
        return tensor
    
    def execute(self, mask, threshold):
        # Convert tensor to numpy
        result = []
        for i in range(len(mask)):
            mask_np = self.mask_to_numpy(mask[i])
            
            # Get image dimensions
            height, width = mask_np.shape
            total_pixels = height * width
            
            # Find connected components
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
                mask_np, connectivity=8
            )
            
            # Create output mask (start with zeros)
            filtered_mask = np.zeros_like(mask_np)
            
            # Iterate through each component (skip background label 0)
            for label in range(1, num_labels):
                # Get the area of this component
                component_area = stats[label, cv2.CC_STAT_AREA]
                
                # Calculate ratio
                area_ratio = component_area / total_pixels
                
                # Keep the region if it's above threshold
                if area_ratio >= threshold:
                    filtered_mask[labels == label] = 255
            
            # Convert back to tensor
            output_tensor = self.numpy_to_tensor(filtered_mask)
            result.append(output_tensor)
        
        output_tensor = torch.stack(result, dim=0)
        return (output_tensor,)



class CombineMaskFromListMasks:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "masks": ("MASK",),
            }
        }
    RETURN_TYPES = ("MASK",)
    FUNCTION = "execute"
    CATEGORY = "ReImage AI"
    def execute(self, masks):
        # combine masks from list of masks masks with pixel > 0.2
        combined_mask = torch.zeros_like(masks[0])
        for i in range(len(masks)):
            combined_mask[masks[i] > 0.2] = 1
        return (combined_mask,)