import cv2
import uuid
from skimage import color
from skimage.util import img_as_float
from PIL import Image, ImageFilter
import numpy as np
import cv2
import torch

DEBUG = False

def pil_to_tensor(image: Image.Image) -> torch.Tensor:
    # Convert PIL Image to numpy array if needed
    if image.mode != "RGB":
        image = image.convert("RGB")
    arr = np.array(image).astype(np.float32) / 255.0  # [H, W, C], [0, 1]
    tensor = torch.from_numpy(arr)
    tensor = tensor.unsqueeze(0)  # [1, H, W, C]
    return tensor

def parse_json(text: str) -> str:
    """Extract JSON array from text.

    Args:
        text (str): Input text containing JSON array

    Returns:
        str: Extracted JSON array as string

    Raises:
        ValueError: If no JSON array is found in the text
    """
    start = text.find('[')
    end = text.rfind(']') + 1
    
    if start == -1 or end == 0:
        raise ValueError("No JSON array found in the text")
    
    return text[start:end]

def create_color_range_mask(original_image: Image.Image,
                            modified_image: Image.Image,
                            target_color: str = 'red') -> Image.Image:
    """Create mask highlighting areas where color changes occurred.

    Args:
        original_image (Image.Image): The original image
        modified_image (Image.Image): The modified image to compare against
        target_color (str, optional): Color to detect. Defaults to 'red'.

    Returns:
        Image.Image: Mask image highlighting areas of color change
    """
    orig_array = np.array(original_image)
    mod_array = np.array(modified_image)
    
    if orig_array.shape != mod_array.shape:
        modified_image = modified_image.resize(original_image.size)
        mod_array = np.array(modified_image)
    
    orig_hsv = cv2.cvtColor(orig_array, cv2.COLOR_RGB2HSV)
    mod_hsv = cv2.cvtColor(mod_array, cv2.COLOR_RGB2HSV)
    
    if target_color.lower() == 'red':
        lower_red1 = np.array([0, 70, 50])
        upper_red1 = np.array([10, 255, 255])
        lower_red2 = np.array([170, 70, 50])
        upper_red2 = np.array([180, 255, 255])
        
        mod_mask1 = cv2.inRange(mod_hsv, lower_red1, upper_red1)
        mod_mask2 = cv2.inRange(mod_hsv, lower_red2, upper_red2)
        mod_red_mask = cv2.bitwise_or(mod_mask1, mod_mask2)
        
        orig_mask1 = cv2.inRange(orig_hsv, lower_red1, upper_red1)
        orig_mask2 = cv2.inRange(orig_hsv, lower_red2, upper_red2)
        orig_red_mask = cv2.bitwise_or(orig_mask1, orig_mask2)
        
        new_red_mask = cv2.bitwise_and(mod_red_mask, cv2.bitwise_not(orig_red_mask))
    else:
        diff = np.abs(mod_array.astype(int) - orig_array.astype(int))
        diff_sum = np.sum(diff, axis=2)
        new_red_mask = (diff_sum > 100).astype(np.uint8) * 255
    
    kernel = np.ones((5, 5), np.uint8)
    cleaned_mask = cv2.morphologyEx(new_red_mask, cv2.MORPH_CLOSE, kernel)
    cleaned_mask = cv2.morphologyEx(cleaned_mask, cv2.MORPH_OPEN, kernel)
    
    contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        contours = sorted(contours, key=cv2.contourArea, reverse=True)
        final_mask = np.zeros_like(cleaned_mask)
        num_contours = min(7, len(contours))
        for i in range(num_contours):
            cv2.drawContours(final_mask, [contours[i]], -1, 255, -1)
    else:
        final_mask = cleaned_mask
    
    return Image.fromarray(final_mask)

def create_perceptual_difference_mask(
    original_image: Image.Image,
    modified_image: Image.Image,
    delta_e_threshold: float = 25.0, # Tune this threshold
    use_morphology: bool = True,
    kernel_size: int = 5) -> Image.Image:
    """Create a perceptual difference mask between two images.

    Args:
        original_image (Image.Image): The original image
        modified_image (Image.Image): The modified image to compare against
        delta_e_threshold (float, optional): Threshold for difference detection. Defaults to 25.0.
        use_morphology (bool, optional): Whether to apply morphological operations. Defaults to True.
        kernel_size (int, optional): Size of kernel for morphological operations. Defaults to 5.

    Returns:
        Image.Image: Mask image highlighting perceptual differences
    """
    if original_image.mode != 'RGB':
        original_image = original_image.convert('RGB')
    if modified_image.mode != 'RGB':
        modified_image = modified_image.convert('RGB')
    if original_image.size != modified_image.size:
        modified_image = modified_image.resize(original_image.size)

    img_orig_float = img_as_float(np.array(original_image))
    img_mod_float = img_as_float(np.array(modified_image))

    lab_orig = color.rgb2lab(img_orig_float)
    lab_mod = color.rgb2lab(img_mod_float)

    delta_e_map = color.deltaE_ciede2000(lab_orig, lab_mod)
    mask_np = (delta_e_map > delta_e_threshold).astype(np.uint8) * 255

    if use_morphology and kernel_size > 0:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel, iterations=1)
        mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_CLOSE, kernel, iterations=1)

    return Image.fromarray(mask_np, mode='L')

def blur_mask(mask_img: Image.Image, radius: int = 10) -> Image.Image:
    """Blur the mask image.

    Args:
        mask_img (Image.Image): The mask image to blur
        radius (int, optional): Blur radius. Defaults to 10.

    Returns:
        Image.Image: Blurred mask image
    """
    return mask_img.filter(ImageFilter.GaussianBlur(radius=radius))

def convert_mask_to_rgba(mask_img: Image.Image) -> Image.Image:
    """Convert the mask image to RGBA mode with alpha channel based on mask values.

    Args:
        mask_img (Image.Image): The mask image to convert

    Returns:
        Image.Image: RGBA image where black areas are transparent (alpha=0)
                    and white areas are opaque (alpha=255)
    """
    if mask_img.mode != 'L':
        mask_img = mask_img.convert('L')
    
    mask_data = np.array(mask_img)
    
    rgba_data = np.zeros((mask_img.height, mask_img.width, 4), dtype=np.uint8)
    white_mask = mask_data > 0  # Boolean mask where mask is not black
    rgba_data[white_mask, 0:3] = 255  # Set R, G, B to 255 (white)
    rgba_data[:, :, 3] = mask_data  # Use mask directly for alpha channel
    
    return Image.fromarray(rgba_data, mode='RGBA')

def expand_mask(mask_img: Image.Image, radius: int = 10) -> Image.Image:
    """Expand the mask in all four directions by a specified number of pixels.

    Args:
        mask_img (Image.Image): The mask image to expand
        radius (int, optional): Number of pixels to expand in each direction. Defaults to 10.

    Returns:
        Image.Image: Expanded mask image
    """
    if mask_img.mode != 'L':
        mask_img = mask_img.convert('L')
    mask_array = np.array(mask_img)

    kernel = np.ones((2 * radius + 1, 2 * radius + 1), np.uint8)
    expanded_array = cv2.dilate(mask_array, kernel, iterations=1)

    return Image.fromarray(expanded_array, mode='L')

def apply_soft_overlay(
    image_path: str,
    mask_path: str,
    overlay_color: tuple = (0, 128, 255, 0),  # R, G, B, A
    opacity: float = 0.5) -> Image.Image:
    """Apply soft overlay color to the mask region in the image.

    Args:
        image_path (str): Path to the original image
        mask_path (str): Path to the mask image
        overlay_color (tuple, optional): RGBA color for overlay. Defaults to (0, 128, 255, 0).
        opacity (float, optional): Opacity of the overlay. Defaults to 0.5.

    Returns:
        Image.Image: Image with overlay applied to masked regions
    """
    image = Image.open(image_path).convert("RGBA")
    mask = Image.open(mask_path).convert("RGBA")
    mask = mask.resize(image.size, Image.LANCZOS)

    img_np = np.array(image).astype(np.float32)
    mask_np = np.array(mask).astype(np.float32)

    alpha_mask = mask_np[:, :, 3] / 255.0
    blend_factor = alpha_mask * opacity

    for c in range(3):  # RGB channels
        img_np[:, :, c] = img_np[:, :, c] * (1 - blend_factor) + overlay_color[c] * blend_factor
    img_np[:, :, 3] = 255

    output_image = Image.fromarray(np.uint8(img_np))
    return output_image

def bitvise_or_list_mask(list_mask):
    """Perform bitwise OR operation on a list of masks.

    Args:
        list_mask (list): List of mask images

    Returns:
        Image.Image: Combined mask image
    """
    rs = np.array(list_mask[0])
    if len(list_mask) == 2:
        rs = np.bitwise_or(rs, np.array(list_mask[1]))
    else:
        for ix in range(1, len(list_mask)):
            rs = np.bitwise_or(rs, np.array(list_mask[ix]))
    rs[rs > 0] = 255.0
    return Image.fromarray(rs.astype('uint8'))
    
def tmpname(prefix="tmpfile", suffix=".tmp"):
    """Generate a temporary filename with random component.

    Args:
        prefix (str, optional): Prefix for the filename. Defaults to "tmpfile".
        suffix (str, optional): Suffix for the filename. Defaults to ".tmp".

    Returns:
        str: Path to temporary file
    """
    random=uuid.uuid4().hex[:8]
    return f"/tmp/{prefix}{random}{suffix}"

def check_image_sizes(**kwargs):
    """Verify that all images are the same size and SD compatible.

    Args:
        **kwargs: Dictionary of image paths to check

    Returns:
        tuple: Width and height of the images

    Raises:
        Exception: If images are incompatible or of different sizes
    """
    if len(kwargs) < 1:
        raise Exception("image required")
    cvimg = cv2.imread(kwargs["image"])
    if cvimg is None:
        raise Exception("bad image format: image")
    height, width, _ = cvimg.shape
    for name, img in kwargs.items():
        if name == "image":
            # already checked the image
            continue
        if name == "style_image":
            # style_image can be any size and need not match image
            continue
        if img is None:
            continue
        cvimg = cv2.imread(img)
        if cvimg is None:
            raise Exception(f"bad image format: {name}")
        h, w, _ = cvimg.shape
        if h % 64 != 0 or w % 64 != 0:
            raise Exception(f"image dimensions unsupported: {w}x{h}")
        if h != height or w != width:
            raise Exception(f"image dimensions differ: {width}x{height} {w}x{h}")
    return (width, height)

def align_images_phase_correlation(img_ref_pil: Image.Image,
                                img_to_align_pil: Image.Image,
                                max_shift: int = 15
                                ) -> Image.Image:
    """
    Aligns img_to_align_pil to img_ref_pil using phase correlation to find translation.
    Returns the aligned version of img_to_align_pil, and also the shifted value.
    If max_shift is provided, shifts are capped.
    """
    img_ref_gray_np = np.array(img_ref_pil.convert('L'), dtype=np.float32)
    img_to_align_gray_np = np.array(img_to_align_pil.convert('L'), dtype=np.float32)

    # Ensure same size for phase correlation (should already be if called correctly)
    if img_ref_gray_np.shape != img_to_align_gray_np.shape:
        print("Warning (Align): Image shapes differ. Resizing to_align to ref.")
        img_to_align_gray_np = cv2.resize(img_to_align_gray_np,
                                            (img_ref_gray_np.shape[1], img_ref_gray_np.shape[0]))


    # Phase correlation
    # The dft function automatically zero-pads if sizes are different, but better to have them same.
    # OpenCV's phaseCorrelate returns (dx, dy) and response
    shift, _ = cv2.phaseCorrelate(img_ref_gray_np, img_to_align_gray_np)
    dx, dy = shift[0], shift[1]

    # Cap the shift if max_shift is defined
    if max_shift is not None:
        dx = np.clip(dx, -max_shift, max_shift)
        dy = np.clip(dy, -max_shift, max_shift)

    if DEBUG:
        print(f"[Align] Detected shift (dx, dy): ({dx:.2f}, {dy:.2f})")

    # If shift is negligible, no need to warp
    if abs(dx) < 0.5 and abs(dy) < 0.5: # Threshold for "negligible"
        if DEBUG: print("[Align] Negligible shift. Returning original image to align.")
        return img_to_align_pil, (dx, dy)


    # Create translation matrix and warp the original *color* image to align
    h, w = img_to_align_gray_np.shape
    M = np.float32([[1, 0, -dx], [0, 1, -dy]]) # Translate by -dx, -dy to align to_align with ref

    # Warp the original color PIL image
    img_to_align_np_rgb = np.array(img_to_align_pil) # Use the color version for warping
    aligned_img_np = cv2.warpAffine(img_to_align_np_rgb, M, (w, h),
                                    borderMode=cv2.BORDER_REPLICATE) # Or BORDER_CONSTANT

    aligned_img_pil = Image.fromarray(aligned_img_np)
    if DEBUG:
        print("[Align] Image warped successfully.")
        aligned_img_pil.save('dewraped_image_{}_{}.png'.format(dx, dy))  # Save for debugging
    return aligned_img_pil, (dx, dy)
