import cv2
import torch
import numpy as np

from fast_pytorch_kmeans import KMeans
from fast_pytorch_kmeans.init_methods import _kpp
from PIL import Image
from skimage import measure, morphology

LAB_MAX_L = 100
LAB_MAX_A = 127
LAB_MAX_B = 127
LAB_MIN_L = 0
LAB_MIN_A = -127
LAB_MIN_B = -127


def adjust_chroma(img_lab, lab_color, paint_mask, color_drift_p=95):
    # taxicab distance
    lightness_diff = np.abs(img_lab[:, :, 0] - int(lab_color[0]))
    chroma_diff = np.linalg.norm(img_lab[:, :, 1:] - lab_color[1:].reshape(1, 1, -1), axis=2)
    color_diff = lightness_diff + chroma_diff
    max_color_drift = np.percentile(color_diff, color_drift_p)
    drift_mask = (paint_mask > 0) & (color_diff < max_color_drift)
    img_lab[drift_mask, 1:] = lab_color[1:]
    return img_lab


def adjust_lighting(
    src_image_lab,
    dst_image_lab,
    mask_x,
    mask_y,
    dst_base_max_offset=5,
    dst_base_max_lightness=90,
    dst_base_min_lightness=10,
    dst_max_lightness=110,
    dst_min_lightness=-10,
    lightness_baseline_p=70,
    adjust_near_white=True,
    near_white_lower_p=0,
    near_white_upper_p=100,
    near_white_max_lighting_change=30,
):
    dst_image_lab = dst_image_lab.copy()
    src_lightness = src_image_lab[mask_x, mask_y, 0]
    src_baseline_lightness = np.percentile(src_lightness, lightness_baseline_p)
    src_lightness_diff = src_lightness - src_baseline_lightness
    dst_lightness = dst_image_lab[mask_x, mask_y, 0]
    dst_base_lightness = np.median(dst_lightness)
    max_src_diff = np.max(src_lightness_diff)
    min_src_diff = np.min(src_lightness_diff)
    # adjust max lighting change for near white colors
    # if is_near_white(dst_image_lab[mask_x, mask_y, 1], dst_image_lab[mask_x, mask_y, 2]):
    if adjust_near_white and is_near_white(dst_image_lab[mask_x, mask_y]):
        lower = np.percentile(src_lightness_diff, near_white_lower_p)
        upper = np.percentile(src_lightness_diff, near_white_upper_p)
        max_abs_diff = np.max(np.abs([lower, upper]))
        src_lightness_diff[src_lightness_diff < 0] = (
            near_white_max_lighting_change * src_lightness_diff[src_lightness_diff < 0] / max_abs_diff
        )
        src_lightness_diff[src_lightness_diff > 0] = (
            near_white_max_lighting_change * src_lightness_diff[src_lightness_diff > 0] / max_abs_diff
        )
        max_src_diff = lower
        min_src_diff = upper
    # adjust lighting for extreme colors
    if dst_base_lightness > dst_max_lightness - max_src_diff:
        # if selection is still too bright, adjust downwards but not more than dst_base_max_offset
        dst_base_lightness = max(dst_base_lightness - dst_base_max_offset, dst_max_lightness - max_src_diff)
    elif dst_base_lightness + min_src_diff < dst_min_lightness:
        # if selection is still too dark, adjust upwards but not more than dst_base_max_offset
        dst_base_lightness = min(dst_base_lightness + dst_base_max_offset, dst_min_lightness - min_src_diff)
    dst_base_lightness = np.clip(dst_base_lightness, dst_base_min_lightness, dst_base_max_lightness)
    dst_image_lab[mask_x, mask_y, 0] = np.clip(dst_base_lightness + src_lightness_diff, LAB_MIN_L, LAB_MAX_L)
    return dst_image_lab


def adjust_lighting_contrast(
    painted_image_lab,
    paint_mask_x,
    paint_mask_y,
    src_image_lab=None,
    chroma_dist_threshold=0.0,
    perceptual_dist_threshold=1.0,
    perceptual_dist_threshold_l=1.0,
    perceptual_dist_threshold_l_easy=1.0,
    perceptual_min_area_threshold=0.0,
    perceptual_min_area_threshold_easy=0.0,
    lighting_baseline_p=50,
    near_white_baseline_p=70,
    write=True,
    exclusion_mask=None,
):
    coords = np.concatenate([paint_mask_x.reshape(-1, 1), paint_mask_y.reshape(-1, 1)], axis=1)
    if exclusion_mask is not None:
        new_coords = coords[exclusion_mask[coords[:, 0], coords[:, 1]] == 0]
        if len(new_coords) > 0:
            coords = new_coords
    painted_lab_values = painted_image_lab[coords[:, 0], coords[:, 1]]
    centroids, labels, s_dist = get_lighting_centroids(painted_lab_values)
    multicolored = False
    light_label = 0 if tuple(centroids[0]) > tuple(centroids[1]) else 1
    dark_label = 1 - light_label
    dark_mask = np.zeros_like(painted_image_lab[:, :, 0])
    dark_coords = coords[labels == dark_label]
    dark_mask[dark_coords[:, 0], dark_coords[:, 1]] = 1
    labels_mask = measure.label(dark_mask, connectivity=2)
    regions = measure.regionprops(labels_mask)
    area_threshold = 0
    if len(regions) > 0:
        regions.sort(key=lambda x: x.area, reverse=True)
        ratio = regions[0].area / len(painted_lab_values)
        area_threshold = min(ratio, 1 - ratio)
    if src_image_lab is not None:
        # use source image for centroid distance if provided
        src_lab_values = src_image_lab[coords[:, 0], coords[:, 1]]
        src_centroids, _, s_dist = get_lighting_centroids(src_lab_values)
        src_chroma_diff = (src_centroids[0][1:] - src_centroids[1][1:]).abs()
        if (src_chroma_diff > chroma_dist_threshold).any() and s_dist > perceptual_dist_threshold:
            multicolored = area_threshold > perceptual_min_area_threshold_easy
        elif s_dist > perceptual_dist_threshold_l_easy:
            multicolored = area_threshold > perceptual_min_area_threshold_easy
        elif s_dist > perceptual_dist_threshold_l:
            multicolored = area_threshold > perceptual_min_area_threshold
    else:
        if s_dist > perceptual_dist_threshold_l_easy:
            multicolored = area_threshold > perceptual_min_area_threshold_easy
        elif s_dist > perceptual_dist_threshold_l:
            multicolored = area_threshold > perceptual_min_area_threshold
    if multicolored and write:
        dark_coords = coords[labels == dark_label]
        light_coords = coords[labels == light_label]
        if is_near_white(painted_image_lab[light_coords[:, 0], light_coords[:, 1]]):
            p = near_white_baseline_p
        else:
            p = lighting_baseline_p
        base_lightness = np.percentile([centroids[light_label][0], centroids[dark_label][0]], p)
        dark_diff = float(base_lightness - centroids[dark_label][0])
        light_diff = float(centroids[light_label][0] - base_lightness)
        painted_image_lab[dark_coords[:, 0], dark_coords[:, 1], 0] = np.clip(
            painted_image_lab[dark_coords[:, 0], dark_coords[:, 1], 0] + dark_diff,
            LAB_MIN_L,
            LAB_MAX_L,
        )
        painted_image_lab[light_coords[:, 0], light_coords[:, 1], 0] = np.clip(
            painted_image_lab[light_coords[:, 0], light_coords[:, 1], 0] - light_diff,
            LAB_MIN_L,
            LAB_MAX_L,
        )
    return (
        painted_image_lab,
        centroids[light_label].numpy(),
        centroids[dark_label].numpy(),
        multicolored,
    )


def get_lighting_centroids(pixels):
    X = torch.tensor(pixels)
    n_clusters = 2
    # if only 1 unique pixel, exit early
    if len(torch.unique(X, dim=0)) > 1:
        initial_centroids = _kpp(X, n_clusters)
    else:
        return X[:2], torch.zeros(len(X)), 0
    kmeans = KMeans(n_clusters=n_clusters, mode="euclidean")
    labels = kmeans.fit_predict(X, centroids=initial_centroids)
    other_labels = 1 - labels
    current_centroids = kmeans.centroids[labels].unsqueeze(1)
    other_centroids = kmeans.centroids[other_labels].unsqueeze(1)
    kmeans_centroids = torch.concatenate([current_centroids, other_centroids], axis=1)
    dists = ((kmeans_centroids - X.unsqueeze(1)) ** 2).sum(axis=2).sqrt()
    intra_mean_dist = dists[:, 0].mean()
    inter_mean_dist = dists[:, 1].mean()
    silhouette_coef = (inter_mean_dist - intra_mean_dist) / max(inter_mean_dist, intra_mean_dist)
    return kmeans.centroids, labels, silhouette_coef


def is_near_white(pixels, near_white_a=10, near_white_b=10, near_white_l=85):
    pixels = np.abs(pixels)
    l = np.median(pixels[:, 0])
    a = np.median(pixels[:, 1])
    b = np.median(pixels[:, 2])
    return l >= near_white_l and a <= near_white_a and b <= near_white_b


def set_rgb(image_arr, mask, color, convert_to_lab=False):
    image_arr[mask > 0] = color
    if convert_to_lab:
        return rgb_to_lab(image_arr)
    return image_arr


def lab_to_rgb(image_arr):
    image_arr = cv2.cvtColor(image_arr, cv2.COLOR_LAB2RGB)
    return (255 * image_arr).astype(np.uint8)


def rgb_to_lab(image_arr):
    image_arr = image_arr.astype(np.float32) / 255
    image_arr = cv2.cvtColor(image_arr, cv2.COLOR_RGB2LAB)
    return image_arr


def get_edges_without_surface_details(canny_arr, mlsd_arr, paint_mask, kernel_size=1, iters=1):
    surface_edge_mask = cv2.dilate(mlsd_arr, np.ones((kernel_size, kernel_size), np.uint8), iterations=iters)
    surface_edge_mask = (paint_mask > 0) & (surface_edge_mask > 0)
    canny_no_surface_arr = canny_arr.copy()
    canny_no_surface_arr[(paint_mask > 0) & ~surface_edge_mask] = 0
    return canny_no_surface_arr, surface_edge_mask


def remove_intra_surface_edges(edge_arr, mask, kernel_size=3, iters=2):
    intra_mask = cv2.erode(mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=iters)
    edge_arr[intra_mask > 0] = 0
    return edge_arr, intra_mask


def get_denoised_lighting_mask(
    canny_arr,
    paint_mask,
    noise_kernel_size=5,
    smoothing_kernel_size=11,
    noise_threshold=10000,
):
    inverted_arr = 255 - canny_arr
    inverted_arr[paint_mask == 0] = 0
    noise_kernel = np.ones((noise_kernel_size, noise_kernel_size), np.uint8)
    surface_edge_mask = cv2.dilate(canny_arr, noise_kernel, iterations=2)
    inverted_arr[surface_edge_mask > 0] = 0
    inverted_arr = cv2.dilate(inverted_arr, noise_kernel, iterations=2)
    smoothing_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (smoothing_kernel_size, smoothing_kernel_size))
    inverted_arr = cv2.morphologyEx(inverted_arr, cv2.MORPH_OPEN, smoothing_kernel, iterations=3)
    inverted_arr[paint_mask == 0] = 0
    return (255 * morphology.remove_small_objects(inverted_arr > 0, min_size=noise_threshold, connectivity=2)).astype(
        np.uint8
    )


def get_edge_mask(canny_arr, kernel_size=3, hole_threshold=500, iters=2):
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    mask = cv2.dilate(canny_arr, kernel, iterations=iters)
    mask = (255 * morphology.remove_small_holes(mask > 0, area_threshold=hole_threshold)).astype(np.uint8)
    mask = cv2.erode(mask, kernel, iterations=iters)
    return mask


def get_prompt_details(location_type, room_type):
    if location_type == "interior":
        details = room_type
        if room_type == "kitchen":
            details += " with white cabinets"
    else:
        details = "house, white shutters"
    return details
