import sys
import os
import scipy 

import node_helpers

from comfy_extras.nodes_flux import PREFERED_KONTEXT_RESOLUTIONS
from comfy.utils import common_upscale
# from nodes import common_ksampler

from typing import List, Tuple, Optional
from PIL import Image, ImageDraw, ImageSequence, ImageOps
import numpy as np
from google import genai
from dotenv import load_dotenv
from io import BytesIO
import time
import json
import base64
import io
import yaml
import torch
import comfy 
import latent_preview
from copy import deepcopy
from dataclasses import dataclass
import concurrent.futures
from .utils import create_perceptual_difference_mask, parse_json, convert_mask_to_rgba, blur_mask, expand_mask, create_color_range_mask, bitvise_or_list_mask, align_images_phase_correlation, pil_to_tensor

DEBUG = False
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
    latent_image = latent["samples"]
    latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)

    if disable_noise:
        noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
    else:
        batch_inds = latent["batch_index"] if "batch_index" in latent else None
        noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)

    noise_mask = None
    if "noise_mask" in latent:
        noise_mask = latent["noise_mask"]

    callback = latent_preview.prepare_callback(model, steps)
    disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
    samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
                                  denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
                                  force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
    out = latent.copy()
    out["samples"] = samples
    return (out, )

@dataclass(frozen=True)
class SegmentationMask:
    """Class for storing segmentation mask information.

    Args:
        y0 (int): Top coordinate in [0..height-1]
        x0 (int): Left coordinate in [0..width-1]
        y1 (int): Bottom coordinate in [0..height-1]
        x1 (int): Right coordinate in [0..width-1]
        mask (np.ndarray): Mask array of shape [img_height, img_width] with values 0..255
        label (str): Class label for the mask
    """
    y0: int  # in [0..height - 1]
    x0: int  # in [0..width - 1]
    y1: int  # in [0..height - 1]
    x1: int  # in [0..width - 1]
    mask: np.ndarray  # [img_height, img_width] with values 0..255
    label: str

@dataclass(frozen=True)
class SegmentationMaskResult:
    """Class for storing segmentation mask results.
    
    Args:
        mask_image (Image.Image): Image with the segmentation mask
    """
    mask_image: Image.Image
    mask_image_rgba: Image.Image
    
class KontextSegmentor:
    def __init__(self, 
                pos_class_names: list = ['landscape'],
                pos_delta_e_threshold: list = [20.0],
                pos_color_diff_method: list = ['delta_e'],
                pos_diff_target_colors: list = [['red', 'blue']],
                pos_use_bbox=[False], 
                pos_use_expand_mask=[False],
                pos_expand_mask_pixel=[0],
                pos_use_blur_mask=[False],
                neg_class_names: list = ['building'],
                neg_delta_e_threshold: list = [20.0],
                neg_color_diff_method: list = ['color_range'],
                neg_diff_target_colors: list = [["red"]],
                neg_use_bbox=[False], 
                neg_use_expand_mask=[False],
                neg_expand_mask_pixel=[0], 
                neg_use_blur_mask=[False],
                
                ):
        """Initialize KontextSegmentor with segmentation parameters.
        
        Args:
            pos_class_names (list, optional): Positive class names. Defaults to ['landscape'].
            pos_delta_e_threshold (list, optional): Delta E thresholds for positive classes. Defaults to [20.0].
            pos_color_diff_method (list, optional): Color difference methods for positive classes. Defaults to ['delta_e'].
            pos_diff_target_colors (list, optional): Target colors for positive classes. Defaults to [['red', 'blue']].
            pos_use_bbox (list, optional): Whether to use bounding box for positive classes. Defaults to [False].
            pos_use_expand_mask (list, optional): Whether to expand mask for positive classes. Defaults to [False].
            pos_expand_mask_pixel (list, optional): Expansion size for positive masks. Defaults to [0].
            pos_use_blur_mask (list, optional): Whether to blur mask for positive classes. Defaults to [False].
            neg_class_names (list, optional): Negative class names. Defaults to ['building'].
            neg_delta_e_threshold (list, optional): Delta E thresholds for negative classes. Defaults to [20.0].
            neg_color_diff_method (list, optional): Color difference methods for negative classes. Defaults to ['color_range'].
            neg_diff_target_colors (list, optional): Target colors for negative classes. Defaults to [["red"]].
            neg_use_bbox (list, optional): Whether to use bounding box for negative classes. Defaults to [False].
            neg_use_expand_mask (list, optional): Whether to expand mask for negative classes. Defaults to [False].
            neg_expand_mask_pixel (list, optional): Expansion size for negative masks. Defaults to [0].
            neg_use_blur_mask (list, optional): Whether to blur mask for negative classes. Defaults to [False].
        """
        
        if DEBUG:
            print("[INFO] Loading FLUX Kontext Segmentor")
        
        # Positive classes
        self.pos_class_names         = pos_class_names
        self.pos_delta_e_threshold   = pos_delta_e_threshold
        self.pos_color_diff_method   = pos_color_diff_method
        self.pos_diff_target_colors  = pos_diff_target_colors
        self.pos_use_bbox            = pos_use_bbox
        self.pos_use_expand_mask     = pos_use_expand_mask
        self.pos_expand_mask_pixel   = pos_expand_mask_pixel
        self.pos_use_blur_mask       = pos_use_blur_mask
        
        # Negative classes
        self.neg_class_names         = neg_class_names
        self.neg_delta_e_threshold   = neg_delta_e_threshold
        self.neg_color_diff_method   = neg_color_diff_method
        self.neg_diff_target_colors  = neg_diff_target_colors
        self.neg_use_bbox            = neg_use_bbox
        self.neg_use_expand_mask     = neg_use_expand_mask
        self.neg_expand_mask_pixel   = neg_expand_mask_pixel
        self.neg_use_blur_mask       = neg_use_blur_mask 
        self.max_workers = 1
        # Check config
        self.assert_config()
        
    def assert_config(self):
        """Validate configuration parameters and ensure they are consistent.
        
        Raises:
            ValueError: If an invalid color difference method is specified
            AssertionError: If configuration lists have inconsistent lengths
        """
        # If debug mode, show config
        if DEBUG:
            print("[INFO] FLUX KONTEXT SEGMENTOR CONFIG")
            print('-'*10)
            print("  - pos_class_names: ", self.pos_class_names)
            print("  - pos_delta_e_threshold: ", self.pos_delta_e_threshold)
            print("  - pos_color_diff_method: ", self.pos_color_diff_method)
            print("  - pos_diff_target_colors: ", self.pos_diff_target_colors)
            print("  - pos_use_bbox: ", self.pos_use_bbox)
            print("  - pos_use_expand_mask: ", self.pos_use_expand_mask)
            print("  - pos_expand_mask_pixel: ", self.pos_expand_mask_pixel)
            print("  - pos_use_blur_mask: ", self.pos_use_blur_mask)
            print('-'*10)
            print("  - neg_class_names: ", self.neg_class_names)
            print("  - neg_delta_e_threshold: ", self.neg_delta_e_threshold)
            print("  - neg_color_diff_method: ", self.neg_color_diff_method)
            print("  - neg_diff_target_colors: ", self.neg_diff_target_colors)
            print("  - neg_use_bbox: ", self.neg_use_bbox)
            print("  - neg_use_expand_mask: ", self.neg_use_expand_mask)
            print("  - neg_expand_mask_pixel: ", self.neg_expand_mask_pixel)
            print("  - neg_use_blur_mask: ", self.neg_use_blur_mask)
        # Make sure config of positive classes is valid
        assert len(self.pos_class_names) == len(self.pos_delta_e_threshold)
        assert len(self.pos_class_names) == len(self.pos_color_diff_method)
        assert len(self.pos_class_names) == len(self.pos_diff_target_colors)
        assert len(self.pos_class_names) == len(self.pos_use_bbox)
        assert len(self.pos_class_names) == len(self.pos_use_blur_mask)
        assert len(self.pos_class_names) == len(self.pos_use_expand_mask)
        assert len(self.pos_class_names) == len(self.pos_expand_mask_pixel)
        # Make sure config of negative classes is valid
        assert len(self.neg_class_names) == len(self.neg_delta_e_threshold)
        assert len(self.neg_class_names) == len(self.neg_color_diff_method)
        assert len(self.neg_class_names) == len(self.neg_diff_target_colors)
        assert len(self.neg_class_names) == len(self.neg_use_bbox)
        assert len(self.neg_class_names) == len(self.neg_use_blur_mask)
        assert len(self.neg_class_names) == len(self.neg_use_expand_mask)
        assert len(self.neg_class_names) == len(self.neg_expand_mask_pixel)
        # Make sure color different method is valid
        for diff_id, diff_method in enumerate(self.pos_color_diff_method):
            if diff_method not in ['delta_e', 'color_range']:
                raise ValueError("Invalid color difference method '{}', current support only 'delta_e' and 'color_range'".format(diff_method))
            assert isinstance(self.pos_diff_target_colors[diff_id], list)
        for diff_id, diff_method in enumerate(self.neg_color_diff_method):
            if diff_method not in ['delta_e', 'color_range']:
                raise ValueError("Invalid color difference method '{}', current support only 'delta_e' and 'color_range'".format(diff_method))
            assert isinstance(self.neg_diff_target_colors[diff_id], list)
            

    def _plot_segmentation_white_masks(self,
                                    img: Image.Image,
                                    segmentation_masks: List[SegmentationMask]) -> Image.Image:
        """Create black and white mask image from segmentation masks.
        
        Args:
            img (Image.Image): Original image to get dimensions from
            segmentation_masks (List[SegmentationMask]): List of segmentation masks
            
        Returns:
            Image.Image: Binary mask image with white regions for segmented areas
        """
        mask_img = Image.new(img.mode, img.size, 0)
        draw = ImageDraw.Draw(mask_img)
        
        for mask in segmentation_masks:
            draw.rectangle(
                ((mask.x0, mask.y0), (mask.x1, mask.y1)),
                fill=255
            )
        
        return mask_img
    
    def kontext_inference(self, model, clip, vae, sampler_name, scheduler, steps, mask, image, prompt):
        if DEBUG: 
            image.save("input_image.png")  # Save for debugging
            print(f"Input image saved as 'input_image.png' with size: {image.size}")
        ori_width, ori_height = image.size
        # image = pil_to_tensor(image) if isinstance(image, Image.Image) else image
        output_images = []
        w, h = None, None

        excluded_formats = ['MPO']

        for i in ImageSequence.Iterator(image):
            i = node_helpers.pillow(ImageOps.exif_transpose, i)

            if i.mode == 'I':
                i = i.point(lambda i: i * (1 / 255))
            image = i.convert("RGB")

            if len(output_images) == 0:
                w = image.size[0]
                h = image.size[1]

            if image.size[0] != w or image.size[1] != h:
                continue

            image = np.array(image).astype(np.float32) / 255.0
            image = torch.from_numpy(image)[None,]
            output_images.append(image)

        if len(output_images) > 1 and image.format not in excluded_formats:
            output_image = torch.cat(output_images, dim=0)
        else:
            output_image = output_images[0]
        image = output_image
        
        # Flux Kontext image scale
        aspect_ratio = ori_width / ori_height
        _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
        image = common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
        
        # Ensure all operations are done without gradient tracking
        with torch.no_grad():
            # Clip text encode 
            tokens = clip.tokenize(prompt)
            clip_encode_conditioning = clip.encode_from_tokens_scheduled(tokens)

            # FluxGuidance
            conditioning = node_helpers.conditioning_set_values(clip_encode_conditioning, {"guidance": 2.5})
            
            # Reference Latent - Ensure latent is properly detached
            negative_conditioning = conditioning 
            # Kontext Inpainting Conditioning
            x = (image.shape[1] // 8) * 8
            y = (image.shape[2] // 8) * 8
            mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
             
            orig_pixels = image
            pixels = orig_pixels.clone()
            if pixels.shape[1] != x or pixels.shape[2] != y:
                x_offset = (pixels.shape[1] % 8) // 2
                y_offset = (pixels.shape[2] % 8) // 2
                pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
                mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
                
            m = (1.0 - mask.round()).squeeze(1) 
            for i in range(3):
                pixels[:,:,:,i] -= 0.5
                pixels[:,:,:,i] *= m
                pixels[:,:,:,i] += 0.5
                
            concat_latent = vae.encode(pixels)
            orig_latent = vae.encode(orig_pixels)
            
            pixel_latent = vae.encode(orig_pixels[:,:,:,:3])  
            encoded_latent = {"samples": pixel_latent}
            
            c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
                                                                "concat_mask": mask})

            positive_conditioning = node_helpers.conditioning_set_values(c, {"reference_latents": [encoded_latent["samples"]]}, append=True)
            
            out_latent = {}
            out_latent["samples"] = orig_latent
            # out_latent["noise_mask"] = mask
            
        
        with torch.no_grad():
        # Ksampler - Run outside no_grad to allow proper gradient handling
            decoder_latent_tuple = common_ksampler(
                model=model,
                seed=0,
                steps=steps,
                cfg=1.0, 
                sampler_name=sampler_name,
                scheduler=scheduler,
                positive=positive_conditioning,
                negative=negative_conditioning,
                latent=out_latent,
            )
        
        # Extract the dictionary from the tuple (common_ksampler returns (out, ))
        decoder_latent = decoder_latent_tuple[0]
        
        # VAE Decode
        with torch.no_grad():
            decoded_image = vae.decode(decoder_latent["samples"])

        if DEBUG:
            print(f"Shape decoded image: {len(decoded_image.shape)}")

        output_tensor = decoded_image[0]
        i = 255. * output_tensor.cpu().numpy()
        img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
        output_image = img
        output_image = output_image.resize((ori_width, ori_height), Image.Resampling.LANCZOS)
        if DEBUG:
            print(f"Output image created successfully with size: {output_image.size}")
            output_image.save(f"output_image.png")
        return output_image
    
    def request_kontext(self, 
                        model,
                        clip,
                        vae,
                        sampler_name,
                        scheduler,
                        steps,
                        mask,
                        task_type: str, 
                        task_class_name: str,
                        task_response_modalities: List[str],
                        task_output_type: str,
                        task_data: List[Image.Image],
                        attempts: int = 1,
                        apply_quality_filtering: bool = True) -> dict:
        """Call Kontext to process a task.
        
        Args:
            task_type (str): Type of the task (e.g., 'segment', 'recolor')
            task_class_name (str): Class name for the task
            task_model_name (str): Model name to use for the task
            task_response_modalities (List[str]): Expected response modalities
            task_output_type (str): Expected output type
            task_data (List[Image.Image]): Input data for the task
            attempts (int, optional): Number of attempts to call API. Defaults to 3.
            
        Returns:
            dict: Response from Kontext API
        """
        attempt_time = 0
        if DEBUG:
            print('-'*10)
            print('  - Task type:', task_type)
            print('  - Task class name:', task_class_name)
            print('  - Task response modalities:', task_response_modalities)
            print('  - Task output type:', task_output_type)
        if apply_quality_filtering and task_type.startswith('recolor-'):
            best_response_image = None
            best_translate_dx_dy = 999
        while attempt_time < attempts:
            attempt_time += 1
            start_time = time.time()
            response_image = self.kontext_inference(model, clip, vae, sampler_name, scheduler, steps, mask, task_data[0], task_data[1])
            end_time = time.time()
            if DEBUG:
                print(f"[INFO] Attempt {attempt_time}/{attempts} took {end_time - start_time:.2f} seconds")
            
            if apply_quality_filtering and task_type.startswith('recolor-'):
                original_image = task_data[0]
                _, translate_dx_dy = align_images_phase_correlation(original_image, response_image)
                max_translate_dx_dy = max(abs(translate_dx_dy[0]), abs(translate_dx_dy[1]))
                if max_translate_dx_dy > 1.0:
                    if max_translate_dx_dy < best_translate_dx_dy:
                        best_translate_dx_dy = max_translate_dx_dy
                        best_response_image = response_image
                    print(f"Recolored image quality is low, translation: {translate_dx_dy}, retrying ...")
                    continue
                else:
                    best_response_image = response_image
                    best_translate_dx_dy = max_translate_dx_dy
                    print(f"Recolored image quality is good, translation: {translate_dx_dy}, using this result")
                    break
        if apply_quality_filtering and task_type.startswith('recolor-'):
            response_image = best_response_image
            return {
                "task_type": task_type,
                "task_class_name": task_class_name,
                "task_response": response_image,
            }
    
    def submit_task(self, model, clip, vae, sampler_name, scheduler, steps, mask, task):
        """Submit a task to kontext.
        
        Args:
            task (dict): Task parameters to submit to flux kontext
            
        Returns:
            dict: Result from flux kontext
        """
        return self.request_kontext(model, clip, vae, sampler_name, scheduler, steps, mask, **task)

    def create_and_submit_tasks(self, model, clip, vae, sampler_name, scheduler, steps, mask, image: Image.Image) -> List[dict]:
        """Generate tasks for segmentation and recoloring and submit to Flux Kontext.
        
        Args:
            image (Image.Image): Image to process
            mask (Image.Image): Mask to apply to the image
        Returns:
            List[dict]: List of Flux Kontext results
        """
        tasks = []

        # 1. Recolor task for positive object
        resize_image = self.resize_keep_aspect(image)
        for class_id, class_name in enumerate(self.pos_class_names):
            for target_color in self.pos_diff_target_colors[class_id]:
                # prompt = f"Change {class_name} color to {target_color}."
                prompt = f"ONLY Change the color of the {class_name} object to {target_color}, keep all other objects unchanged. "
                tasks.append({
                    "task_type": "recolor-{}".format(target_color),
                    "task_class_name": class_name,
                    "task_response_modalities": ['Image'],
                    "task_output_type": "image",
                    "task_data": [deepcopy(resize_image), prompt],
                    "attempts": 1
                    })

        # 2. Recolor task for negative object
        for class_id, class_name in enumerate(self.neg_class_names):
            for target_color in self.neg_diff_target_colors[class_id]:
                prompt = f"ONLY Change the color of the {class_name} object to {target_color}, keep all other objects unchanged."
                # prompt = f"ONLY change the color of the {class_name} in the provided image to {target_color}. Keep rest AND image size unchanged."
                tasks.append({
                    "task_type": "recolor-{}".format(target_color),
                    "task_class_name": class_name,
                    "task_response_modalities": ['Image'],
                    "task_output_type": "image",
                    "task_data": [deepcopy(resize_image), prompt],
                    "attempts": 1
                    })

        # 4. Parallel call
        kontext_results = []
        for task in tasks:
            kontext_results.append(
                self.submit_task(model, clip, vae, sampler_name, scheduler, steps, mask, task)
            )
            
        return list(kontext_results)
            

    def resize_keep_aspect(self,
                            image: Image.Image,
                            max_size: int = 1024):
        """Resize image while maintaining aspect ratio.
        
        Args:
            image (Image.Image): Image to resize
            max_size (int, optional): Maximum dimension (width or height). Defaults to 1024.
            
        Returns:
            Image.Image: Resized image
        """
        width, height = image.size
        ratio = min(max_size/width, max_size/height)
        new_width = int(width*ratio)
        new_height = int(height*ratio)
        return image.resize((new_width, new_height))

    def change_object_color(self,
                        image: Image.Image,
                        class_name: str,
                        color_diff_method: str,
                        delta_e_threshold: float,
                        use_bbox: bool,
                        use_expand_mask: bool,
                        expand_mask_pixel: int,
                        use_blur_mask: bool,
                        kontext_results: List[dict],
                        target_color: str
                        ) -> Tuple[Image.Image, Image.Image]:
        """Change the color of objects of a specific class and generate a mask of the changes.
        
        Args:
            image (Image.Image): Original image
            class_name (str): Name of the class to modify
            color_diff_method (str): Method for color difference detection ('delta_e' or 'color_range')
            delta_e_threshold (float): Threshold for delta_e method
            use_bbox (bool): Whether to clip mask by bounding box
            use_expand_mask (bool): Whether to expand the mask boundary
            expand_mask_pixel (int): Number of pixels to expand mask
            use_blur_mask (bool): Whether to blur the mask
            segmentation_masks (List[SegmentationMask]): List of segmented masks
            kontext_results (List[dict]): Results from Kontext calls
            target_color (str): Target color for object
            
        Returns:
            Tuple[Image.Image, Image.Image]: Modified image and color mask
            
        Raises:
            ValueError: If no modified image received or color difference method is invalid
        """
        org_size = image.size
        resize_image = self.resize_keep_aspect(image)
        
        modified_image = [res for res in kontext_results if res["task_type"] == "recolor-{}".format(target_color) and res["task_class_name"] == class_name]
        assert len(modified_image) == 1 # Only 1 result
        modified_image = modified_image[0]["task_response"]

        if modified_image is None:
            raise ValueError("No modified image received")

        if DEBUG:
            modified_image.save("modified_image_{}_{}.png".format(class_name, target_color))
        
        if color_diff_method == 'delta_e':
            color_mask = create_perceptual_difference_mask(resize_image,
                                                        modified_image,
                                                        delta_e_threshold=delta_e_threshold)
        elif color_diff_method == 'color_range':
            color_mask = create_color_range_mask(resize_image,
                                                modified_image,
                                                target_color)
        else:
            raise ValueError("Invalid color difference calculation method")

        # Resize mask and image back to normal
        modified_image = modified_image.resize(org_size)
        color_mask = color_mask.resize(org_size, Image.Resampling.NEAREST)
        
        # Expand mask 
        if use_expand_mask:
            color_mask = expand_mask(color_mask, radius=int(expand_mask_pixel))
            if DEBUG:
                color_mask.save("expanded_mask_{}.png".format(class_name))
        
        # Blur mask
        if use_blur_mask:
            color_mask = blur_mask(color_mask, radius=5)
            if DEBUG:
                color_mask.save("blurred_mask_{}.png".format(class_name))

        # Using NEAREST interpolation to avoid new value in segmentation mask
        return modified_image, color_mask
    
    def create_combine_mask(self, 
                            mask: Image.Image,
                            image: Image.Image,
                            list_class_names: list,
                            list_delta_e_threshold: list,
                            list_color_diff_method: list,
                            list_mix_target_colors: list,
                            list_use_bbox: list,
                            list_use_expand_mask: list,
                            list_expand_mask_pixel: list,
                            list_use_blur_mask: list,
                            kontext_results: List[dict],
                            ):
        """Generate combined mask from multiple classes and their configurations.
        
        Args:
            mask (Image.Image): Mask to apply to the image
            image (Image.Image): Original image
            list_class_names (list): List of class names
            list_delta_e_threshold (list): List of delta_e thresholds
            list_color_diff_method (list): List of color difference methods
            list_mix_target_colors (list): List of target colors for each class
            list_use_bbox (list): List of bbox usage flags
            list_use_expand_mask (list): List of mask expansion flags
            list_expand_mask_pixel (list): List of mask expansion sizes
            list_use_blur_mask (list): List of mask blur flags
            segmentation_masks (List[SegmentationMask]): List of segmentation masks
            kontext_results (List[dict]): Results from Flux Kontext calls
            
        Returns:
            Image.Image: Combined mask from all specified classes
        """
        # Create a combined mask for positive all classes
        all_combined_mask = None
        
        # Process each class in the positive list
        for class_id, class_name in enumerate(list_class_names):
            combine_masks = []
            for color in list_mix_target_colors[class_id]:
                modified_image, color_mask = self.change_object_color(deepcopy(image),
                                                                    class_name,
                                                                    target_color=color,
                                                                    delta_e_threshold = list_delta_e_threshold[class_id],
                                                                    color_diff_method = list_color_diff_method[class_id],
                                                                    use_bbox = list_use_bbox[class_id],
                                                                    use_expand_mask = list_use_expand_mask[class_id],
                                                                    use_blur_mask = list_use_blur_mask[class_id],
                                                                    expand_mask_pixel = list_expand_mask_pixel[class_id],
                                                                    kontext_results = kontext_results
                                                                    )
                combine_masks.append(color_mask)
                if DEBUG:
                    color_mask.save(f"{class_name}_{color}_mask.png")

            color_mask = bitvise_or_list_mask(combine_masks)
            if DEBUG:
                color_mask.save(f"{class_name}_color_mask.png")

            # For the first class, initialize the combined mask
            # mask_np = np.array(mask.convert("L"))
            if all_combined_mask is None:
                all_combined_mask = color_mask
            else:
                # Combine masks by taking the maximum value at each pixel
                all_combined_mask_array = np.maximum(np.array(all_combined_mask), np.array(color_mask))
                all_combined_mask = Image.fromarray(all_combined_mask_array)

        mask_np = np.array(mask.convert("L"))
        # Combine the original mask with the combined mask
        all_combined_mask = ((mask_np > 0) & (np.array(all_combined_mask) > 0)).astype(np.uint8) * 255
        all_combined_mask = Image.fromarray(all_combined_mask).convert("L")
        
        return all_combined_mask

    def __call__(self, model, clip, vae, sampler_name, scheduler, steps, mask, image) -> SegmentationMaskResult:
        """Generate a segmentation mask for the input image.
        
         Args:
            model: ComfyUI model instance
            clip: CLIP model instance
            vae: VAE model instance
            sampler_name: Name of the sampler to use
            scheduler: Scheduler for sampling
            steps: Number of steps for sampling
            mask: Mask to apply to the image
            image: Path to input image (str) or PIL Image instance
            
        Returns:
            SegmentationMaskResult: Object containing the resulting mask image
            
        Raises:
            ValueError: If image path is invalid or processing fails
        """
        
        if isinstance(image, torch.Tensor):
            # Convert from (1, H, W, C) or (B, H, W, C) to (H, W, C)
            if image.ndim == 4:
                image_np = image[0].cpu().numpy()
            elif image.ndim == 3:
                image_np = image.cpu().numpy()
            else:
                raise ValueError(f"Unsupported image tensor shape: {image.shape}")
            # Convert from float [0,1] to uint8 [0,255]
            image_np = (image_np * 255).clip(0, 255).astype("uint8")
            # Convert to PIL Image
            pil_image = Image.fromarray(image_np)
        elif isinstance(image, np.ndarray):
            if image.dtype == np.float32 or image.max() <= 1.0:
                image = (image * 255).clip(0, 255).astype("uint8")
            pil_image = Image.fromarray(image)
        elif isinstance(image, Image.Image):
            pil_image = image
        else:
            raise ValueError(f"Unsupported image input type: {type(image)}")
        
        if DEBUG:
            pil_image.save("org_img.png")

        results = self.create_and_submit_tasks(model, clip, vae, sampler_name, scheduler, steps, mask, pil_image)

        if isinstance(mask, torch.Tensor):
            # Convert mask tensor to PIL Image
            if mask.ndim == 4:
                mask_np = mask[0, 0].cpu().numpy()
            elif mask.ndim == 3:
                mask_np = mask[0].cpu().numpy()
            elif mask.ndim == 2:
                mask_np = mask.cpu().numpy()
            else:
                raise ValueError(f"Unsupported mask tensor shape: {mask.shape}")
            
            if mask_np.dtype == np.float32 or mask_np.dtype == np.float64:
                if mask_np.max() <= 1.0:
                    mask_np = (mask_np * 255).clip(0, 255).astype("uint8")
                else:   
                    mask_np = mask_np.clip(0, 255).astype("uint8")
            mask_image = Image.fromarray(mask_np).convert("L")
        elif isinstance(mask, np.ndarray):
            if mask.ndim == 3:
                if mask.shape[2] == 1:
                    mask = mask[:, :, 0]
                elif mask.shape[0] == 1:
                    mask = mask[0]
                else:
                    mask = mask[:, :, 0] 
            
            if mask.dtype == np.float32 or mask.dtype == np.float64:
                if mask.max() <= 1.0:
                    mask = (mask * 255).clip(0, 255).astype("uint8")
                else:
                    mask = mask.clip(0, 255).astype("uint8")
            mask_image = Image.fromarray(mask).convert("L")
        elif isinstance(mask, Image.Image):
            mask_image = mask.convert("L")
        else:
            raise ValueError(f"Unsupported mask input type: {type(mask)}")
        
        if DEBUG:
            mask_image.save("mask_image.png")
        pos_combined_mask = self.create_combine_mask(
                                                mask = mask_image,
                                                image = pil_image,
                                                list_class_names = self.pos_class_names,
                                                list_delta_e_threshold = self.pos_delta_e_threshold,
                                                list_color_diff_method = self.pos_color_diff_method,
                                                list_mix_target_colors = self.pos_diff_target_colors,
                                                list_use_bbox = self.pos_use_bbox,
                                                list_use_blur_mask = self.pos_use_blur_mask,
                                                list_use_expand_mask = self.pos_use_expand_mask,
                                                list_expand_mask_pixel = self.pos_expand_mask_pixel,
                                                kontext_results = results
                                                )
        neg_combined_mask = None
        # Create a combined mask for all negative classes
        if len(self.neg_class_names) > 0:
            neg_combined_mask = self.create_combine_mask(
                                                    mask = mask_image,
                                                    image = pil_image,
                                                    list_class_names = self.neg_class_names,
                                                    list_delta_e_threshold = self.neg_delta_e_threshold,
                                                    list_color_diff_method = self.neg_color_diff_method,
                                                    list_mix_target_colors = self.neg_diff_target_colors,
                                                    list_use_bbox = self.neg_use_bbox,
                                                    list_use_blur_mask = self.neg_use_blur_mask,
                                                    list_use_expand_mask = self.neg_use_expand_mask,
                                                    list_expand_mask_pixel = self.neg_expand_mask_pixel,
                                                    kontext_results = results
                                                    )
        if neg_combined_mask is not None:
            # Final mask = pos_mask - neg_mask
            unified_mask = np.array(pos_combined_mask, dtype = np.float32) - np.array(neg_combined_mask, dtype = np.float32)
            unified_mask[unified_mask < 0] = 0
            unified_mask = Image.fromarray(unified_mask.astype("uint8"))
        else:
            unified_mask = pos_combined_mask
        
        if DEBUG:
            pos_combined_mask.save("pos_combined_mask.png")
            if neg_combined_mask is not None:
                neg_combined_mask.save("neg_combined_mask.png")
            unified_mask.save("unified_mask.png")

        # Convert to RGBA for final output
        rgba_mask = convert_mask_to_rgba(unified_mask)
        
        return SegmentationMaskResult(rgba_mask, unified_mask)
        