"""interior remodel functions"""

import reimage
import subproc
import subproc_util
import models
import pipelines
import tempfile
import numpy as np
from PIL import Image
from skimage import measure
from util import check_image_sizes
from checkpoints import OperationCheckpoint
from operations import (
    canny,
    depth,
    inpainting,
    controlnet,
)


def load_preprocessing_models():
    models.load_lama_model()
    pipelines.pipeline_get_loader("sdxl_controlnet")()


def load_inpaint_model():
    pipelines.pipeline_get_loader("sdxl_inpainting_controlnet")()


def preprocess(
    input_image=None,
    mask_image=None,
    seed=None,
    batch_size=None,
    lama_mask_threshold=0.03,
):
    if input_image is None:
        raise Exception("image required")
    if mask_image is None:
        raise Exception("mask required")
    lama = models.load_lama_model()
    ckpt = OperationCheckpoint()
    ckpt.save_result(input_image, name="input")
    ckpt.save_result(mask_image, name="mask")

    # First run LAMA remove on input
    image = Image.open(input_image).convert("RGB")
    mask_image = Image.open(mask_image).convert("L")
    width, height = image.size
    num_pixels = width * height
    binary_mask = (np.array(mask_image) > 0).astype(np.uint8)
    labels_mask = measure.label(binary_mask, connectivity=2)
    regions = measure.regionprops(labels_mask)
    regions.sort(key=lambda x: x.area)
    cur_area = 0
    cur_mask = np.zeros((height, width))
    lama_masks = []
    cur_small_idx = -1
    small_masks = []
    small_areas = []
    large_masks = []
    # detect small and large masks
    for region in regions:
        if (region.area / num_pixels) > lama_mask_threshold:
            mask = np.zeros((height, width))
            coord_mask = tuple(np.array(region.coords).T)
            mask[coord_mask] = 1
            large_masks.append(mask)
        elif cur_small_idx >= 0 and ((region.area + small_areas[cur_small_idx]) / num_pixels) <= lama_mask_threshold:
            coord_mask = tuple(np.array(region.coords).T)
            small_areas[cur_small_idx] += region.area
            small_masks[cur_small_idx][coord_mask] = 1
        else:
            mask = np.zeros((height, width))
            coord_mask = tuple(np.array(region.coords).T)
            mask[coord_mask] = 1
            small_areas.append(region.area)
            small_masks.append(mask)
            cur_small_idx += 1
    # run LaMa on all masks
    for mask in small_masks + large_masks:
        mask = (255 * mask).astype(np.uint8)
        ckpt.save_image(mask, ckpt_name="lama-mask", tmp_prefix="remove-")
        mask = Image.fromarray(mask)
        image = lama(image, mask)
    lama_image_path = ckpt.save_image(
        np.array(image),
        ckpt_name="image-0" if not large_masks else "lama",
        tmp_prefix="remove-",
        is_extra=False,
        update_idx=False,
    )
    # if no large masks, skip depth estimation
    if not large_masks:
        return ckpt.get_state()
    # use LaMA output for depth
    models.load_depth_model()
    depth_image_path = depth(lama_image_path)
    combined_large_mask = np.zeros((height, width)).astype(np.uint8)
    for mask in large_masks:
        combined_large_mask[mask > 0] = 255
    ckpt.save_image(combined_large_mask, ckpt_name="large-mask", tmp_prefix="remove-", is_extra=False, update_idx=False)
    results = controlnet(
        prompt="empty, nothing, clean, blank",
        negative_prompt="cluttered, dirty, objects, items, messy",
        batch_size=1,
        steps=25,
        guidance_scale=7.5,
        seed=seed,
        control_types=["depth-sdxl"],
        control_images_conditioned=[depth_image_path],
        control_scales=[0.8],
    )
    ckpt.merge(results)
    return ckpt.get_state()


def inpaint(
    input_image=None,
    layout_image=None,
    mask_image=None,
    seed=None,
    batch_size=None,
):
    if input_image is None:
        raise Exception("image required")
    if mask_image is None:
        raise Exception("mask required")
    ckpt = OperationCheckpoint()
    ckpt.save_result(layout_image, name="layout")
    models.load_depth_model()
    depth_image_path = depth(layout_image)
    canny_image_path = canny(layout_image)

    results = inpainting(
        prompt="empty, nothing, clean, blank",
        negative_prompt="cluttered, dirty, objects, items, messy",
        input_image=input_image,
        mask_image=mask_image,
        batch_size=batch_size,
        steps=25,
        guidance_scale=7.5,
        strength=0.65,
        seed=seed,
        overmask=True,
        control_types=["depth-sdxl", "canny-sdxl"],
        control_images_conditioned=[depth_image_path, canny_image_path],
        control_scales=[0.95, 0.3],
    )
    ckpt.merge(results)
    return ckpt.get_state()


def run_preprocess(
    job_id=None,
    input_image=None,
    mask_image=None,
    batch_size=None,
    seed=None,
):
    model = "lama"
    subproc_util.subproc_launch(model, job_id, fn=load_preprocessing_models)
    args = {
        "input_image": input_image,
        "mask_image": mask_image,
        "batch_size": batch_size,
        "seed": seed,
    }
    reimage.update_job(job_id, "rendering")
    result = subproc.subproc_call(model, preprocess, args=args)
    return result


def run_inpaint(
    job_id=None,
    input_image=None,
    layout_image=None,
    mask_image=None,
    batch_size=None,
    seed=None,
):
    model = "sdxl_inpaint_controlnet"
    subproc_util.subproc_launch(model, job_id, fn=load_inpaint_model)
    args = {
        "input_image": input_image,
        "layout_image": layout_image,
        "mask_image": mask_image,
        "seed": seed,
        "batch_size": batch_size,
    }
    result = subproc.subproc_call(model, inpaint, args=args)
    return result


def remodel_remove(
    job_id,
    image,
    mask_image,
    seed,
    batch_size,
):
    if job_id is None:
        raise Exception("job_id is required")
    if image is None:
        raise Exception("image is required")
    if mask_image is None:
        raise Exception("mask image is required")

    check_image_sizes(image=image, mask_image=mask_image)

    # run LaMa
    ckpt = OperationCheckpoint()
    result = run_preprocess(
        job_id=job_id,
        input_image=image,
        mask_image=mask_image,
        batch_size=batch_size,
        seed=seed,
    )
    if isinstance(result, BaseException):
        raise result
    ckpt.merge(result, prefix="extra")
    # if no large masks, return lama output
    if "large-mask" not in result:
        return result
    # run inpaint for large masks
    result = run_inpaint(
        job_id=job_id,
        input_image=result["lama"],
        layout_image=result["image-0"],
        mask_image=result["large-mask"],
        batch_size=batch_size,
        seed=seed,
    )
    if isinstance(result, BaseException):
        raise result
    ckpt.merge(result)
    return ckpt.get_state()
