import cv2 
from PIL import Image
import numpy as np
import torch 
from typing import Union, List
from .utils import align_images_phase_correlation, create_perceptual_difference_mask

class RefineMask: 
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image_1": ("IMAGE",),
                "image_2": ("IMAGE",),
                "color_diff_threshold": ("FLOAT", {"default": 15.0, "min": 0.0, "max": 100.0, "step": 0.1, "tooltip": "The threshold for color difference to consider a pixel as changed."}),
            },
            "optional": {
                "mask": ("MASK",),
            }
        }
    
    RETURN_TYPES = ("MASK",)
    FUNCTION = "execute"
    CATEGORY = "ReImage AI"
    
    def tensor_to_image(self, tensor):
        # Ensure tensor is in the right format (H, W, C)
        # if len(tensor.shape) == 4:
            # If batch dimension exists, take the first image
        tensor = tensor[0]
        image = tensor.mul(255).clamp(0, 255).byte().cpu()
        if len(image.shape) == 3 and image.shape[-1] == 3:
            image = image.numpy()
        elif len(image.shape) == 3 and image.shape[-1] == 4:
            # If the image has an alpha channel, ignore it
            image = image[..., :3].numpy()
        elif len(image.shape) == 2:
            image = image.numpy()
        # image = image[..., [0, 1, 2]].numpy()
        return image
    
    def pil2tensor(self, image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        elif image.mode not in ['RGB', 'L']:
            image = image.convert('RGB')

        # Convert to numpy array and normalize to [0, 1]
        img_array = np.array(image).astype(np.float32) / 255.0
        
        # Return tensor with shape [1, H, W, 3]
        return torch.from_numpy(img_array)[None,]

    def execute(self, image_1, image_2, color_diff_threshold, mask=None):
        image1_pil = Image.fromarray(self.tensor_to_image(image_1))
        image2_pil = Image.fromarray(self.tensor_to_image(image_2))
        
        _, translate_dx_dy = align_images_phase_correlation(image1_pil, image2_pil)
        max_translate_dx_dy = max(abs(translate_dx_dy[0]), abs(translate_dx_dy[1]))
        
        color_mask = create_perceptual_difference_mask(
            image1_pil, 
            image2_pil,
            delta_e_threshold=color_diff_threshold,
        )
        
        if mask is not None: 
            input_mask_pil = Image.fromarray(self.tensor_to_image(mask))
            input_mask_np = np.array(input_mask_pil.convert("L"))
            final_mask = ((input_mask_np > 0) & (np.array(color_mask) > 0)).astype(np.uint8) * 255
            num_final_mask_pixels = np.sum(final_mask > 0)
            num_input_mask_pixels = np.sum(input_mask_np > 0)
            if num_final_mask_pixels >= 0.75 * num_input_mask_pixels:
                final_mask_pil = Image.fromarray(final_mask, mode="L")
            else:
                final_mask_pil = Image.fromarray(input_mask_np, mode="L")
        else:
            final_mask_pil = color_mask

        final_mask = self.pil2tensor(final_mask_pil)

        return final_mask
