import os
from typing import List, Tuple, Optional
from PIL import Image, ImageDraw
import numpy as np
from google import genai
from dotenv import load_dotenv
from io import BytesIO
from google.genai import types
import time
import json
import base64
import io
import yaml
import torch
from copy import deepcopy
from dataclasses import dataclass
import concurrent.futures
from .utils import create_perceptual_difference_mask, parse_json, convert_mask_to_rgba, blur_mask, expand_mask, create_color_range_mask, bitvise_or_list_mask, align_images_phase_correlation

DEBUG = False

@dataclass(frozen=True)
class SegmentationMask:
    """Class for storing segmentation mask information.

    Args:
        y0 (int): Top coordinate in [0..height-1]
        x0 (int): Left coordinate in [0..width-1]
        y1 (int): Bottom coordinate in [0..height-1]
        x1 (int): Right coordinate in [0..width-1]
        mask (np.ndarray): Mask array of shape [img_height, img_width] with values 0..255
        label (str): Class label for the mask
    """
    y0: int  # in [0..height - 1]
    x0: int  # in [0..width - 1]
    y1: int  # in [0..height - 1]
    x1: int  # in [0..width - 1]
    mask: np.ndarray  # [img_height, img_width] with values 0..255
    label: str

@dataclass(frozen=True)
class GeminiMaskResult:
    """Class for storing segmentation mask result.

    Args:
        mask_image (Image.Image): PIL Image containing the mask
    """
    mask_image: Image.Image  # PIL Image
    mask_image_rgb: Image.Image  # PIL Image

class GeminiSegmentor:
    def __init__(self,
                api_key: Optional[str] = None,
                model_segmentation: str = "gemini-2.5-pro-exp-03-25",
                model_color: str = "gemini-2.0-flash-exp-image-generation",
                pos_class_names: list = ['landscape'],
                pos_delta_e_threshold: list = [20.0],
                pos_color_diff_method: list = ['delta_e'],
                pos_diff_target_colors: list = [['red', 'blue']],
                pos_use_bbox=[False], 
                pos_use_expand_mask=[False],
                pos_expand_mask_pixel=[0],
                pos_use_blur_mask=[False],
                neg_class_names: list = ['building'],
                neg_delta_e_threshold: list = [20.0],
                neg_color_diff_method: list = ['color_range'],
                neg_diff_target_colors: list = [["red"]],
                neg_use_bbox=[False], 
                neg_use_expand_mask=[False],
                neg_expand_mask_pixel=[0], 
                neg_use_blur_mask=[False],
                max_workers=4
                ):
        """Initialize GeminiSegmentor with API key and segmentation parameters.
        
        Args:
            api_key (Optional[str], optional): Gemini API key. If None, will try to load from .env file. Defaults to None.
            model_segmentation (str, optional): Gemini model name for segmentation task. Defaults to "gemini-2.5-pro-exp-03-25".
            model_color (str, optional): Gemini model name for re-color task. Defaults to "gemini-2.0-flash-exp-image-generation".
            pos_class_names (list, optional): Positive class names. Defaults to ['landscape'].
            pos_delta_e_threshold (list, optional): Delta E thresholds for positive classes. Defaults to [20.0].
            pos_color_diff_method (list, optional): Color difference methods for positive classes. Defaults to ['delta_e'].
            pos_diff_target_colors (list, optional): Target colors for positive classes. Defaults to [['red', 'blue']].
            pos_use_bbox (list, optional): Whether to use bounding box for positive classes. Defaults to [False].
            pos_use_expand_mask (list, optional): Whether to expand mask for positive classes. Defaults to [False].
            pos_expand_mask_pixel (list, optional): Expansion size for positive masks. Defaults to [0].
            pos_use_blur_mask (list, optional): Whether to blur mask for positive classes. Defaults to [False].
            neg_class_names (list, optional): Negative class names. Defaults to ['building'].
            neg_delta_e_threshold (list, optional): Delta E thresholds for negative classes. Defaults to [20.0].
            neg_color_diff_method (list, optional): Color difference methods for negative classes. Defaults to ['color_range'].
            neg_diff_target_colors (list, optional): Target colors for negative classes. Defaults to [["red"]].
            neg_use_bbox (list, optional): Whether to use bounding box for negative classes. Defaults to [False].
            neg_use_expand_mask (list, optional): Whether to expand mask for negative classes. Defaults to [False].
            neg_expand_mask_pixel (list, optional): Expansion size for negative masks. Defaults to [0].
            neg_use_blur_mask (list, optional): Whether to blur mask for negative classes. Defaults to [False].
            max_workers (int, optional): Maximum number of concurrent workers. Defaults to 4.
            
        Raises:
            ValueError: If no API key is provided and GEMINI_API_KEY not found in .env file
        """
        if api_key is None:
            load_dotenv(override=True)
            api_key = os.getenv("GEMINI_API_KEY")
            if api_key is None:
                raise ValueError("No API key provided and GEMINI_API_KEY not found in .env file")
        
        if DEBUG:
            print('[INFO] Loading Gemini Segmentor')

        self.client = genai.Client(api_key=api_key)
        self.model_segmentation = model_segmentation
        self.model_color = model_color
        self.max_workers = max_workers
        
        # Positive classes
        self.pos_class_names         = pos_class_names
        self.pos_delta_e_threshold   = pos_delta_e_threshold
        self.pos_color_diff_method   = pos_color_diff_method
        self.pos_diff_target_colors  = pos_diff_target_colors
        self.pos_use_bbox            = pos_use_bbox
        self.pos_use_expand_mask     = pos_use_expand_mask
        self.pos_expand_mask_pixel   = pos_expand_mask_pixel
        self.pos_use_blur_mask       = pos_use_blur_mask
        
        # Negative classes
        self.neg_class_names         = neg_class_names
        self.neg_delta_e_threshold   = neg_delta_e_threshold
        self.neg_color_diff_method   = neg_color_diff_method
        self.neg_diff_target_colors  = neg_diff_target_colors
        self.neg_use_bbox            = neg_use_bbox
        self.neg_use_expand_mask     = neg_use_expand_mask
        self.neg_expand_mask_pixel   = neg_expand_mask_pixel
        self.neg_use_blur_mask       = neg_use_blur_mask 
        
        # Check config
        self.assert_config()

    def assert_config(self):
        """Validate configuration parameters and ensure they are consistent.
        
        Raises:
            ValueError: If an invalid color difference method is specified
            AssertionError: If configuration lists have inconsistent lengths
        """
        # If debug mode, show config
        if DEBUG:
            print("[INFO] GEMINI CONFIG")
            print("  - model_segmentation: ", self.model_segmentation)
            print("  - model_color: ", self.model_color)
            print('-'*10)
            print("  - pos_class_names: ", self.pos_class_names)
            print("  - pos_delta_e_threshold: ", self.pos_delta_e_threshold)
            print("  - pos_color_diff_method: ", self.pos_color_diff_method)
            print("  - pos_diff_target_colors: ", self.pos_diff_target_colors)
            print("  - pos_use_bbox: ", self.pos_use_bbox)
            print("  - pos_use_expand_mask: ", self.pos_use_expand_mask)
            print("  - pos_expand_mask_pixel: ", self.pos_expand_mask_pixel)
            print("  - pos_use_blur_mask: ", self.pos_use_blur_mask)
            print('-'*10)
            print("  - neg_class_names: ", self.neg_class_names)
            print("  - neg_delta_e_threshold: ", self.neg_delta_e_threshold)
            print("  - neg_color_diff_method: ", self.neg_color_diff_method)
            print("  - neg_diff_target_colors: ", self.neg_diff_target_colors)
            print("  - neg_use_bbox: ", self.neg_use_bbox)
            print("  - neg_use_expand_mask: ", self.neg_use_expand_mask)
            print("  - neg_expand_mask_pixel: ", self.neg_expand_mask_pixel)
            print("  - neg_use_blur_mask: ", self.neg_use_blur_mask)
        # Make sure config of positive classes is valid
        assert len(self.pos_class_names) == len(self.pos_delta_e_threshold)
        assert len(self.pos_class_names) == len(self.pos_color_diff_method)
        assert len(self.pos_class_names) == len(self.pos_diff_target_colors)
        assert len(self.pos_class_names) == len(self.pos_use_bbox)
        assert len(self.pos_class_names) == len(self.pos_use_blur_mask)
        assert len(self.pos_class_names) == len(self.pos_use_expand_mask)
        assert len(self.pos_class_names) == len(self.pos_expand_mask_pixel)
        # Make sure config of negative classes is valid
        assert len(self.neg_class_names) == len(self.neg_delta_e_threshold)
        assert len(self.neg_class_names) == len(self.neg_color_diff_method)
        assert len(self.neg_class_names) == len(self.neg_diff_target_colors)
        assert len(self.neg_class_names) == len(self.neg_use_bbox)
        assert len(self.neg_class_names) == len(self.neg_use_blur_mask)
        assert len(self.neg_class_names) == len(self.neg_use_expand_mask)
        assert len(self.neg_class_names) == len(self.neg_expand_mask_pixel)
        # Make sure color different method is valid
        for diff_id, diff_method in enumerate(self.pos_color_diff_method):
            if diff_method not in ['delta_e', 'color_range']:
                raise ValueError("Invalid color difference method '{}', current support only 'delta_e' and 'color_range'".format(diff_method))
            assert isinstance(self.pos_diff_target_colors[diff_id], list)
        for diff_id, diff_method in enumerate(self.neg_color_diff_method):
            if diff_method not in ['delta_e', 'color_range']:
                raise ValueError("Invalid color difference method '{}', current support only 'delta_e' and 'color_range'".format(diff_method))
            assert isinstance(self.neg_diff_target_colors[diff_id], list)

    def from_config(self, filepath):
        """Load segmentation configuration from a YAML file.
        
        Args:
            filepath (str): Path to the configuration YAML file
            
        Raises:
            ValueError: If the configuration file is empty or contains invalid values
        """
        with open(filepath, "r") as f:
            config = yaml.load(f, Loader=yaml.CLoader)
            if not config:
                raise ValueError("received empty config file")
        # Global config
        self.model_segmentation = config["model_segmentation"]
        self.model_color = config["model_color"]
        self.max_workers = config.get("max_workers", 4)
        assert self.max_workers > 0, "Invalid max worker: {}".format(self.max_workers)
        
        # Positive classes
        self.pos_class_names         = []
        self.pos_delta_e_threshold   = []
        self.pos_color_diff_method   = []
        self.pos_diff_target_colors  = []
        self.pos_use_bbox            = []  
        self.pos_use_blur_mask       = []
        self.pos_use_expand_mask     = []
        self.pos_expand_mask_pixel   = []
        
        # Negative classes
        self.neg_class_names         = []
        self.neg_delta_e_threshold   = []
        self.neg_color_diff_method   = []
        self.neg_diff_target_colors  = []
        self.neg_use_bbox            = []
        self.neg_use_blur_mask       = []
        self.neg_use_expand_mask     = []
        self.neg_expand_mask_pixel   = []

        for s in config["segments"]:
            segment_id = s["id"] # 1, 2, 3 ...
            name = s["name"] # floor1, landscape2, ...
            mask_type = s["mask_type"] # pos / neg
            assert mask_type in ["pos", "neg"], "Invalid mask_type '{}', current support only 'pos' and 'neg'".format(mask_type)
            classes = s["classes"] # LIST: floor, cabinets, ...
            color_diff_method = s["color_diff_method"] # mix / single
            assert color_diff_method in ["delta_e", "color_range"], "Invalid color difference method '{}', current support only 'delta_e' and 'color_range'".format(color_diff_method)
            delta_e_threshold = s.get("delta_e_threshold", 25.0)
            mix_target_colors = s.get("diff_target_colors", ["red"])
            use_bbox = s.get("use_bbox", False)
            use_expand_mask = s.get("use_expand_mask", False)
            use_blur_mask   = s.get("use_blur_mask", False)
            expand_mask_pixel = s.get("expand_mask_pixel", 10)
            if mask_type == "pos":
                for class_name in classes:
                    self.pos_class_names.append(class_name)
                    self.pos_delta_e_threshold.append(delta_e_threshold)
                    self.pos_color_diff_method.append(color_diff_method)
                    self.pos_diff_target_colors.append(mix_target_colors)
                    self.pos_use_bbox.append(use_bbox)
                    self.pos_use_expand_mask.append(use_expand_mask)
                    self.pos_use_blur_mask.append(use_blur_mask)
                    self.pos_expand_mask_pixel.append(expand_mask_pixel)
            else:
                for class_name in classes:
                    self.neg_class_names.append(class_name)
                    self.neg_delta_e_threshold.append(delta_e_threshold)
                    self.neg_color_diff_method.append(color_diff_method)
                    self.neg_diff_target_colors.append(mix_target_colors)
                    self.neg_use_bbox.append(use_bbox)
                    self.neg_use_expand_mask.append(use_expand_mask)
                    self.neg_use_blur_mask.append(use_blur_mask)
                    self.neg_expand_mask_pixel.append(expand_mask_pixel)
        # Make sure all config is valid
        self.assert_config()


    def _parse_segmentation_masks(self,
                                predicted_str: str,
                                img_height: int,
                                img_width: int) -> List[SegmentationMask]:
        """Parse segmentation masks from JSON string.
        
        Args:
            predicted_str (str): JSON string containing mask data
            img_height (int): Height of the original image
            img_width (int): Width of the original image
            
        Returns:
            List[SegmentationMask]: List of parsed segmentation masks
        """
        try:
            items = json.loads(parse_json(predicted_str))
        except ValueError as e:
            import logging
            logging.warning(f"_parse_segmentation_masks: Skipping output, reason: {e}")
            return []
        masks = []
        
        for item in items:
            raw_box = item["box_2d"]
            abs_y0 = int(raw_box[0] / 1000 * img_height)
            abs_x0 = int(raw_box[1] / 1000 * img_width)
            abs_y1 = int(raw_box[2] / 1000 * img_height)
            abs_x1 = int(raw_box[3] / 1000 * img_width)
            
            if abs_y0 >= abs_y1 or abs_x0 >= abs_x1:
                print("Invalid bounding box", item["box_2d"])
                continue
                
            label = item["label"]
            png_str = item["mask"]
            
            if not png_str.startswith("data:image/png;base64,"):
                print("Invalid mask")
                continue
                
            png_str = png_str.removeprefix("data:image/png;base64,")
            png_str = base64.b64decode(png_str)
            mask = Image.open(io.BytesIO(png_str))
            
            bbox_height = abs_y1 - abs_y0
            bbox_width = abs_x1 - abs_x0
            
            if bbox_height < 1 or bbox_width < 1:
                print("Invalid bounding box")
                continue
                
            mask = mask.resize((bbox_width, bbox_height), resample=Image.Resampling.NEAREST)
            np_mask = np.zeros((img_height, img_width), dtype=np.uint8)
            np_mask[abs_y0:abs_y1, abs_x0:abs_x1] = np.array(mask)
            
            masks.append(SegmentationMask(abs_y0, abs_x0, abs_y1, abs_x1, np_mask, label))
            
        return masks

    def _plot_segmentation_white_masks(self,
                                    img: Image.Image,
                                    segmentation_masks: List[SegmentationMask]) -> Image.Image:
        """Create black and white mask image from segmentation masks.
        
        Args:
            img (Image.Image): Original image to get dimensions from
            segmentation_masks (List[SegmentationMask]): List of segmentation masks
            
        Returns:
            Image.Image: Binary mask image with white regions for segmented areas
        """
        mask_img = Image.new(img.mode, img.size, 0)
        draw = ImageDraw.Draw(mask_img)
        
        for mask in segmentation_masks:
            draw.rectangle(
                ((mask.x0, mask.y0), (mask.x1, mask.y1)),
                fill=255
            )
        
        return mask_img

    def request_gemini(self,
                        task_type: str,
                        task_class_name: str,
                        task_model_name: str,
                        task_response_modalities: list,
                        task_output_type: str,
                        task_data: list,
                        attempts: int = 2,
                        apply_quality_filtering: bool = True):
        """Perform a single request to Gemini API and handle the response.
        
        Args:
            task_type (str): Type of task (segment, recolor-red, etc.)
            task_class_name (str): Name of the class to process
            task_model_name (str): Name of the Gemini model to use
            task_response_modalities (list): Response modalities (Text, Image, etc.)
            task_output_type (str): Expected output type ('image' or 'str')
            task_data (list): Input data for the API
            attempts (int, optional): Number of retry attempts. Defaults to 2.
            
        Returns:
            dict: Task result containing task_type, task_class_name, and task_response
            
        Raises:
            AssertionError: If attempts is not positive or expected content is missing
        """
        assert attempts > 0, "Invalid attempts: {}".format(attempts)
        attempt_time = 0
        if DEBUG:
            print('-'*10)
            print('  - Task type:', task_type)
            print('  - Task class name:', task_class_name)
            print('  - Task response modalities:', task_response_modalities)
        if apply_quality_filtering and task_type.startswith('recolor-'):
            best_response_image = None
            best_translate_dx_dy = 999
        start_time = time.time()
        while attempt_time < attempts:
            # try:
            attempt_time += 1

            response = self.client.models.generate_content(
                model=task_model_name,
                contents=task_data,
                config=types.GenerateContentConfig(
                    response_modalities=task_response_modalities
                )
            )
            if task_output_type == 'image':
                assert response.candidates[0].content is not None # Make sure that gemini has output image
                # Apply quality filtering if enabled
                if apply_quality_filtering and task_type.startswith('recolor-'):
                    response_image = None
                    for part in response.candidates[0].content.parts:
                        if part.inline_data is not None:
                            try:
                                response_image = Image.open(BytesIO(part.inline_data.data))
                            except:
                                decoded_data = base64.b64decode(part.inline_data.data)
                                response_image = Image.open(BytesIO(decoded_data))
                            break
                    if response_image is None:
                        continue
                    else:
                        # Calculate the quality of the recolored image by aligning it with the original image
                        original_image = task_data[0]
                        _, translate_dx_dy = align_images_phase_correlation(original_image, response_image)
                        max_translate = max(abs(translate_dx_dy[0]), abs(translate_dx_dy[1]))
                        if max_translate > 1.0:
                            if max_translate < best_translate_dx_dy:
                                best_response_image = response_image
                                best_translate_dx_dy = max_translate
                            print(f"Recolored image quality is low, translation: {translate_dx_dy}, retrying...")
                            continue
                        else:
                            best_response_image = response_image
                            best_translate_dx_dy = max_translate
                            print(f"Recolored image quality is good, translation: {translate_dx_dy}")
                            break  # Exit the loop if quality is acceptable
            break
        end_time = time.time()
        print(f"Time taken for gemini calling: {end_time - start_time} seconds")
        if task_output_type == 'image':
            if apply_quality_filtering and task_type.startswith('recolor-'):
                response_image = best_response_image
            else:
                response_image = None
                for part in response.candidates[0].content.parts:
                    if part.inline_data is not None:
                        try:
                            response_image = Image.open(BytesIO(part.inline_data.data))
                        except:
                            decoded_data = base64.b64decode(part.inline_data.data)
                            response_image = Image.open(BytesIO(decoded_data))
                        break
            return {
                "task_type": task_type,
                "task_class_name": task_class_name,
                "task_response": response_image
            }
        else:
            return {
                "task_type": task_type,
                "task_class_name": task_class_name,
                "task_response": response.text
            }

    def submit_task(self, task):
        """Submit a task to Gemini API.
        
        Args:
            task (dict): Task parameters to submit to request_gemini
            
        Returns:
            dict: Result from request_gemini
        """
        return self.request_gemini(**task)

    def create_and_submit_tasks(self, image: Image.Image) -> List[dict]:
        """Generate tasks for segmentation and recoloring and submit to Gemini.
        
        Args:
            image (Image.Image): Image to process
            
        Returns:
            List[dict]: List of Gemini API results
        """
        tasks = []

        # 1. Segment image task
        use_bbox_class_names = [clss for ix, clss in enumerate(self.pos_class_names) if self.pos_use_bbox[ix]]
        use_bbox_class_names += [clss for ix, clss in enumerate(self.neg_class_names) if self.neg_use_bbox[ix]]
        if len(use_bbox_class_names) > 0:
            prompt = f"""
            Give the segmentation masks for the following classes if they appear in the image:
            {', '.join(use_bbox_class_names)}

            Output a JSON list of segmentation masks where each entry contains:
            1. The 2D bounding box in the key "box_2d"
            2. The segmentation mask in key "mask"
            3. The text label in the key "label" (use exactly one of the class names listed above)

            Only include objects that clearly match one of the specified classes.
            """
            tasks.append({
                "task_type": "segment",
                "task_class_name": "all",
                "task_model_name": self.model_segmentation,
                "task_response_modalities": ['Text'],
                "task_output_type": "str",
                "task_data": [image, prompt],
                "attempts": 3
            })

        # 2. Recolor task for positive object
        resize_image = self.resize_keep_aspect(image)
        for class_id, class_name in enumerate(self.pos_class_names):
            for target_color in self.pos_diff_target_colors[class_id]:
                # prompt = f"Change {class_name} color to {target_color}."
                prompt = f"only change the {class_name} color to {target_color}. Keep the color of all other objects "
                tasks.append({
                    "task_type": "recolor-{}".format(target_color),
                    "task_class_name": class_name,
                    "task_model_name": self.model_color,
                    "task_response_modalities": ['Text', 'Image'],
                    "task_output_type": "image",
                    "task_data": [deepcopy(resize_image), prompt],
                    "attempts": 3
                    })

        # 3. Recolor task for negative object
        for class_id, class_name in enumerate(self.neg_class_names):
            for target_color in self.neg_diff_target_colors[class_id]:
                prompt = f"only change the {class_name} color to {target_color}. Keep the color of all other objects"
                # prompt = f"ONLY change the color of the {class_name} in the provided image to {target_color}. Keep rest AND image size unchanged."
                tasks.append({
                    "task_type": "recolor-{}".format(target_color),
                    "task_class_name": class_name,
                    "task_model_name": self.model_color,
                    "task_response_modalities": ['Text', 'Image'],
                    "task_output_type": "image",
                    "task_data": [deepcopy(resize_image), prompt],
                    "attempts": 3
                    })

        # 4. Parallel call
        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            gemini_results = executor.map(self.submit_task, tasks)

        return list(gemini_results)
            

    def segment_image(self,
                    image: Image.Image,
                    gemini_results: List[dict]) -> Tuple[List[SegmentationMask], Image.Image, Image.Image]:
        """Segment an image using provided Gemini API results.
        
        Args:
            image (Image.Image): Image to segment
            gemini_results (List[dict]): Results from Gemini API calls
            
        Returns:
            Tuple containing:
            - List[SegmentationMask]: List of segmentation masks
            - Image.Image: Visualization image with colored bounding boxes
            - Image.Image: Binary mask image
        """
        # Gather object for bounding box segmentation
        use_bbox_class_names = [clss for ix, clss in enumerate(self.pos_class_names) if self.pos_use_bbox[ix]]
        use_bbox_class_names += [clss for ix, clss in enumerate(self.neg_class_names) if self.neg_use_bbox[ix]]
        if DEBUG:
            print('[INFO] Bounding box segmentation: ', use_bbox_class_names)
        if len(use_bbox_class_names) == 0:
            return None, None   
        img_width, img_height = image.size
        
        response_text = [res for res in gemini_results if res["task_type"] == "segment"]
        assert len(response_text) == 1
        response_text = response_text[0]["task_response"]
        
        segmentation_masks = self._parse_segmentation_masks(response_text, img_height=img_height, img_width=img_width)

        # Create colored visualization
        visualize_img = image.copy()
        draw = ImageDraw.Draw(visualize_img)
        colors = ['red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'brown', 'gray']
        
        for i, mask in enumerate(segmentation_masks):
            color = colors[i % len(colors)]
            draw.rectangle(
                ((mask.x0, mask.y0), (mask.x1, mask.y1)),
                outline=color,
                width=4
            )
            if mask.label:
                draw.text((mask.x0 + 8, mask.y0 - 20), mask.label, fill=color)

        return segmentation_masks, visualize_img

    def resize_keep_aspect(self,
                            image: Image.Image,
                            max_size: int = 1024):
        """Resize image while maintaining aspect ratio.
        
        Args:
            image (Image.Image): Image to resize
            max_size (int, optional): Maximum dimension (width or height). Defaults to 1024.
            
        Returns:
            Image.Image: Resized image
        """
        width, height = image.size
        ratio = min(max_size/width, max_size/height)
        new_width = int(width*ratio)
        new_height = int(height*ratio)
        return image.resize((new_width, new_height))

    def change_object_color(self,
                        image: Image.Image,
                        class_name: str,
                        color_diff_method: str,
                        delta_e_threshold: float,
                        use_bbox: bool,
                        use_expand_mask: bool,
                        expand_mask_pixel: int,
                        use_blur_mask: bool,
                        segmentation_masks: List[SegmentationMask],
                        gemini_results: List[dict],
                        target_color: str
                        ) -> Tuple[Image.Image, Image.Image]:
        """Change the color of objects of a specific class and generate a mask of the changes.
        
        Args:
            image (Image.Image): Original image
            class_name (str): Name of the class to modify
            color_diff_method (str): Method for color difference detection ('delta_e' or 'color_range')
            delta_e_threshold (float): Threshold for delta_e method
            use_bbox (bool): Whether to clip mask by bounding box
            use_expand_mask (bool): Whether to expand the mask boundary
            expand_mask_pixel (int): Number of pixels to expand mask
            use_blur_mask (bool): Whether to blur the mask
            segmentation_masks (List[SegmentationMask]): List of segmented masks
            gemini_results (List[dict]): Results from Gemini API calls
            target_color (str): Target color for object
            
        Returns:
            Tuple[Image.Image, Image.Image]: Modified image and color mask
            
        Raises:
            ValueError: If no modified image received or color difference method is invalid
        """
        org_size = image.size
        resize_image = self.resize_keep_aspect(image)
        
        modified_image = [res for res in gemini_results if res["task_type"] == "recolor-{}".format(target_color) and res["task_class_name"] == class_name]
        assert len(modified_image) == 1 # Only 1 result
        modified_image = modified_image[0]["task_response"]

        if modified_image is None:
            raise ValueError("No modified image received")

        if DEBUG:
            modified_image.save("modified_image_{}_{}.png".format(class_name, target_color))
        
        if color_diff_method == 'delta_e':
            color_mask = create_perceptual_difference_mask(resize_image,
                                                        modified_image,
                                                        delta_e_threshold=delta_e_threshold)
        elif color_diff_method == 'color_range':
            color_mask = create_color_range_mask(resize_image,
                                                modified_image,
                                                target_color)
        else:
            raise ValueError("Invalid color difference calculation method")

        # Resize mask and image back to normal
        modified_image = modified_image.resize(org_size)
        color_mask = color_mask.resize(org_size, Image.Resampling.NEAREST)

        # Clip mask by bounding box
        if use_bbox:
            color_mask_array = np.array(color_mask)
            clip_color_mask = Image.new(color_mask.mode, color_mask.size, 0)
            clip_color_mask = np.array(clip_color_mask)
            for mask in segmentation_masks:
                if mask.label == class_name:
                    roi = color_mask_array[mask.y0:mask.y1, mask.x0:mask.x1]
                    clip_color_mask[mask.y0:mask.y1, mask.x0:mask.x1] = np.maximum(
                        clip_color_mask[mask.y0:mask.y1, mask.x0:mask.x1],
                        roi
                    )
            color_mask = Image.fromarray(clip_color_mask.astype('uint8'))
        
        # Expand mask 
        if use_expand_mask:
            color_mask = expand_mask(color_mask, radius=int(expand_mask_pixel))
            if DEBUG:
                color_mask.save("expanded_mask_{}.png".format(class_name))
        
        # Blur mask
        if use_blur_mask:
            color_mask = blur_mask(color_mask, radius=5)
            if DEBUG:
                color_mask.save("blurred_mask_{}.png".format(class_name))

        # Using NEAREST interpolation to avoid new value in segmentation mask
        return modified_image, color_mask
    
    def create_combine_mask(self, image: Image.Image,
                            list_class_names: list,
                            list_delta_e_threshold: list,
                            list_color_diff_method: list,
                            list_mix_target_colors: list,
                            list_use_bbox: list,
                            list_use_expand_mask: list,
                            list_expand_mask_pixel: list,
                            list_use_blur_mask: list,
                            segmentation_masks: List[SegmentationMask],
                            gemini_results: List[dict],
                            ):
        """Generate combined mask from multiple classes and their configurations.
        
        Args:
            image (Image.Image): Original image
            list_class_names (list): List of class names
            list_delta_e_threshold (list): List of delta_e thresholds
            list_color_diff_method (list): List of color difference methods
            list_mix_target_colors (list): List of target colors for each class
            list_use_bbox (list): List of bbox usage flags
            list_use_expand_mask (list): List of mask expansion flags
            list_expand_mask_pixel (list): List of mask expansion sizes
            list_use_blur_mask (list): List of mask blur flags
            segmentation_masks (List[SegmentationMask]): List of segmentation masks
            gemini_results (List[dict]): Results from Gemini API calls
            
        Returns:
            Image.Image: Combined mask from all specified classes
        """
        # Create a combined mask for positive all classes
        all_combined_mask = None
        
        # Process each class in the positive list
        for class_id, class_name in enumerate(list_class_names):
            combine_masks = []
            for color in list_mix_target_colors[class_id]:
                modified_image, color_mask = self.change_object_color(deepcopy(image),
                                                                    class_name,
                                                                    target_color=color,
                                                                    delta_e_threshold = list_delta_e_threshold[class_id],
                                                                    color_diff_method = list_color_diff_method[class_id],
                                                                    use_bbox = list_use_bbox[class_id],
                                                                    use_expand_mask = list_use_expand_mask[class_id],
                                                                    use_blur_mask = list_use_blur_mask[class_id],
                                                                    expand_mask_pixel = list_expand_mask_pixel[class_id],
                                                                    segmentation_masks = segmentation_masks,
                                                                    gemini_results = gemini_results
                                                                    )
                combine_masks.append(color_mask)
                if DEBUG:
                    color_mask.save(f"{class_name}_{color}_mask.png")

            color_mask = bitvise_or_list_mask(combine_masks)
            if DEBUG:
                color_mask.save(f"{class_name}_color_mask.png")

            # For the first class, initialize the combined mask
            if all_combined_mask is None:
                all_combined_mask = color_mask
            else:
                # Combine masks by taking the maximum value at each pixel
                all_combined_mask_array = np.maximum(np.array(all_combined_mask), np.array(color_mask))
                all_combined_mask = Image.fromarray(all_combined_mask_array)

        return all_combined_mask

    def __call__(self, image) -> GeminiMaskResult:
        """Generate a segmentation mask from an image.
        
        Args:
            image: Path to input image (str) or PIL Image instance
            
        Returns:
            GeminiMaskResult: Object containing the resulting mask image
            
        Raises:
            ValueError: If image path is invalid or processing fails
        """
        # Load input image
        if isinstance(image, torch.Tensor):
            # Convert from (1, H, W, C) or (B, H, W, C) to (H, W, C)
            if image.ndim == 4:
                image_np = image[0].cpu().numpy()
            elif image.ndim == 3:
                image_np = image.cpu().numpy()
            else:
                raise ValueError(f"Unsupported image tensor shape: {image.shape}")
            # Convert from float [0,1] to uint8 [0,255]
            image_np = (image_np * 255).clip(0, 255).astype("uint8")
            # Convert to PIL Image
            pil_image = Image.fromarray(image_np)
        elif isinstance(image, np.ndarray):
            if image.dtype == np.float32 or image.max() <= 1.0:
                image = (image * 255).clip(0, 255).astype("uint8")
            pil_image = Image.fromarray(image)
        elif isinstance(image, Image.Image):
            pil_image = image
        else:
            raise ValueError(f"Unsupported image input type: {type(image)}")

        pil_image.save("org_img.png")
        if DEBUG:
            pil_image.save("org_img.png")
        # Collect tasks and request to gemini
        gemini_results = self.create_and_submit_tasks(pil_image)

        # Get segmentation masks and modified image using bounding boxes
        segmentation_masks, colored_img = self.segment_image(image = pil_image, gemini_results = gemini_results)
        if DEBUG and colored_img is not None:
            colored_img.save("colored_img.png")
        
        # Create a combined mask for all positive classes
        pos_combined_mask = self.create_combine_mask(image = pil_image,
                                                list_class_names = self.pos_class_names,
                                                list_delta_e_threshold = self.pos_delta_e_threshold,
                                                list_color_diff_method = self.pos_color_diff_method,
                                                list_mix_target_colors = self.pos_diff_target_colors,
                                                list_use_bbox = self.pos_use_bbox,
                                                list_use_blur_mask = self.pos_use_blur_mask,
                                                list_use_expand_mask = self.pos_use_expand_mask,
                                                list_expand_mask_pixel = self.pos_expand_mask_pixel,
                                                segmentation_masks = segmentation_masks,
                                                gemini_results = gemini_results
                                                )
        neg_combined_mask = None
        # Create a combined mask for all negative classes
        if len(self.neg_class_names) > 0:
            neg_combined_mask = self.create_combine_mask(image = pil_image,
                                                    list_class_names = self.neg_class_names,
                                                    list_delta_e_threshold = self.neg_delta_e_threshold,
                                                    list_color_diff_method = self.neg_color_diff_method,
                                                    list_mix_target_colors = self.neg_diff_target_colors,
                                                    list_use_bbox = self.neg_use_bbox,
                                                    list_use_blur_mask = self.neg_use_blur_mask,
                                                    list_use_expand_mask = self.neg_use_expand_mask,
                                                    list_expand_mask_pixel = self.neg_expand_mask_pixel,
                                                    segmentation_masks = segmentation_masks,
                                                    gemini_results = gemini_results
                                                    )
        if neg_combined_mask is not None:
            # Final mask = pos_mask - neg_mask
            unified_mask = np.array(pos_combined_mask, dtype = np.float32) - np.array(neg_combined_mask, dtype = np.float32)
            unified_mask[unified_mask < 0] = 0
            unified_mask = Image.fromarray(unified_mask.astype("uint8"))
        else:
            unified_mask = pos_combined_mask

        if DEBUG:
            pos_combined_mask.save("pos_combined_mask.png")
            if neg_combined_mask is not None:
                neg_combined_mask.save("neg_combined_mask.png")
            unified_mask.save("unified_mask.png")

        # Convert to RGBA for final output
        rgba_mask = convert_mask_to_rgba(unified_mask)
        
        return GeminiMaskResult(rgba_mask, unified_mask)


if __name__ == "__main__":
    gemini_segmentor = GeminiSegmentor(
                                model_segmentation = "gemini-2.5-pro-exp-03-25",
                                model_color = "gemini-2.0-flash-exp-image-generation",
                                pos_class_names = ['roof tiles'],
                                pos_delta_e_threshold = [20.0],
                                pos_color_diff_method = ['delta_e'],
                                pos_diff_target_colors = [['red', 'blue']],
                                pos_use_bbox = [False],
                                pos_use_blur_mask = [False],
                                pos_use_expand_mask = [True],
                                pos_expand_mask_pixel = [10],
                                neg_class_names = [],
                                neg_delta_e_threshold = [],
                                neg_color_diff_method = [],
                                neg_diff_target_colors = [],
                                neg_use_bbox = [],
                                neg_use_blur_mask = [],
                                neg_use_expand_mask = [],
                                neg_expand_mask_pixel = [],
                                max_workers=1
                             )
    image_path = "/workspace/Owens.jpg"
    input_image = Image.open(image_path)
    # input_image = input_image.resize((1280, 1024))
    input_image.save("resize_img.png")
    mask = gemini_segmentor(input_image)
    mask.mask_image.save("gemini_mask.png")
