"""paint operations and pipelines"""

import logging

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
import os
import cv2
import yaml
import numpy as np
from enum import Enum
from PIL import Image
from typing import Any, Dict, Optional, Tuple
import torch
from PIL import Image  
from .operations import (
    canny,
    depth,
    mlsd,
)

from .paint_utils import (
    LAB_MAX_L,
    LAB_MAX_A,
    LAB_MAX_B,
    LAB_MIN_L,
    LAB_MIN_A,
    LAB_MIN_B,
    adjust_chroma,
    adjust_lighting,
    adjust_lighting_contrast,
    get_denoised_lighting_mask,
    get_edge_mask,
    get_edges_without_surface_details,
    get_prompt_details,
    rgb_to_lab,
    set_rgb,
)

DEFAULT_CONFIG_PATH = f"{os.path.dirname(__file__)}/configs/paint/default.yaml"
from .util import  tmpname

class OperationCheckpoint:
    def __init__(self):
        self.save_idx = 0
        self.state = {}

    def save_image(
        self,
        img_arr,
        tmp_prefix="paint-",
        ckpt_name=None,
        is_extra=True,
        update_idx=True,
        convert_from_lab=False,
    ):
        if convert_from_lab:
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_LAB2RGB)
            img_arr = (255 * img_arr).astype(np.uint8)
        img = Image.fromarray(img_arr).convert("RGB")
        image_path = tmpname(suffix=".png", prefix=tmp_prefix)
        img.save(image_path)
        self.save_result(
            image_path, name=ckpt_name, is_extra=is_extra, update_idx=update_idx
        )
        return image_path

    def save_result(self, image_path, name=None, is_extra=True, update_idx=True):
        idx = self.save_idx
        if name:
            if is_extra:
                key = f"extra-{idx:02}-{name}"
            else:
                key = name
        else:
            key = f"extra-{idx:02}"
        self.state[key] = image_path
        if update_idx:
            self.save_idx += 1

    def merge(self, results, prefix=None):
        for key in results.keys():
            if not prefix or key.startswith(prefix):
                self.state[key] = results[key]

    def get_state(self):
        return self.state


class Paint:
    @classmethod
    def INPUT_TYPES(cls) -> Dict[str , Any]:
        return {
            "required": {
                "color_r": (
                    "INT",
                    {"default": 0, "min": 0, "max": 255, "step": 1},
                ),
                "color_g": (
                    "INT",
                    {"default": 0, "min": 0, "max": 255, "step": 1},
                ),
                "color_b": (
                    "INT",
                    {"default": 0, "min": 0, "max": 255, "step": 1},
                ),
                "images": ("IMAGE",),
                "mask_image": ("IMAGE",),
                "location": (['interior' , 'exterior'], {"default": "interior"},),
                "keep_surface_details": ("BOOLEAN" , {"default" : True} ),
                "surface_type":  ("STRING", {}) ,
            },
        }
    RETURN_TYPES: Tuple[str, str, ...] = ("IMAGE" , "IMAGE" ,)
    RETURN_NAMES: Tuple[str, str, ...] = ("result_image", "color_image" , )
    FUNCTION: str = "paint"
    CATEGORY: str = "Custom/Paint"

    def tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI image tensor (H, W, C) to numpy array (H, W, C)."""
        numpy_array = tensor.cpu().numpy()
        numpy_array = (numpy_array * 255).clip(0, 255).astype(np.uint8)
        return numpy_array
    def mask_to_numpy(self, mask_tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI mask tensor (H, W) to numpy array (H, W)."""
        mask_np = mask_tensor.cpu().numpy()
        mask_np = (mask_np * 255).clip(0, 255).astype(np.uint8)
        return mask_np

    def numpy_to_tensor(self, numpy_array: np.ndarray) -> torch.Tensor:
        """Convert numpy array (H, W, C) to ComfyUI image tensor (H, W, C)."""
        tensor = torch.from_numpy(numpy_array.astype(np.float32) / 255.0)
        return tensor
    def lab_to_rgb(self, image_arr: np.ndarray) -> np.ndarray:
        """Convert LAB (H, W, 3) to RGB (H, W, 3) color space."""
        image_arr = cv2.cvtColor(image_arr, cv2.COLOR_LAB2RGB)
        return np.clip(255 * image_arr, 0, 255).astype(np.uint8)
    def paint(
            self,
            color_r,
            color_g,
            color_b,
            images,
            mask_image,
            location,
            keep_surface_details,
            surface_type,
            
    )-> Tuple[torch.Tensor, ]:
        if images is None:
            raise Exception("image is required")
        if mask_image is None:
            raise Exception("mask_image is required")
        if None in [color_r, color_g, color_b]:
            raise Exception("color must be specified")

        color = (color_r, color_g, color_b)  # red
        
        # Create 100x100 image filled with that color
        img = Image.new("RGB", (1000, 1000), color)
        color_image = torch.stack([torch.from_numpy(np.array(img).astype(np.float32) / 255.0)])
        
        #check if mask_image is full black
        if torch.all(mask_image == 0):
            print("Warning: Empty mask image. Returning original image.")
            return (images, color_image,)

        if surface_type is not None and surface_type not in ["wall" , "cabinet" , "shutter" , "outside-wall", "brick" , "door"]:
            surface_type = "wall"
        results = []

        for image in images:
            image = self.tensor_to_numpy(images[0])
            mask_image = self.mask_to_numpy(mask_image[0])
            os.makedirs("output", exist_ok=True)
            Image.fromarray(np.uint8(image)).save("temp/image.png")
            Image.fromarray(np.uint8(mask_image)).save("temp/mask.png")
            image = "temp/image.png"
            mask_image = "temp/mask.png"
            # hardcode at batch_size 1 for now - technically up to 2 is support, but 3 throws OOM
            batch_size = 1
            overmask=True
            seed = 10
            config = PaintConfig.from_file(DEFAULT_CONFIG_PATH)
            pipe = PaintPipeline.from_config(config)
            if isinstance(keep_surface_details, bool):
                keep_surface_details = "auto" if not keep_surface_details else "on"
            pipe.init_surface(
                image,
                mask_image,
                location,
                keep_surface_details=keep_surface_details,
                surface_type=surface_type,
            )
            result = pipe.apply_paint(
                color=[color_r, color_g, color_b],
                batch_size=batch_size,
                seed=seed,
                overmask=overmask,
            )
            result = self.lab_to_rgb(result)
            results.append(self.numpy_to_tensor(result))
        results = torch.stack(results, dim=0)
       
        return (results, color_image,)


class PaintConfig(dict):
    @classmethod
    def from_file(cls, filepath):
        with open(filepath, "r") as f:
            config_dict = yaml.load(f, Loader=yaml.CLoader)
        return cls(**config_dict)

    def __getattr__(self, name):
        return self[name]

    def __getitem__(self, name):
        value = dict.get(self, name)
        if isinstance(value, dict):
            value = PaintConfig(value)
        return value

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, d):
        self.__dict__ = d


class PaintSurfaceType(Enum):
    GENERIC = 0
    BRICK = 1
    CABINET = 2
    DOOR = 3
    OUTSIDE_WALL = 4
    WALL = 5
    SHUTTER = 6


class PaintSurface:
    def __init__(
            self,
            image_path,
            mask_image_path,
            config,
            location=None,
            keep_surface_details="auto",
            checkpoint=None,
            surface_type=None,
            surface_mask=None,
    ):
        self.config = config
        self.keep_surface_details = keep_surface_details
        self.image_path = image_path
        self.mask_image_path = mask_image_path
        self.image = Image.open(image_path).convert("RGB")
        paint_mask_image = Image.open(mask_image_path).convert("L")
        paint_mask = np.array(paint_mask_image)
        self.paint_mask = paint_mask
        self.ref_style_image = None
        self.requires_reference = False
        self.restrict_lighting_contrast = False
        self.lighting_correction_enabled = False
        self.location_type = location 

        self.surface_presets = {
            PaintSurfaceType.BRICK: "brick",
            PaintSurfaceType.CABINET: "cabinet",
            PaintSurfaceType.DOOR: "door",
            PaintSurfaceType.OUTSIDE_WALL: "outside-wall",
            PaintSurfaceType.WALL: "wall",
            PaintSurfaceType.SHUTTER: "shutter",
        }

        self.overrides = {}
        if checkpoint is None:
            self.checkpoint = OperationCheckpoint()
        else:
            self.checkpoint = checkpoint
        checkpoint.save_result(image_path, name="orig")
        checkpoint.save_image(
            paint_mask,
            tmp_prefix="paint-mask-",
            ckpt_name="mask",
        )
        if surface_type is None:
            self._set_surface_type()
        else:
            self.surface_type = surface_type
            if surface_mask is None:
                self.surface_mask = self.paint_mask
            else:
                self.surface_mask = surface_mask
        self._generate_condition_images()

    def _generate_condition_images(self):
        canny_image_path = canny(self.image_path)
        self.canny_image = Image.open(canny_image_path).convert("L")
        self.checkpoint.save_result(canny_image_path, name="canny") 

        self.lighting_mask = self.paint_mask
        self.restrict_lighting_contrast = True
        self.lighting_correction_enabled = True
        
        self.checkpoint.save_image(
            self.lighting_mask,
            tmp_prefix="paint-lighting-",
            ckpt_name="lighting-mask",
        )
    
    def get_config(self):
        config = self.config
        preset = self.get_preset()
        if preset in config:
            return config[preset]
        return config

    def get_image(self, as_array=False, convert_to_lab=False):
        if as_array:
            image = np.array(self.image)
            if convert_to_lab:
                return rgb_to_lab(image)
            return image
        return self.image

    def get_paint_mask(self, as_array=False):
        if as_array:
            return self.paint_mask
        return Image.fromarray(self.paint_mask).convert("L")

    def get_ref_style_image(self, as_array=False, convert_to_lab=False):
        if self.ref_style_image is None:
            return None
        if as_array:
            image = np.array(self.ref_style_image)
            if convert_to_lab:
                return rgb_to_lab(image)
            return image
        return self.ref_style_image

    def get_lighting_params(self):
        lighting_params = self.config[self.location_type].lighting
        surface_config = self.get_config()
        if "lighting" in surface_config:
            surface_lighting_params = surface_config.lighting
            for key, value in surface_lighting_params.items():
                lighting_params[key] = value
        return lighting_params

    def get_canny(self):
        return self.canny_image

    def get_lighting_mask(self):
        return self.lighting_mask

    def get_preset(self):
        return self.surface_presets.get(self.surface_type)

    def enable_lighting_correction(self):
        self.lighting_correction_enabled = True

    def disable_lighting_correction(self):
        self.lighting_correction_enabled = False

    def should_apply_lighting_correction(self):
        return self.lighting_correction_enabled

    def should_restrict_lighting_contrast(self):
        return self.restrict_lighting_contrast and self.surface_type in [
            PaintSurfaceType.CABINET,
            PaintSurfaceType.WALL,
        ]


class PaintPipeline:
    def __init__(self, config, checkpoint=None):
        self.config = config
        self.use_ref_image = False
        self.surface_initialized = False
        self.preprocessed_images = None
        self.intermediate = None
        self.ref_intermediate = None
        self.surface = None
        if checkpoint is None:
            self.checkpoint = OperationCheckpoint()
        else:
            self.checkpoint = checkpoint

    @classmethod
    def from_config(cls, config, checkpoint=None):
        return cls(
            config,
            checkpoint=checkpoint,
        )

    def init_surface(
            self,
            image_path,
            mask_image_path,
            location,
            keep_surface_details="auto",
            surface_type="auto",
    ):
        
        preset_surface_map = {
            "brick": PaintSurfaceType.BRICK,
            "cabinet": PaintSurfaceType.CABINET,
            "door": PaintSurfaceType.DOOR,
            "outside-wall": PaintSurfaceType.OUTSIDE_WALL,
            "wall": PaintSurfaceType.WALL,
            "shutter": PaintSurfaceType.SHUTTER,
        }
        self.surface = PaintSurface(
            image_path,
            mask_image_path,
            self.config,
            location=location,
            keep_surface_details=keep_surface_details,
            surface_type=(preset_surface_map[surface_type] if surface_type != "auto" else None),
            checkpoint=self.checkpoint,
        )
        self.surface_initialized = True
        self.intermediate = self.surface.get_image(as_array=True)
        self.ref_intermediate = None

    def apply_paint(self, color, batch_size=1, seed=None, surface=None, overmask=True):
        if not self.surface_initialized:
            raise ValueError("Surface not initialized.")
        surface = self._get_surface(surface)
        self.apply_color(color, surface=surface)
        self.apply_lighting(
            surface=surface,
            restrict_contrast=surface.should_restrict_lighting_contrast(),
        )
        # self.apply_model(surface=surface, batch_size=batch_size)
        self.apply_postprocessing(surface=surface, overmask=overmask)
        output_images = self._get_intermediates(surface=surface)[0]
        return output_images


    def apply_color(self, color, surface=None):
        surface = self._get_surface(surface)
        paint_mask = surface.get_paint_mask(as_array=True)
        for image in self._get_intermediates(surface=surface):
            painted_image = set_rgb(
                image,
                paint_mask,
                color,
            )
            painted_image_lab = rgb_to_lab(painted_image)
            paint_mask_x, paint_mask_y = np.nonzero(paint_mask)
            self.base_pixel = painted_image_lab[paint_mask_x, paint_mask_y][0]
            self.intermediate = painted_image_lab
            image_path = self.checkpoint.save_image(
                painted_image_lab,
                tmp_prefix="paint-",
                ckpt_name=f"painted",
                convert_from_lab=True,
            )

    def apply_lighting(self, surface=None, restrict_contrast=True):
        config = self.config
        surface = self._get_surface(surface)
        surface_config = surface.get_config()
        input_images_lab = []
        painted_images_lab = self._get_intermediates(surface=surface)
        if self.preprocessed_images is not None:
            input_images_lab = [rgb_to_lab(image) for image in self.preprocessed_images]
        else:
            image_lab = surface.get_image(as_array=True, convert_to_lab=True)
            input_images_lab = [image_lab] * len(painted_images_lab)
        paint_mask_x, paint_mask_y = np.nonzero(surface.get_paint_mask(as_array=True))
        mask_x, mask_y = np.nonzero(surface.get_lighting_mask())
        for input_image_lab, painted_image_lab in zip(input_images_lab, painted_images_lab):
            if len(mask_x) > 0:
                painted_image_lab = adjust_lighting(
                    input_image_lab,
                    painted_image_lab,
                    mask_x,
                    mask_y,
                    adjust_near_white=surface_config.adjust_near_white,
                    **surface.get_lighting_params(),
                )
                if restrict_contrast:
                    exclusion_mask = get_edge_mask(np.array(surface.get_canny()))
                    painted_image_lab, _, _, multicolored = adjust_lighting_contrast(
                        painted_image_lab,
                        paint_mask_x,
                        paint_mask_y,
                        src_image_lab=input_image_lab if not surface.requires_reference else None,
                        chroma_dist_threshold=surface_config.chroma_dist_threshold,
                        perceptual_dist_threshold=surface_config.perceptual_dist_threshold,
                        perceptual_dist_threshold_l=surface_config.perceptual_dist_threshold_l,
                        perceptual_dist_threshold_l_easy=surface_config.perceptual_dist_threshold_l_easy,
                        perceptual_min_area_threshold=surface_config.perceptual_min_area_threshold,
                        perceptual_min_area_threshold_easy=surface_config.perceptual_min_area_threshold_easy,
                        exclusion_mask=exclusion_mask,
                    )
                    if multicolored:
                        surface.disable_lighting_correction()
            self.intermediate = painted_image_lab
            image_path = self.checkpoint.save_image(
                painted_image_lab,
                tmp_prefix="paint-lighting-",
                ckpt_name="lighting",
                convert_from_lab=True,
            )

    def apply_color_correction(self, surface=None):
        surface = self._get_surface(surface)
        input_images = self._get_intermediates(surface=surface)
        paint_mask = surface.get_paint_mask(as_array=True)
        output_images = []
        config = surface.get_config()
        for img_lab in input_images:
            img_lab = adjust_chroma(
                img_lab,
                self._get_base_pixel(),
                paint_mask,
                color_drift_p=config.color_drift_p,
            )
            output_images.append(img_lab)
        self.intermediate = output_images

        for image in output_images:
            image_path = self.checkpoint.save_image(
                image,
                tmp_prefix="paint-chroma-adjust",
                ckpt_name=f"chroma-adjust",
                convert_from_lab=True,
            )

    def apply_lighting_correction(self, surface=None):
        surface = self._get_surface(surface)
        if not surface.should_apply_lighting_correction():
            return
        surface_config = surface.get_config()
        mask_x, mask_y = np.nonzero(surface.get_paint_mask(as_array=True))
        input_images = self._get_intermediates(surface=surface)
        lighting_params = surface.get_lighting_params()
        # turn down brightness for ref images since they include white surfaces
        if self.ref_intermediate is None:
            ref_images = [surface.get_image(as_array=True, convert_to_lab=True)]
            lighting_params["lightness_baseline_p"] = 50
        else:
            ref_images = self._get_ref_intermediates()
            lighting_params["lightness_baseline_p"] = 70
        output_images = []
        for img_lab, ref_img_lab in zip(input_images, ref_images):
            output_lab = adjust_lighting(
                ref_img_lab,
                img_lab,
                mask_x,
                mask_y,
                adjust_near_white=surface_config.adjust_near_white,
                **lighting_params,
            )
            output_images.append(output_lab)
        self.intermediate = output_images
        for image in output_images:
            image_path = self.checkpoint.save_image(
                image,
                tmp_prefix="paint-lighting-adjust",
                ckpt_name=f"lighting-adjust",
                convert_from_lab=True,
            )
        self.intermediate = output_images
        

    def apply_overlay(self, surface=None):
        surface = self._get_surface(surface)
        orig_lab = surface.get_image(as_array=True, convert_to_lab=True)
        input_images = self._get_intermediates(surface=surface)
        paint_mask = surface.get_paint_mask(as_array=True)
        paint_mask = cv2.dilate(paint_mask, np.ones((3, 3), np.uint8), iterations=3)
        mask_blur = cv2.GaussianBlur(paint_mask, (11, 11), 0)
        self.checkpoint.save_image(
            mask_blur,
            tmp_prefix="paint-blur-mask",
            ckpt_name=f"blur-mask",
        )
        mask_blur = mask_blur / 255
        output_images = []
        for img_lab in input_images:
            img_lab[:, :, 0] = np.clip(
                img_lab[:, :, 0] * mask_blur + orig_lab[:, :, 0] * (1 - mask_blur),
                LAB_MIN_L,
                LAB_MAX_L,
            )
            img_lab[:, :, 1] = np.clip(
                img_lab[:, :, 1] * mask_blur + orig_lab[:, :, 1] * (1 - mask_blur),
                LAB_MIN_A,
                LAB_MAX_A,
            )
            img_lab[:, :, 2] = np.clip(
                img_lab[:, :, 2] * mask_blur + orig_lab[:, :, 2] * (1 - mask_blur),
                LAB_MIN_B,
                LAB_MAX_B,
            )
            output_images.append(img_lab)
        self.intermediate = output_images
        for image in output_images:
            image_path = self.checkpoint.save_image(
                image,
                tmp_prefix="paint-overlay-mask",
                ckpt_name=f"overlay-mask",
                convert_from_lab=True,
            )

    def apply_postprocessing(self, surface=None, overmask=True):
        surface = self._get_surface(surface)
        self.apply_color_correction(surface=surface)
        self.apply_lighting_correction(surface=surface)
        if overmask:
            self.apply_overlay(surface=surface)

    def __call__(self, color, batch_size=1, seed=None, surface=None):
        surface = self._get_surface(surface)
        return self.apply_paint(color, batch_size=batch_size, seed=seed)

    def _get_intermediates(self, surface=None):
        surface = self._get_surface(surface)
        intermediate = self.intermediate
        if isinstance(intermediate, list):
            return intermediate
        return [intermediate]

    def _get_ref_intermediates(self):
        intermediate = self.ref_intermediate
        if isinstance(intermediate, list):
            return intermediate
        return [intermediate]

    def _get_surface(self, surface):
        return self.surface if surface is None else surface

    def _get_base_pixel(self):
        return self.base_pixel
