"""paint operations and pipelines"""

import logging

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
import os
import cv2
import yaml
import numpy as np
from enum import Enum
from PIL import Image
from pytorch_lightning import seed_everything

from util import check_image_sizes
import reimage
import models
import pipelines
import subproc
import subproc_util

from checkpoints import OperationCheckpoint
from exceptions import BadRequest, InternalServerError
from grounded_sam import clean_mask, get_mask_iou
from operations import (
    canny,
    classify_location,
    classify_room,
    depth,
    mlsd,
    seg,
)
from paint_utils import (
    LAB_MAX_L,
    LAB_MAX_A,
    LAB_MAX_B,
    LAB_MIN_L,
    LAB_MIN_A,
    LAB_MIN_B,
    adjust_chroma,
    adjust_lighting,
    adjust_lighting_contrast,
    get_denoised_lighting_mask,
    get_edge_mask,
    get_edges_without_surface_details,
    get_prompt_details,
    lab_to_rgb,
    remove_intra_surface_edges,
    rgb_to_lab,
    set_rgb,
)

DEFAULT_ASSETS_DIR = f"{os.path.dirname(__file__)}/assets/paint"
DEFAULT_CONFIG_PATH = f"{os.path.dirname(__file__)}/configs/paint/default.yaml"


def paint(
    job_id,
    color_r,
    color_g,
    color_b,
    image,
    mask_image,
    batch_size,
    seed,
    keep_surface_details,
    surface_type,
    overmask=True,
):
    if job_id is None:
        raise Exception("job_id is required")
    if image is None:
        raise Exception("image is required")
    if mask_image is None:
        raise Exception("mask_image is required")
    if None in [color_r, color_g, color_b]:
        raise Exception("color must be specified")
    for color in [color_r, color_g, color_b]:
        if not isinstance(color, int) or color < 0 or color > 255:
            raise Exception(f"invalid color: {color}")

    check_image_sizes(image=image, mask_image=mask_image)

    # hardcode at batch_size 1 for now - technically up to 2 is support, but 3 throws OOM
    batch_size = 1

    reimage.update_job(job_id, "rendering")
    config = PaintConfig.from_file(DEFAULT_CONFIG_PATH)
    pipe = PaintPipeline.from_config(config)
    if isinstance(keep_surface_details, bool):
        keep_surface_details = "auto" if not keep_surface_details else "on"
    pipe.init_surface(
        image,
        mask_image,
        keep_surface_details=keep_surface_details,
        surface_type=surface_type,
    )
    result = pipe.apply_paint(
        color=[color_r, color_g, color_b],
        batch_size=batch_size,
        seed=seed,
        overmask=overmask,
    )
    return result


def load_preprocessing_model(pipeline, config):
    model_config = config.model.preprocess[pipeline]
    load_pipeline = pipelines.pipeline_get_loader(pipeline)
    load_pipeline(with_refiner=False)
    pipe = pipelines.pipe
    if pipeline == "sdxl_controlnet_ip_adapter_global":
        pipe.controlnet = models.controlnet_get(model_config.controlnets)
        ip_adapter_scale = {
            "up": {"block_0": [0.0, model_config.ip_adapter_scale, 0.0]},
        }
        pipe.set_ip_adapter_scale(ip_adapter_scale)
    elif pipeline == "sdxl_controlnet_pag":
        pipe.controlnet = models.controlnet_get(model_config.controlnets)
    else:
        raise ValueError(f"unknown pipeline: {pipeline}")
    pipe.scheduler = pipelines.scheduler_get(model_config.scheduler, pipe.scheduler.config)


def run_preprocessing_model(
    pipeline,
    config,
    input_image,
    paint_mask_image,
    depth_image,
    canny_image,
    style_image,
    batch_size,
    prompt,
):
    model_config = config.model.preprocess[pipeline]
    if pipeline == "sdxl_controlnet_ip_adapter_global":
        args = {
            "image": [depth_image, canny_image],
            "num_images_per_prompt": batch_size,
            "num_inference_steps": model_config.steps,
            "controlnet_conditioning_scale": model_config.controlnet_scales,
            "guidance_scale": model_config.guidance_scale,
            "prompt": prompt,
            "negative_prompt": model_config.neg_prompt,
            "ip_adapter_image": style_image,
        }
    elif pipeline == "sdxl_controlnet_pag":
        args = {
            "image": [depth_image, canny_image],
            "num_images_per_prompt": batch_size,
            "num_inference_steps": model_config.steps,
            "controlnet_conditioning_scale": model_config.controlnet_scales,
            "guidance_scale": model_config.guidance_scale,
            "pag_scale": model_config.pag_scale,
            "prompt": prompt,
            "negative_prompt": model_config.neg_prompt,
        }
    else:
        raise ValueError(f"unknown pipeline: {pipeline}")
    pipe = pipelines.pipe
    result = pipe(**args)
    output_images = []
    preprocessed_images = []
    preprocess_image = None
    preprocess_overlay_image = None
    intermediate = None
    ref_intermediate = None
    for output_image in result.images:
        preprocess_image = output_image
        output_image = pipe.image_processor.apply_overlay(paint_mask_image, input_image, output_image)
        output_image = np.array(output_image)
        output_images.append(output_image)
        preprocessed_images.append(output_image.copy())
        preprocess_overlay_image = output_image
    intermediate = output_images
    ref_intermediate = [rgb_to_lab(image) for image in output_images]
    return (
        preprocess_image,
        preprocess_overlay_image,
        intermediate,
        ref_intermediate,
        preprocessed_images,
    )


def load_model(config):
    pipeline = "sdxl_img2img_controlnet_pag"
    pipelines.pipeline_get_loader(pipeline)()

    model_config = config.model.main
    pipe = pipelines.pipe
    pipe.controlnet = models.controlnet_get(model_config.controlnets)
    pipe.scheduler = pipelines.scheduler_get(model_config.scheduler, pipe.scheduler.config)


def run_model(
    input_images,
    prompt,
    negative_prompt,
    num_inference_steps,
    num_images_per_prompt,
    depth_image,
    canny_image,
    controlnet_conditioning_scale,
    guidance_scale,
    pag_scale,
    strength,
):
    output_images = []
    for input_image in input_images:
        args = {
            "image": input_image,
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "num_inference_steps": num_inference_steps,
            "num_images_per_prompt": num_images_per_prompt,
            "control_image": [depth_image, canny_image],
            "controlnet_conditioning_scale": controlnet_conditioning_scale,
            "guidance_scale": guidance_scale,
            "strength": strength,
        }
        result = pipelines.pipe(**args)
        for _, img in enumerate(result.images):
            output_images.append(img)
        return output_images


class PaintConfig(dict):
    @classmethod
    def from_file(cls, filepath):
        with open(filepath, "r") as f:
            config_dict = yaml.load(f, Loader=yaml.CLoader)
        return cls(**config_dict)

    def __getattr__(self, name):
        return self[name]

    def __getitem__(self, name):
        value = dict.get(self, name)
        if isinstance(value, dict):
            value = PaintConfig(value)
        return value

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, d):
        self.__dict__ = d


class PaintSurfaceType(Enum):
    GENERIC = 0
    BRICK = 1
    CABINET = 2
    DOOR = 3
    OUTSIDE_WALL = 4
    WALL = 5
    SHUTTER = 6


class PaintSurface:
    def __init__(
        self,
        image_path,
        mask_image_path,
        config,
        keep_surface_details="auto",
        checkpoint=None,
        surface_type=None,
        surface_mask=None,
        assets_dir=DEFAULT_ASSETS_DIR,
    ):
        self.config = config
        self.keep_surface_details = keep_surface_details
        self.image_path = image_path
        self.mask_image_path = mask_image_path
        self.image = Image.open(image_path).convert("RGB")
        paint_mask_image = Image.open(mask_image_path).convert("L")
        paint_mask = np.array(paint_mask_image)
        paint_mask, is_empty = clean_mask(
            paint_mask,
            min_obj_size=config.mask_min_obj_size,
            hole_threshold=config.mask_hole_threshold,
        )
        if is_empty:
            raise BadRequest("Not enough paint. Apply paint to a larger area.")
        self.paint_mask = paint_mask
        self.ref_style_image = None
        self.requires_reference = False
        self.restrict_lighting_contrast = False
        self.lighting_correction_enabled = False
        self.location_type = classify_location(image_path)
        self.room_type = classify_room(image_path)
        self.surface_types = [
            PaintSurfaceType.BRICK,
            PaintSurfaceType.DOOR,
        ]
        if self.location_type == "interior":
            self.surface_types.append(PaintSurfaceType.CABINET)
            self.surface_types.append(PaintSurfaceType.WALL)
        else:
            self.surface_types.append(PaintSurfaceType.OUTSIDE_WALL)
            self.surface_types.append(PaintSurfaceType.SHUTTER)
        self.surface_presets = {
            PaintSurfaceType.BRICK: "brick",
            PaintSurfaceType.CABINET: "cabinet",
            PaintSurfaceType.DOOR: "door",
            PaintSurfaceType.OUTSIDE_WALL: "outside-wall",
            PaintSurfaceType.WALL: "wall",
            PaintSurfaceType.SHUTTER: "shutter",
        }
        self.assets_dir = assets_dir
        self.overrides = {}
        if checkpoint is None:
            self.checkpoint = OperationCheckpoint()
        else:
            self.checkpoint = checkpoint
        checkpoint.save_result(image_path, name="orig")
        checkpoint.save_image(
            paint_mask,
            tmp_prefix="paint-mask-",
            ckpt_name="mask",
        )
        if surface_type is None:
            self._set_surface_type()
        else:
            self.surface_type = surface_type
            if surface_mask is None:
                self.surface_mask = self.paint_mask
            else:
                self.surface_mask = surface_mask
        self._generate_condition_images()

    def _set_surface_type(self):
        config = self.config
        max_iou = 0
        best_surface_type = PaintSurfaceType.GENERIC
        best_surface_mask = self.paint_mask
        for surface_type in self.surface_types:
            preset = self.surface_presets[surface_type]
            surface_mask_path = seg(image=self.image_path, preset=preset)
            surface_mask_arr = np.array(Image.open(surface_mask_path).convert("L"))
            surface_mask_arr, is_empty = clean_mask(
                surface_mask_arr,
                min_obj_size=config.mask_min_obj_size,
                hole_threshold=config.mask_hole_threshold,
            )
            if is_empty:
                continue
            iou = get_mask_iou(self.paint_mask > 0, surface_mask_arr > 0)
            if iou > max_iou:
                max_iou = iou
                best_surface_type = surface_type
                best_surface_mask = surface_mask_arr
        self.surface_type = best_surface_type
        self.surface_mask = best_surface_mask

    def _generate_condition_images(self):
        image_path = self.image_path
        canny_image_path = canny(image_path)
        depth_image_path = depth(image_path)
        mlsd_image_path = mlsd(image_path)
        self.canny_image = Image.open(canny_image_path).convert("L")
        self.depth_image = Image.open(depth_image_path).convert("RGB")
        self.mlsd_image = Image.open(mlsd_image_path).convert("L")
        canny_arr = np.array(self.canny_image)
        depth_arr = np.array(self.depth_image)
        mlsd_arr = np.array(self.mlsd_image)
        checkpoint = self.checkpoint
        checkpoint.save_result(canny_image_path, name="canny")
        checkpoint.save_result(depth_image_path, name="depth")
        checkpoint.save_result(mlsd_image_path, name="mlsd")

        surface = self.surface_type
        if surface == PaintSurfaceType.BRICK:
            self.lighting_mask = self.paint_mask
        elif surface == PaintSurfaceType.CABINET:
            preset = self.surface_presets[surface]
            config = self.config[preset]
            paint_mask = self.paint_mask
            paint_mask_x, paint_mask_y = np.nonzero(paint_mask)
            canny_density = np.count_nonzero(canny_arr[paint_mask > 0]) / len(paint_mask_x)
            keep_surface_details = self.keep_surface_details
            if keep_surface_details == "auto":
                if canny_density > config.surface_density_threshold:
                    # keep only surface edge pixels and remove details
                    canny_no_surface_arr, surface_edge_mask = get_edges_without_surface_details(
                        canny_arr,
                        mlsd_arr,
                        paint_mask,
                        kernel_size=1,
                        iters=1,
                    )
                    canny_no_surface_density = np.count_nonzero(canny_no_surface_arr[paint_mask > 0]) / len(
                        paint_mask_x
                    )
                    if canny_no_surface_density < config.edge_detail_threshold:
                        keep_surface_details = "off"
                    else:
                        keep_surface_details = "on"
                else:
                    keep_surface_details = "on"
            if keep_surface_details == "on":
                self.lighting_mask = self.paint_mask
                if canny_density < config.lighting_density_threshold:
                    self.restrict_lighting_contrast = True
                self.lighting_correction_enabled = True
            else:
                canny_no_surface_arr, surface_edge_mask = get_edges_without_surface_details(
                    canny_arr,
                    mlsd_arr,
                    paint_mask,
                    kernel_size=1,
                    iters=1,
                )
                canny_image_no_surface = Image.fromarray(canny_no_surface_arr)
                canny_image_no_surface_path = checkpoint.save_image(
                    canny_no_surface_arr,
                    tmp_prefix="paint-canny-no-surface-",
                    ckpt_name="canny-no-surface",
                )
                self.canny_image = Image.open(canny_image_no_surface_path).convert("L")
                self.lighting_mask = self.paint_mask
                ref_style_image_path = os.path.join(self.assets_dir, config.ref_style_image)
                self.ref_style_image = Image.open(ref_style_image_path).convert("RGB")
                self.requires_reference = True
                self.lighting_correction_enabled = True
        elif surface == PaintSurfaceType.DOOR:
            self.lighting_mask = self.paint_mask
            self.lighting_correction_enabled = True
        elif surface == PaintSurfaceType.SHUTTER:
            self.lighting_mask = self.paint_mask
            self.lighting_correction_enabled = True
        elif surface == PaintSurfaceType.OUTSIDE_WALL:
            self.lighting_mask = self.paint_mask
        elif surface == PaintSurfaceType.WALL:
            preset = self.surface_presets[surface]
            config = self.config[preset]
            paint_mask = self.paint_mask
            paint_mask_x, paint_mask_y = np.nonzero(paint_mask)
            surface_mask = self.surface_mask
            remove_intra_surface_edges(mlsd_arr, surface_mask)
            canny_density = np.count_nonzero(canny_arr[paint_mask > 0]) / len(paint_mask_x)
            keep_surface_details = self.keep_surface_details
            if keep_surface_details == "auto":
                if canny_density > config.surface_density_threshold:
                    keep_surface_details = "off"
                else:
                    keep_surface_details = "on"
            if keep_surface_details == "on":
                self.lighting_mask = self.paint_mask
                self.restrict_lighting_contrast = True
                self.lighting_correction_enabled = True
            else:
                canny_no_surface_arr, surface_edge_mask = get_edges_without_surface_details(
                    canny_arr, mlsd_arr, paint_mask, kernel_size=1, iters=1
                )
                canny_image_no_surface = Image.fromarray(canny_no_surface_arr)
                canny_image_no_surface_path = checkpoint.save_image(
                    canny_no_surface_arr,
                    tmp_prefix="paint-canny-no-surface-",
                    ckpt_name="canny-no-surface",
                )
                self.canny_image = Image.open(canny_image_no_surface_path).convert("L")
                lighting_mask = get_denoised_lighting_mask(canny_arr, surface_mask)
                self.overrides["strength"] = 0.7
                clean_surface_threshold = np.count_nonzero(lighting_mask) / np.count_nonzero(surface_mask)
                if clean_surface_threshold > config.clean_surface_threshold:
                    self.lighting_mask = lighting_mask
                else:
                    self.lighting_mask = np.zeros_like(paint_mask)
        else:
            self.lighting_mask = self.paint_mask

        lighting_mask_ckpt_name = "lighting-mask"
        checkpoint.save_image(
            self.lighting_mask,
            tmp_prefix="paint-lighting-",
            ckpt_name=lighting_mask_ckpt_name,
        )

    def get_config(self):
        config = self.config
        preset = self.get_preset()
        if preset in config:
            return config[preset]
        return config

    def get_image(self, as_array=False, convert_to_lab=False):
        if as_array:
            image = np.array(self.image)
            if convert_to_lab:
                return rgb_to_lab(image)
            return image
        return self.image

    def get_paint_mask(self, as_array=False):
        if as_array:
            return self.paint_mask
        return Image.fromarray(self.paint_mask).convert("L")

    def get_ref_style_image(self, as_array=False, convert_to_lab=False):
        if self.ref_style_image is None:
            return None
        if as_array:
            image = np.array(self.ref_style_image)
            if convert_to_lab:
                return rgb_to_lab(image)
            return image
        return self.ref_style_image

    def get_prompt(self, detailed=False, model_config=None):
        if model_config is None:
            model_config = self.config.model.main
        if detailed:
            prompt = " ".join(
                [
                    model_config.prompt,
                    get_prompt_details(self.location_type, self.room_type),
                ]
            )
        else:
            prompt = model_config.prompt
        return prompt

    def get_lighting_params(self):
        lighting_params = self.config[self.location_type].lighting
        surface_config = self.get_config()
        if "lighting" in surface_config:
            surface_lighting_params = surface_config.lighting
            for key, value in surface_lighting_params.items():
                lighting_params[key] = value
        return lighting_params

    def get_model_params(self):
        params = {}
        model_config = self.config.model.main
        if "strength" in self.overrides:
            params["strength"] = self.overrides["strength"]
        else:
            params["strength"] = model_config.strength
        params["prompt"] = self.get_prompt()
        return PaintConfig(**params)

    def get_depth(self):
        return self.depth_image

    def get_canny(self):
        return self.canny_image

    def get_mlsd(self):
        return self.mlsd_image

    def get_lighting_mask(self):
        return self.lighting_mask

    def get_preset(self):
        return self.surface_presets.get(self.surface_type)

    def enable_lighting_correction(self):
        self.lighting_correction_enabled = True

    def disable_lighting_correction(self):
        self.lighting_correction_enabled = False

    def should_apply_preprocessing(self):
        return self.requires_reference and self.surface_type in [
            PaintSurfaceType.CABINET,
            PaintSurfaceType.SHUTTER,
        ]

    def should_apply_model(self):
        return self.surface_type in [
            PaintSurfaceType.CABINET,
            PaintSurfaceType.DOOR,
            PaintSurfaceType.SHUTTER,
            PaintSurfaceType.WALL,
        ]

    def should_apply_postprocessing(self):
        return self.surface_type in [
            PaintSurfaceType.CABINET,
            PaintSurfaceType.DOOR,
            PaintSurfaceType.SHUTTER,
            PaintSurfaceType.WALL,
        ]

    def should_apply_lighting_correction(self):
        return self.lighting_correction_enabled

    def should_restrict_lighting_contrast(self):
        return self.restrict_lighting_contrast and self.surface_type in [
            PaintSurfaceType.CABINET,
            PaintSurfaceType.WALL,
        ]


class PaintPipeline:
    def __init__(self, config, checkpoint=None):
        self.config = config
        self.use_ref_image = False
        self.surface_initialized = False
        self.preprocessed_images = None
        self.intermediate = None
        self.ref_intermediate = None
        self.surface = None
        if checkpoint is None:
            self.checkpoint = OperationCheckpoint()
        else:
            self.checkpoint = checkpoint

    @classmethod
    def from_config(cls, config, checkpoint=None):
        return cls(
            config,
            checkpoint=checkpoint,
        )

    def init_surface(
        self,
        image_path,
        mask_image_path,
        keep_surface_details="auto",
        surface_type="auto",
    ):
        try:
            preset_surface_map = {
                "brick": PaintSurfaceType.BRICK,
                "cabinet": PaintSurfaceType.CABINET,
                "door": PaintSurfaceType.DOOR,
                "outside-wall": PaintSurfaceType.OUTSIDE_WALL,
                "wall": PaintSurfaceType.WALL,
            }
            self.surface = PaintSurface(
                image_path,
                mask_image_path,
                self.config,
                keep_surface_details=keep_surface_details,
                surface_type=(preset_surface_map[surface_type] if surface_type != "auto" else None),
                checkpoint=self.checkpoint,
            )
            self.surface_initialized = True
            self.intermediate = self.surface.get_image(as_array=True)
            self.ref_intermediate = None
        except BadRequest as e:
            raise e
        except Exception as e:
            logger.exception(e)
            raise BadRequest("No paint surfaces found. Please check image.") from None

    def apply_paint(self, color, batch_size=1, seed=None, surface=None, overmask=True):
        if not self.surface_initialized:
            raise ValueError("Surface not initialized.")
        try:
            if seed is not None:
                seed_everything(seed)
            surface = self._get_surface(surface)
            self.apply_preprocessing(surface=surface, batch_size=batch_size)
            self.apply_color(color, surface=surface)
            self.apply_lighting(
                surface=surface,
                restrict_contrast=surface.should_restrict_lighting_contrast(),
            )
            self.apply_model(surface=surface, batch_size=batch_size)
            self.apply_postprocessing(surface=surface, overmask=overmask)

            output_images = self._get_intermediates(surface=surface)
            for i, image in enumerate(output_images):
                self.checkpoint.save_image(
                    image,
                    tmp_prefix="paint-",
                    ckpt_name=f"image-{i}",
                    is_extra=False,
                    update_idx=False,
                    convert_from_lab=True,
                )
            return self.checkpoint.get_state()
        except Exception as e:
            logger.exception(e)
            raise InternalServerError("Unable to paint image.") from None

    def apply_preprocessing(self, surface=None, batch_size=1):
        surface = self._get_surface(surface)
        if not surface.should_apply_preprocessing():
            return
        surface_config = surface.get_config()
        input_image = surface.get_image()
        paint_mask_image = surface.get_paint_mask()
        depth_image = surface.get_depth()
        canny_image = surface.get_canny().convert("RGB")
        style_image = surface.get_ref_style_image()
        pipeline = surface_config.preprocess
        model_config = self.config.model.preprocess[pipeline]
        prompt = surface.get_prompt(detailed=True, model_config=model_config)
        # we use a different subproc name because we're not loading the refiner
        # and this subproc can't be used by other {pipeline} users without the refiner
        subproc_util.subproc_launch(
            f"{pipeline}_norefiner",
            None,
            fn=load_preprocessing_model,
            fn_args={"pipeline": pipeline, "config": self.config},
        )
        args = {
            "pipeline": pipeline,
            "config": self.config,
            "input_image": input_image,
            "paint_mask_image": paint_mask_image,
            "depth_image": depth_image,
            "canny_image": canny_image,
            "style_image": style_image,
            "batch_size": batch_size,
            "prompt": prompt,
        }
        result = subproc.subproc_call(f"{pipeline}_norefiner", run_preprocessing_model, args)
        if isinstance(result, BaseException):
            raise result
        (
            preprocess_image,
            preprocess_overlay_image,
            intermediate,
            ref_intermediate,
            preprocessed_images,
        ) = result

        if preprocess_image is not None:
            self.checkpoint.save_image(
                np.array(preprocess_image),
                tmp_prefix="preprocess-",
                ckpt_name="preprocess",
            )
        if preprocess_overlay_image is not None:
            self.checkpoint.save_image(
                preprocess_overlay_image,
                tmp_prefix="preprocess-overlay-",
                ckpt_name="preprocess-overlay",
            )
        self.intermediate = intermediate
        self.ref_intermediate = ref_intermediate
        self.preprocessed_images = preprocessed_images
        for ref_image in ref_intermediate:
            self.checkpoint.save_image(
                ref_image,
                tmp_prefix="preprocess-ref-",
                ckpt_name="preprocess-ref",
                convert_from_lab=True,
            )

        self.preprocessed_images = preprocessed_images
        surface.overrides["strength"] = 0.4

    def apply_color(self, color, surface=None):
        surface = self._get_surface(surface)
        paint_mask = surface.get_paint_mask(as_array=True)
        for image in self._get_intermediates(surface=surface):
            painted_image = set_rgb(
                image,
                paint_mask,
                color,
            )
            painted_image_lab = rgb_to_lab(painted_image)
            paint_mask_x, paint_mask_y = np.nonzero(paint_mask)
            self.base_pixel = painted_image_lab[paint_mask_x, paint_mask_y][0]
            self.intermediate = painted_image_lab
            self.checkpoint.save_image(
                painted_image_lab,
                tmp_prefix="paint-",
                ckpt_name=f"painted",
                convert_from_lab=True,
            )

    def apply_lighting(self, surface=None, restrict_contrast=True):
        config = self.config
        surface = self._get_surface(surface)
        surface_config = surface.get_config()
        input_images_lab = []
        painted_images_lab = self._get_intermediates(surface=surface)
        if self.preprocessed_images is not None:
            input_images_lab = [rgb_to_lab(image) for image in self.preprocessed_images]
        else:
            image_lab = surface.get_image(as_array=True, convert_to_lab=True)
            input_images_lab = [image_lab] * len(painted_images_lab)
        paint_mask_x, paint_mask_y = np.nonzero(surface.get_paint_mask(as_array=True))
        mask_x, mask_y = np.nonzero(surface.get_lighting_mask())
        for input_image_lab, painted_image_lab in zip(input_images_lab, painted_images_lab):
            if len(mask_x) > 0:
                painted_image_lab = adjust_lighting(
                    input_image_lab,
                    painted_image_lab,
                    mask_x,
                    mask_y,
                    adjust_near_white=surface_config.adjust_near_white,
                    **surface.get_lighting_params(),
                )
                if restrict_contrast:
                    exclusion_mask = get_edge_mask(np.array(surface.get_canny()))
                    painted_image_lab, _, _, multicolored = adjust_lighting_contrast(
                        painted_image_lab,
                        paint_mask_x,
                        paint_mask_y,
                        src_image_lab=input_image_lab if not surface.requires_reference else None,
                        chroma_dist_threshold=surface_config.chroma_dist_threshold,
                        perceptual_dist_threshold=surface_config.perceptual_dist_threshold,
                        perceptual_dist_threshold_l=surface_config.perceptual_dist_threshold_l,
                        perceptual_dist_threshold_l_easy=surface_config.perceptual_dist_threshold_l_easy,
                        perceptual_min_area_threshold=surface_config.perceptual_min_area_threshold,
                        perceptual_min_area_threshold_easy=surface_config.perceptual_min_area_threshold_easy,
                        exclusion_mask=exclusion_mask,
                    )
                    if multicolored:
                        surface.disable_lighting_correction()
            self.intermediate = painted_image_lab
            self.checkpoint.save_image(
                painted_image_lab,
                tmp_prefix="paint-lighting-",
                ckpt_name="lighting",
                convert_from_lab=True,
            )

    def apply_model(self, surface=None, batch_size=1):
        surface = self._get_surface(surface)
        if not surface.should_apply_model():
            return
        model_config = self.config.model.main

        pipeline = "sdxl_img2img_controlnet_pag"
        subproc_util.subproc_launch(pipeline, None, fn=load_model, fn_args={"config": self.config})

        input_images = [Image.fromarray(lab_to_rgb(img)) for img in self._get_intermediates(surface=surface)]
        depth_image = surface.get_depth()
        canny_image = surface.get_canny()
        params = surface.get_model_params()
        output_images = []
        args = {
            "input_images": input_images,
            "prompt": params.prompt,
            "negative_prompt": model_config.neg_prompt,
            "num_inference_steps": model_config.steps,
            "num_images_per_prompt": batch_size,
            "depth_image": depth_image,
            "canny_image": canny_image,
            "controlnet_conditioning_scale": model_config.controlnet_scales,
            "guidance_scale": model_config.guidance_scale,
            "pag_scale": model_config.pag_scale,
            "strength": params.strength,
        }
        result = subproc.subproc_call(pipeline, run_model, args)
        if isinstance(result, BaseException):
            raise result
        output_images = result

        output_images = [rgb_to_lab(np.array(img)) for img in output_images]
        self.intermediate = output_images

        for image in output_images:
            self.checkpoint.save_image(
                image,
                tmp_prefix="paint-img2img-",
                ckpt_name="img2img",
                convert_from_lab=True,
            )

    def apply_color_correction(self, surface=None):
        surface = self._get_surface(surface)
        input_images = self._get_intermediates(surface=surface)
        paint_mask = surface.get_paint_mask(as_array=True)
        output_images = []
        config = surface.get_config()
        for img_lab in input_images:
            img_lab = adjust_chroma(
                img_lab,
                self._get_base_pixel(),
                paint_mask,
                color_drift_p=config.color_drift_p,
            )
            output_images.append(img_lab)
        self.intermediate = output_images

        for image in output_images:
            self.checkpoint.save_image(
                image,
                tmp_prefix="paint-chroma-adjust",
                ckpt_name=f"chroma-adjust",
                convert_from_lab=True,
            )

    def apply_lighting_correction(self, surface=None):
        surface = self._get_surface(surface)
        if not surface.should_apply_lighting_correction():
            return
        surface_config = surface.get_config()
        mask_x, mask_y = np.nonzero(surface.get_paint_mask(as_array=True))
        input_images = self._get_intermediates(surface=surface)
        lighting_params = surface.get_lighting_params()
        # turn down brightness for ref images since they include white surfaces
        if self.ref_intermediate is None:
            ref_images = [surface.get_image(as_array=True, convert_to_lab=True)]
            lighting_params["lightness_baseline_p"] = 50
        else:
            ref_images = self._get_ref_intermediates()
            lighting_params["lightness_baseline_p"] = 70
        output_images = []
        for img_lab, ref_img_lab in zip(input_images, ref_images):
            output_lab = adjust_lighting(
                ref_img_lab,
                img_lab,
                mask_x,
                mask_y,
                adjust_near_white=surface_config.adjust_near_white,
                **lighting_params,
            )
            output_images.append(output_lab)
        self.intermediate = output_images
        for image in output_images:
            self.checkpoint.save_image(
                image,
                tmp_prefix="paint-lighting-adjust",
                ckpt_name=f"lighting-adjust",
                convert_from_lab=True,
            )
        self.intermediate = output_images

    def apply_overlay(self, surface=None):
        surface = self._get_surface(surface)
        orig_lab = surface.get_image(as_array=True, convert_to_lab=True)
        input_images = self._get_intermediates(surface=surface)
        paint_mask = surface.get_paint_mask(as_array=True)
        paint_mask = cv2.dilate(paint_mask, np.ones((3, 3), np.uint8), iterations=3)
        mask_blur = cv2.GaussianBlur(paint_mask, (11, 11), 0)
        self.checkpoint.save_image(
            mask_blur,
            tmp_prefix="paint-blur-mask",
            ckpt_name=f"blur-mask",
        )
        mask_blur = mask_blur / 255
        output_images = []
        for img_lab in input_images:
            img_lab[:, :, 0] = np.clip(
                img_lab[:, :, 0] * mask_blur + orig_lab[:, :, 0] * (1 - mask_blur),
                LAB_MIN_L,
                LAB_MAX_L,
            )
            img_lab[:, :, 1] = np.clip(
                img_lab[:, :, 1] * mask_blur + orig_lab[:, :, 1] * (1 - mask_blur),
                LAB_MIN_A,
                LAB_MAX_A,
            )
            img_lab[:, :, 2] = np.clip(
                img_lab[:, :, 2] * mask_blur + orig_lab[:, :, 2] * (1 - mask_blur),
                LAB_MIN_B,
                LAB_MAX_B,
            )
            output_images.append(img_lab)
        self.intermediate = output_images
        for image in output_images:
            self.checkpoint.save_image(
                image,
                tmp_prefix="paint-overlay-mask",
                ckpt_name=f"overlay-mask",
                convert_from_lab=True,
            )

    def apply_postprocessing(self, surface=None, overmask=True):
        surface = self._get_surface(surface)
        if not surface.should_apply_postprocessing():
            return
        self.apply_color_correction(surface=surface)
        self.apply_lighting_correction(surface=surface)
        if overmask:
            self.apply_overlay(surface=surface)

    def __call__(self, color, batch_size=1, seed=None, surface=None):
        surface = self._get_surface(surface)
        return self.apply_paint(color, batch_size=batch_size, seed=seed)

    def _get_intermediates(self, surface=None):
        surface = self._get_surface(surface)
        intermediate = self.intermediate
        if isinstance(intermediate, list):
            return intermediate
        return [intermediate]

    def _get_ref_intermediates(self):
        intermediate = self.ref_intermediate
        if isinstance(intermediate, list):
            return intermediate
        return [intermediate]

    def _get_surface(self, surface):
        return self.surface if surface is None else surface

    def _get_base_pixel(self):
        return self.base_pixel
