import cv2
from PIL import Image
import numpy as np
import torch
from typing import Union, List

class HighlightMaskBorder:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "mask": ("MASK",),
                "thickness": ("INT", {
                    "default": 3,
                    "min": 1,
                    "max": 50,
                    "step": 1
                }),
            }
        }
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "execute"
    CATEGORY = "image/mask"

    def tensor_to_image(self, tensor):
        # Ensure tensor is in the right format (H, W, C)
        # if len(tensor.shape) == 4:
            # If batch dimension exists, take the first image
        tensor = tensor[0]
        image = tensor.mul(255).clamp(0, 255).byte().cpu()
        if len(image.shape) == 3 and image.shape[-1] == 3:
            image = image.numpy()
        elif len(image.shape) == 3 and image.shape[-1] == 4:
            # If the image has an alpha channel, ignore it
            image = image[..., :3].numpy()
        elif len(image.shape) == 2:
            image = image.numpy()
        # image = image[..., [0, 1, 2]].numpy()
        return image

    def pil2tensor(self, image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:


        # Convert PIL image to RGB if needed
        image_np = np.array(image)
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        elif image.mode != 'RGB':
            image = image.convert('RGB')

        # Convert to numpy array and normalize to [0, 1]
        img_array = np.array(image).astype(np.float32) / 255.0
        
        # Return tensor with shape [1, H, W, 3]
        return torch.from_numpy(img_array)[None,]

    def execute(self, image, mask, thickness=3):
        image_pil = Image.fromarray(self.tensor_to_image(image))
        mask_pil = Image.fromarray(self.tensor_to_image(mask))
        highlighted_image_pil = self.highlight_mask_border_pil(image_pil, mask_pil, color=(255, 0, 0), thickness=thickness)
        return (self.pil2tensor(highlighted_image_pil),)

    def highlight_mask_border_pil(self, image_pil, mask_pil, color=(255, 0, 0), thickness=2):
        """
        Draw a colored border along the mask outline on a PIL image.

        Args:
            image_pil (PIL.Image): Input image.
            mask_pil (PIL.Image): Mask image (0 background, 255 foreground).
            color (tuple): RGB color for border (default red).
            thickness (int): Border thickness in pixels.

        Returns:
            PIL.Image: Output image with colored border.
        """
        # Convert PIL → OpenCV (NumPy)
        img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
        mask = np.array(mask_pil.convert("L"))

        # Find mask contours
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Convert color from RGB → BGR for OpenCV
        bgr_color = (color[2], color[1], color[0])

        # Draw contours (red border)
        outlined = img.copy()
        cv2.drawContours(outlined, contours, -1, bgr_color, thickness)

        # Convert back to PIL
        outlined_pil = Image.fromarray(cv2.cvtColor(outlined, cv2.COLOR_BGR2RGB))
        return outlined_pil






