import os
from PIL import Image
import numpy as np
import torch
import numpy as np
from typing import Any, Dict, Optional, Tuple, List

from .utils import create_perceptual_difference_mask, parse_json, convert_mask_to_rgba, blur_mask, expand_mask, create_color_range_mask, bitvise_or_list_mask, align_images_phase_correlation, pil_to_tensor
from .gemini_segmentor import GeminiSegmentor, GeminiMaskResult 
from .kontext_segmentor import KontextSegmentor

import comfy 

class SegmentationProcessor:
    """
    ComfyUI Node for Gemini Segmentation.
    """

    @classmethod
    def INPUT_TYPES(cls) -> Dict[str, Any]:
        return {
            "required": {
                "image": ("IMAGE", {}),
                "model_segment": (["Gemini", "Kontext"] , {"default": "Gemini"} ),
                "api_key": ("STRING", {"default": "", "multiline": False}),
                "model_segmentation": ("STRING", {"default": "gemini-2.5-pro-exp-03-25"}),
                "model_color": ("STRING", {"default": "gemini-2.0-flash-exp-image-generation"}),
                "segmentor_config": ("SEGMENTOR_CONFIG",),  
                "max_workers": ("INT", {"default": 1, "min": 1}),
            }, 
            "optional": {
                "model": ("MODEL", ), 
                "clip": ("CLIP", ), 
                "vae": ("VAE", ), 
                "sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
                "scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}),
                "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
                "mask": ("MASK", )
            }
        }
    RETURN_TYPES = ("IMAGE","IMAGE",)
    FUNCTION = "execute"

    CATEGORY = "ReImage AI"

    def execute(
        self,
        image,
        model_segment,
        api_key,
        model_segmentation,
        model_color,
        segmentor_config, 
        max_workers, 
        model = None, 
        clip = None, 
        vae = None, 
        sampler_name = None, 
        scheduler = None, 
        steps = None, 
        mask = None
    ):
        
        pos_class_names = segmentor_config["pos_class_names"]
        pos_delta_e_threshold = segmentor_config["pos_delta_e_threshold"]
        pos_color_diff_method = segmentor_config["pos_color_diff_method"]
        pos_diff_target_colors = segmentor_config["pos_diff_target_colors"]
        pos_use_bbox = segmentor_config["pos_use_bbox"]
        pos_use_expand_mask = segmentor_config["pos_use_expand_mask"]
        pos_expand_mask_pixel = segmentor_config["pos_expand_mask_pixel"]
        pos_use_blur_mask = segmentor_config["pos_use_blur_mask"]
        neg_class_names = segmentor_config["neg_class_names"]
        neg_delta_e_threshold = segmentor_config["neg_delta_e_threshold"]
        neg_color_diff_method = segmentor_config["neg_color_diff_method"]
        neg_diff_target_colors = segmentor_config["neg_diff_target_colors"]
        neg_use_bbox = segmentor_config["neg_use_bbox"]
        neg_use_expand_mask = segmentor_config["neg_use_expand_mask"]
        neg_expand_mask_pixel = segmentor_config["neg_expand_mask_pixel"]
        neg_use_blur_mask = segmentor_config["neg_use_blur_mask"]
        # model_segment = "Kontext"
        print("model_segment", model_segment)
        if model_segment == "Gemini": 
            try:
                segmentor = GeminiSegmentor(
                    api_key=api_key if api_key.strip() else None,
                    model_segmentation=model_segmentation,
                    model_color=model_color,
                    pos_class_names=pos_class_names,
                    pos_delta_e_threshold=pos_delta_e_threshold,
                    pos_color_diff_method=pos_color_diff_method,
                    pos_diff_target_colors=pos_diff_target_colors,
                    pos_use_bbox=pos_use_bbox,
                    pos_use_blur_mask=pos_use_blur_mask,
                    pos_use_expand_mask=pos_use_expand_mask,
                    pos_expand_mask_pixel=pos_expand_mask_pixel,
                    neg_class_names=neg_class_names,
                    neg_delta_e_threshold=neg_delta_e_threshold,
                    neg_color_diff_method=neg_color_diff_method,
                    neg_diff_target_colors=neg_diff_target_colors,
                    neg_use_bbox=neg_use_bbox,
                    neg_use_blur_mask=neg_use_blur_mask,
                    neg_use_expand_mask=neg_use_expand_mask,
                    neg_expand_mask_pixel=neg_expand_mask_pixel,
                    max_workers=max_workers
                )
                result = segmentor(image)
                rgba_tensor = pil_to_tensor(result.mask_image) if isinstance(result.mask_image, Image.Image) else result.mask_image
                rgb_tensor = pil_to_tensor(result.mask_image_rgb) if isinstance(result.mask_image_rgb, Image.Image) else result.mask_image_rgb
                return (rgba_tensor, rgb_tensor)
            except Exception as e:
                raise f"SegmentorNode Error: {e}"
        
        elif model_segment == "Kontext":
            if model is None or clip is None or vae is None or sampler_name is None or scheduler is None or steps is None or mask is None:
                raise ValueError("Model, CLIP, VAE, sampler_name, scheduler, mask, and steps must be provided for Kontext segmentation.")
            try: 
                segmentor = KontextSegmentor(
                    pos_class_names=pos_class_names,
                    pos_delta_e_threshold=pos_delta_e_threshold,
                    pos_color_diff_method=pos_color_diff_method,
                    pos_diff_target_colors=pos_diff_target_colors,
                    pos_use_bbox=pos_use_bbox,
                    pos_use_expand_mask=pos_use_expand_mask,
                    pos_expand_mask_pixel=pos_expand_mask_pixel,
                    pos_use_blur_mask=pos_use_blur_mask,
                    neg_class_names=neg_class_names,
                    neg_delta_e_threshold=neg_delta_e_threshold,
                    neg_color_diff_method=neg_color_diff_method,
                    neg_diff_target_colors=neg_diff_target_colors,
                    neg_use_bbox=neg_use_bbox,
                    neg_use_expand_mask=neg_use_expand_mask,
                    neg_expand_mask_pixel=neg_expand_mask_pixel,
                    neg_use_blur_mask=neg_use_blur_mask
                )

                result = segmentor(model, clip, vae, sampler_name, scheduler, steps, mask, image)
                rgba_tensor = pil_to_tensor(result.mask_image) if isinstance(result.mask_image, Image.Image) else result.mask_image
                rgb_tensor = pil_to_tensor(result.mask_image_rgba) if isinstance(result.mask_image_rgba, Image.Image) else result.mask_image_rgba
                return (rgba_tensor, rgb_tensor)
            except Exception as e:
                raise RuntimeError(f"Failed to process image with Kontext Segmentor: {e}")
