import cv2
import numpy as np
from skimage import exposure
import torch
from torchvision.utils import save_image
from torchvision.io import read_image, ImageReadMode
from typing import Any, Dict, Optional, Tuple
# Optional import for Delta E 2000 - install with: pip install colorspacious
try:
    from colorspacious import cspace_convert
    HAS_COLORSPACIOUS = True
except ImportError:
    HAS_COLORSPACIOUS = False
    print("Warning: colorspacious not installed. Delta E 2000 will not be available.")

class LABHistogramMatcher:
    @classmethod
    def INPUT_TYPES(cls) -> Dict[str, Any]:
        """
        Defines the input types for the ComfyUI node.
        """
        return {
            "required": {
                "source_image": ("IMAGE",),
                "source_mask": ("MASK",),
                "reference_image": ("IMAGE",),
                "method": (
                    ["Histogram Matching", "Reinhard Transfer"],
                    {"default": "Reinhard Transfer"},
                ),
                "l_blend": (
                    "FLOAT",
                    {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01},
                ),
                "a_blend": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01},
                ),
                "b_blend": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01},
                ),
                "a_var_reduction_strength": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "b_var_reduction_strength": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "calculate_color_difference": (
                    "BOOLEAN",
                    {"default": True},
                ),
            },
            "optional": {
                "reference_mask": ("MASK",),
            },
        }

    RETURN_TYPES: Tuple[str, ...] = ("IMAGE", "STRING")
    RETURN_NAMES: Tuple[str, ...] = ("result_image", "color_metrics")
    FUNCTION: str = "execute"
    CATEGORY: str = "ReImage AI"

    def tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI image tensor (H, W, C) to numpy array (H, W, C)."""
        numpy_array = tensor.cpu().numpy()
        numpy_array = (numpy_array * 255).clip(0, 255).astype(np.uint8)
        return numpy_array

    def mask_to_numpy(self, mask_tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI mask tensor (H, W) to numpy array (H, W)."""
        mask_np = mask_tensor.cpu().numpy()
        mask_np = (mask_np * 255).clip(0, 255).astype(np.uint8)
        return mask_np

    def numpy_to_tensor(self, numpy_array: np.ndarray) -> torch.Tensor:
        """Convert numpy array (H, W, C) to ComfyUI image tensor (H, W, C)."""
        tensor = torch.from_numpy(numpy_array.astype(np.float32) / 255.0)
        return tensor

    def rgb_to_lab(self, image_arr: np.ndarray) -> np.ndarray:
        """Convert RGB (H, W, 3) to LAB (H, W, 3) color space."""
        image_arr = image_arr.astype(np.float32) / 255.0
        image_arr = cv2.cvtColor(image_arr, cv2.COLOR_RGB2LAB)
        return image_arr

    def lab_to_rgb(self, image_arr: np.ndarray) -> np.ndarray:
        """Convert LAB (H, W, 3) to RGB (H, W, 3) color space."""
        image_arr = cv2.cvtColor(image_arr, cv2.COLOR_LAB2RGB)
        return np.clip(255 * image_arr, 0, 255).astype(np.uint8)

    def calculate_delta_e_2000_ab_only(self, lab1: np.ndarray, lab2: np.ndarray) -> float:
        """
        Calculate CIEDE2000 color difference between two LAB color arrays using only A and B channels.
        L channel is set to a fixed value to focus on chromaticity.
        Returns the mean Delta E 2000 value.
        """
        if not HAS_COLORSPACIOUS:
            print("colorspacious not available, skipping Delta E 2000")
            return -1.0
            
        try:
            # Ensure we have valid data
            if len(lab1) == 0 or len(lab2) == 0:
                return float('inf')
            
            # Sample if arrays are very large for performance
            max_samples = 5000  # Reduced for ComfyUI performance
            if len(lab1) > max_samples:
                indices = np.random.choice(len(lab1), max_samples, replace=False)
                lab1_sample = lab1[indices]
            else:
                lab1_sample = lab1
                
            if len(lab2) > max_samples:
                indices = np.random.choice(len(lab2), max_samples, replace=False)
                lab2_sample = lab2[indices]
            else:
                lab2_sample = lab2
            
            # Use the smaller sample size
            min_len = min(len(lab1_sample), len(lab2_sample))
            lab1_sample = lab1_sample[:min_len]
            lab2_sample = lab2_sample[:min_len]
            
            # Create modified LAB arrays with fixed L value (50) to focus on A&B channels only
            fixed_l = 50.0  # Middle lightness value
            lab1_ab_only = lab1_sample.copy()
            lab2_ab_only = lab2_sample.copy()
            lab1_ab_only[:, 0] = fixed_l
            lab2_ab_only[:, 0] = fixed_l
            
            # Calculate pairwise Delta E 2000
            delta_e_values = []
            for i in range(min_len):
                try:
                    # Convert to format expected by colorspacious
                    color1 = lab1_ab_only[i].reshape(1, 3)
                    color2 = lab2_ab_only[i].reshape(1, 3)
                    
                    # Calculate Delta E 2000
                    delta_e = cspace_convert(color1, "CIELab", "DeltaE00", color2)
                    if not np.isnan(delta_e) and not np.isinf(delta_e):
                        delta_e_values.append(delta_e)
                except:
                    continue
            
            if delta_e_values:
                return float(np.mean(delta_e_values))
            else:
                return float('inf')
                
        except Exception as e:
            print(f"Error calculating Delta E 2000 (A&B only): {e}")
            return float('inf')

    def calculate_delta_e_76_ab_only(self, lab1: np.ndarray, lab2: np.ndarray) -> float:
        """
        Calculate CIE76 Delta E color difference between two LAB color arrays using only A and B channels.
        Returns the mean Delta E 76 value for chromaticity only.
        """
        try:
            if len(lab1) == 0 or len(lab2) == 0:
                return float('inf')
            
            # Sample if arrays are very large
            max_samples = 5000  # Reduced for ComfyUI performance
            if len(lab1) > max_samples:
                indices = np.random.choice(len(lab1), max_samples, replace=False)
                lab1_sample = lab1[indices]
            else:
                lab1_sample = lab1
                
            if len(lab2) > max_samples:
                indices = np.random.choice(len(lab2), max_samples, replace=False)
                lab2_sample = lab2[indices]
            else:
                lab2_sample = lab2
            
            min_len = min(len(lab1_sample), len(lab2_sample))
            lab1_sample = lab1_sample[:min_len]
            lab2_sample = lab2_sample[:min_len]
            
            # Calculate CIE76 Delta E using only A and B channels: sqrt((a2-a1)² + (b2-b1)²)
            diff_a = lab2_sample[:, 1] - lab1_sample[:, 1]  # A channel difference
            diff_b = lab2_sample[:, 2] - lab1_sample[:, 2]  # B channel difference
            delta_e_76_ab = np.sqrt(diff_a**2 + diff_b**2)
            
            # Remove invalid values
            valid_mask = ~(np.isnan(delta_e_76_ab) | np.isinf(delta_e_76_ab))
            if np.any(valid_mask):
                return float(np.mean(delta_e_76_ab[valid_mask]))
            else:
                return float('inf')
                
        except Exception as e:
            print(f"Error calculating Delta E 76 (A&B only): {e}")
            return float('inf')

    def calculate_rmse_lab_ab_only(self, lab1: np.ndarray, lab2: np.ndarray) -> float:
        """
        Calculate Root Mean Square Error in LAB color space using only A and B channels.
        """
        try:
            if len(lab1) == 0 or len(lab2) == 0:
                return float('inf')
            
            min_len = min(len(lab1), len(lab2))
            lab1_sample = lab1[:min_len]
            lab2_sample = lab2[:min_len]
            
            # Calculate RMSE for A and B channels only
            diff_a = lab2_sample[:, 1] - lab1_sample[:, 1]
            diff_b = lab2_sample[:, 2] - lab1_sample[:, 2]
            mse_ab = np.mean(diff_a**2 + diff_b**2)
            rmse_ab = np.sqrt(mse_ab)
            
            return float(rmse_ab) if not (np.isnan(rmse_ab) or np.isinf(rmse_ab)) else float('inf')
            
        except Exception as e:
            print(f"Error calculating RMSE (A&B only): {e}")
            return float('inf')

    def calculate_mae_lab_ab_only(self, lab1: np.ndarray, lab2: np.ndarray) -> float:
        """
        Calculate Mean Absolute Error in LAB color space using only A and B channels.
        """
        try:
            if len(lab1) == 0 or len(lab2) == 0:
                return float('inf')
            
            min_len = min(len(lab1), len(lab2))
            lab1_sample = lab1[:min_len]
            lab2_sample = lab2[:min_len]
            
            # Calculate MAE for A and B channels only
            mae_a = np.mean(np.abs(lab2_sample[:, 1] - lab1_sample[:, 1]))
            mae_b = np.mean(np.abs(lab2_sample[:, 2] - lab1_sample[:, 2]))
            mae_ab = (mae_a + mae_b) / 2  # Average of A and B channel MAE
            
            return float(mae_ab) if not (np.isnan(mae_ab) or np.isinf(mae_ab)) else float('inf')
            
        except Exception as e:
            print(f"Error calculating MAE (A&B only): {e}")
            return float('inf')

    def calculate_color_metrics(self, 
                              source_lab: np.ndarray, 
                              source_mask: np.ndarray,
                              reference_lab: np.ndarray, 
                              reference_mask: Optional[np.ndarray] = None) -> Dict[str, float]:
        """
        Calculate comprehensive color difference metrics between masked regions using only A and B channels.
        """
        # Extract masked pixels from source
        source_mask_bool = source_mask > 0
        if not np.any(source_mask_bool):
            print("Warning: Empty source mask for color metrics.")
            return {
                'delta_e_2000_ab': float('inf'),
                'delta_e_76_ab': float('inf'),
                'rmse_lab_ab': float('inf'),
                'mae_lab_ab': float('inf')
            }
        
        source_pixels = source_lab[source_mask_bool]
        
        # Extract masked pixels from reference
        if reference_mask is not None:
            reference_mask_bool = reference_mask > 0
            if not np.any(reference_mask_bool):
                print("Warning: Empty reference mask. Using full reference image for metrics.")
                reference_pixels = reference_lab.reshape(-1, 3)
            else:
                reference_pixels = reference_lab[reference_mask_bool]
        else:
            reference_pixels = reference_lab.reshape(-1, 3)
        
        # Calculate metrics for A and B channels only
        metrics = {}
        
        # Calculate metrics for A and B channels only
        # Delta E 2000 (A&B channels only)
        metrics['delta_e_2000_ab'] = self.calculate_delta_e_2000_ab_only(source_pixels, reference_pixels)
        
        # Delta E 76 (A&B channels only)
        metrics['delta_e_76_ab'] = self.calculate_delta_e_76_ab_only(source_pixels, reference_pixels)
        
        # RMSE in LAB space (A&B channels only)
        metrics['rmse_lab_ab'] = self.calculate_rmse_lab_ab_only(source_pixels, reference_pixels)
        
        # Mean Absolute Error in LAB space (A&B channels only)
        metrics['mae_lab_ab'] = self.calculate_mae_lab_ab_only(source_pixels, reference_pixels)
        
        return metrics

    def apply_histogram_matching_transfer(self, src_pixels, ref_pixels):
        """Applies histogram matching to the LAB pixels."""
        return exposure.match_histograms(src_pixels, ref_pixels)

    def apply_reinhard_transfer(self, src_pixels, ref_pixels):
        """Applies Reinhard transfer to the LAB pixels."""
        src_mean, src_std = np.mean(src_pixels), np.std(src_pixels)
        ref_mean, ref_std = np.mean(ref_pixels), np.std(ref_pixels)

        if src_std == 0:
            return src_pixels - src_mean + ref_mean
    
        temp = src_pixels - src_mean
        temp = (ref_std / src_std) * temp
        return temp + ref_mean

    def apply_histogram_matching(
        self,
        src_image: np.ndarray,
        src_mask: np.ndarray,
        ref_image: np.ndarray,
        ref_mask: Optional[np.ndarray] = None,
        l_blend: float = 0.0,
        a_blend: float = 1.0,
        b_blend: float = 1.0,
        method: str = "Reinhard Transfer",
        a_var_reduction_strength: float = 1.0,
        b_var_reduction_strength: float = 1.0,
    ) -> np.ndarray:
        """Apply LAB histogram matching with channel-specific blending control."""
        LAB_MAX_L, LAB_MIN_L = 100.0, 0.0
        LAB_MAX_A, LAB_MIN_A = 127.0, -128.0
        LAB_MAX_B, LAB_MIN_B = 127.0, -128.0

        lab_A = self.rgb_to_lab(src_image.copy())
        lab_B = self.rgb_to_lab(ref_image.copy())
        modified_lab_A = lab_A.copy()

        mask_indices = src_mask > 0
        if not np.any(mask_indices):
            print("Warning: Empty source mask. Returning original image.")
            return src_image

        l_channel_A_masked = modified_lab_A[mask_indices, 0]
        a_channel_A_masked = modified_lab_A[mask_indices, 1]
        b_channel_A_masked = modified_lab_A[mask_indices, 2]

        if ref_mask is not None:
            mask_B_indices = ref_mask > 0
            if not np.any(mask_B_indices):
                print("Warning: Empty reference mask. Using full reference image.")
                l_channel_B_reference = lab_B[:, :, 0].flatten()
                a_channel_B_reference = lab_B[:, :, 1].flatten()
                b_channel_B_reference = lab_B[:, :, 2].flatten()
            else:
                l_channel_B_reference = lab_B[mask_B_indices, 0]
                a_channel_B_reference = lab_B[mask_B_indices, 1]
                b_channel_B_reference = lab_B[mask_B_indices, 2]
        else:
            l_channel_B_reference = lab_B[:, :, 0].flatten()
            a_channel_B_reference = lab_B[:, :, 1].flatten()
            b_channel_B_reference = lab_B[:, :, 2].flatten()

        transfer_func = (
            self.apply_histogram_matching_transfer
            if method == "Histogram Matching"
            else self.apply_reinhard_transfer
        )

        # --- L Channel Processing ---
        if l_blend > 0.0:
            matched_l = transfer_func(l_channel_A_masked, l_channel_B_reference)
            final_l = ((1 - l_blend) * l_channel_A_masked) + (l_blend * matched_l)
        else:
            final_l = l_channel_A_masked

        # --- A Channel Processing ---
        if a_blend > 0.0:
            matched_a = transfer_func(a_channel_A_masked, a_channel_B_reference)
            final_a = ((1 - a_blend) * a_channel_A_masked) + (a_blend * matched_a)
        else:
            final_a = a_channel_A_masked

        # --- B Channel Processing ---
        if b_blend > 0.0:
            matched_b = transfer_func(b_channel_A_masked, b_channel_B_reference)
            final_b = ((1 - b_blend) * b_channel_A_masked) + (b_blend * matched_b)
        else:
            final_b = b_channel_A_masked
        
        # --- VARIANCE REDUCTION (STD DEV CLAMPING) ---
        if a_var_reduction_strength > 0 and len(final_a) > 1:
            ref_std_a = np.std(a_channel_B_reference)
            current_mean_a = np.mean(final_a)
            current_std_a = np.std(final_a)
            
            target_std_a = ref_std_a / (a_var_reduction_strength + 1e-6)

            if current_std_a > target_std_a:
                print(f"current_std_a: {current_std_a}, target_std_a: {target_std_a}")
                normalized_a = (final_a - current_mean_a) / (current_std_a + 1e-6)
                final_a = normalized_a * target_std_a + current_mean_a

        if b_var_reduction_strength > 0 and len(final_b) > 1:
            ref_std_b = np.std(b_channel_B_reference)
            current_mean_b = np.mean(final_b)
            current_std_b = np.std(final_b)

            target_std_b = ref_std_b / (b_var_reduction_strength + 1e-6)

            if current_std_b > target_std_b:
                print(f"current_std_b: {current_std_b}, target_std_b: {target_std_b}")
                normalized_b = (final_b - current_mean_b) / (current_std_b + 1e-6)
                final_b = normalized_b * target_std_b + current_mean_b
        
        # --- Update and Finalize ---
        modified_lab_A[mask_indices, 0] = final_l
        modified_lab_A[mask_indices, 1] = final_a
        modified_lab_A[mask_indices, 2] = final_b

        modified_lab_A[:, :, 0] = np.clip(modified_lab_A[:, :, 0], LAB_MIN_L, LAB_MAX_L)
        modified_lab_A[:, :, 1] = np.clip(modified_lab_A[:, :, 1], LAB_MIN_A, LAB_MAX_A)
        modified_lab_A[:, :, 2] = np.clip(modified_lab_A[:, :, 2], LAB_MIN_B, LAB_MAX_B)

        result_image_rgb = self.lab_to_rgb(modified_lab_A)
        return result_image_rgb

    def execute(
        self,
        source_image: torch.Tensor,
        source_mask: torch.Tensor,
        reference_image: torch.Tensor,
        method: str,
        l_blend: float,
        a_blend: float,
        b_blend: float,
        a_var_reduction_strength: float,
        b_var_reduction_strength: float,
        calculate_color_difference: bool = True,
        reference_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, str]:
        result = []
        
        # Initialize metrics for BEFORE correction
        total_delta_e_2000_ab_before = 0.0
        total_delta_e_76_ab_before = 0.0
        total_rmse_lab_ab_before = 0.0
        total_mae_lab_ab_before = 0.0
        
        # Initialize metrics for AFTER correction
        total_delta_e_2000_ab_after = 0.0
        total_delta_e_76_ab_after = 0.0
        total_rmse_lab_ab_after = 0.0
        total_mae_lab_ab_after = 0.0
        
        # Process reference image once before the loop
        ref_image_np = self.tensor_to_numpy(reference_image[0])
        ref_mask_np = (
            self.mask_to_numpy(reference_mask[0])
            if reference_mask is not None
            else None
        )
        ref_lab = self.rgb_to_lab(ref_image_np.copy()) if calculate_color_difference else None

        for idx in range(source_image.shape[0]):
            src_image = self.tensor_to_numpy(source_image[idx])
            src_mask = self.mask_to_numpy(source_mask[idx])

            # Calculate color metrics BEFORE correction if requested
            if calculate_color_difference:
                src_lab = self.rgb_to_lab(src_image.copy())
                metrics_before = self.calculate_color_metrics(
                    src_lab, src_mask, ref_lab, ref_mask_np
                )
                total_delta_e_2000_ab_before += metrics_before['delta_e_2000_ab']
                total_delta_e_76_ab_before += metrics_before['delta_e_76_ab']
                total_rmse_lab_ab_before += metrics_before['rmse_lab_ab']
                total_mae_lab_ab_before += metrics_before['mae_lab_ab']

            # Apply histogram matching
            result_np = self.apply_histogram_matching(
                src_image,
                src_mask,
                ref_image_np,
                ref_mask_np,
                l_blend,
                a_blend,
                b_blend,
                method,
                a_var_reduction_strength,
                b_var_reduction_strength,
            )

            # Calculate color metrics AFTER correction if requested
            if calculate_color_difference:
                result_lab = self.rgb_to_lab(result_np.copy())
                metrics_after = self.calculate_color_metrics(
                    result_lab, src_mask, ref_lab, ref_mask_np
                )
                total_delta_e_2000_ab_after += metrics_after['delta_e_2000_ab']
                total_delta_e_76_ab_after += metrics_after['delta_e_76_ab']
                total_rmse_lab_ab_after += metrics_after['rmse_lab_ab']
                total_mae_lab_ab_after += metrics_after['mae_lab_ab']

            result.append(self.numpy_to_tensor(result_np))

        if not result:
            raise ValueError("No valid images were processed.")

        # Calculate average metrics across all processed images
        num_images = source_image.shape[0]
        
        # BEFORE correction averages
        avg_delta_e_2000_ab_before = total_delta_e_2000_ab_before / num_images if calculate_color_difference else 0.0
        avg_delta_e_76_ab_before = total_delta_e_76_ab_before / num_images if calculate_color_difference else 0.0
        avg_rmse_lab_ab_before = total_rmse_lab_ab_before / num_images if calculate_color_difference else 0.0
        avg_mae_lab_ab_before = total_mae_lab_ab_before / num_images if calculate_color_difference else 0.0
        
        # AFTER correction averages
        avg_delta_e_2000_ab_after = total_delta_e_2000_ab_after / num_images if calculate_color_difference else 0.0
        avg_delta_e_76_ab_after = total_delta_e_76_ab_after / num_images if calculate_color_difference else 0.0
        avg_rmse_lab_ab_after = total_rmse_lab_ab_after / num_images if calculate_color_difference else 0.0
        avg_mae_lab_ab_after = total_mae_lab_ab_after / num_images if calculate_color_difference else 0.0

        # Calculate improvement percentages
        def calculate_improvement(before, after):
            if before == 0 or before == float('inf') or after == float('inf'):
                return 0.0
            return ((before - after) / before) * 100.0

        # Format metrics as a readable string for ComfyUI
        if calculate_color_difference:
            delta_e_2000_improvement = calculate_improvement(avg_delta_e_2000_ab_before, avg_delta_e_2000_ab_after)
            delta_e_76_improvement = calculate_improvement(avg_delta_e_76_ab_before, avg_delta_e_76_ab_after)
            rmse_improvement = calculate_improvement(avg_rmse_lab_ab_before, avg_rmse_lab_ab_after)
            mae_improvement = calculate_improvement(avg_mae_lab_ab_before, avg_mae_lab_ab_after)
            
            metrics_text = f"""Color Accuracy Metrics (A&B Channels Only):

BEFORE {method}:
Delta E 2000 (A&B): {avg_delta_e_2000_ab_before:.2f}
Delta E 76 (A&B): {avg_delta_e_76_ab_before:.2f}
RMSE LAB (A&B): {avg_rmse_lab_ab_before:.2f}
MAE LAB (A&B): {avg_mae_lab_ab_before:.2f}

AFTER {method}:
Delta E 2000 (A&B): {avg_delta_e_2000_ab_after:.2f}
Delta E 76 (A&B): {avg_delta_e_76_ab_after:.2f}
RMSE LAB (A&B): {avg_rmse_lab_ab_after:.2f}
MAE LAB (A&B): {avg_mae_lab_ab_after:.2f}

IMPROVEMENT:
Delta E 2000: {delta_e_2000_improvement:+.1f}%
Delta E 76: {delta_e_76_improvement:+.1f}%
RMSE: {rmse_improvement:+.1f}%
MAE: {mae_improvement:+.1f}%

Interpretation (Delta E A&B):
< 1.0: Imperceptible chromaticity difference
1.0-2.0: Very small chromaticity difference
2.0-4.0: Small chromaticity difference  
4.0-6.0: Acceptable chromaticity difference
> 6.0: Large chromaticity difference

Note: Lightness (L) channel excluded from evaluation
Positive improvement % = better color accuracy"""
        else:
            metrics_text = "Color difference calculation disabled"

        result_tensor = torch.stack(result, dim=0)
        return (result_tensor, metrics_text)


# Example usage and testing functions
def create_test_data():
    """Create sample test data for demonstrating color metrics."""
    # Create a simple test image and masks
    test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
    test_mask = np.zeros((100, 100), dtype=np.uint8)
    test_mask[25:75, 25:75] = 255  # Create a square mask in the center
    
    # Create reference image with different colors
    ref_image = np.random.randint(100, 200, (100, 100, 3), dtype=np.uint8)
    ref_mask = np.zeros((100, 100), dtype=np.uint8)
    ref_mask[20:80, 20:80] = 255
    
    return test_image, test_mask, ref_image, ref_mask


if __name__ == "__main__":
    # Example usage for testing
    matcher = LABHistogramMatcher()
    
    # Create test data
    src_img, src_mask, ref_img, ref_mask = create_test_data()
    
    # Convert to tensors (simulating ComfyUI format)
    src_tensor = torch.from_numpy(src_img.astype(np.float32) / 255.0).unsqueeze(0)
    src_mask_tensor = torch.from_numpy(src_mask.astype(np.float32) / 255.0).unsqueeze(0)
    ref_tensor = torch.from_numpy(ref_img.astype(np.float32) / 255.0).unsqueeze(0)
    ref_mask_tensor = torch.from_numpy(ref_mask.astype(np.float32) / 255.0).unsqueeze(0)
    
    # Execute with color difference calculation
    result_tensor, metrics_string = matcher.execute(
        src_tensor, src_mask_tensor, ref_tensor,
        "Reinhard Transfer",
        l_blend=0.0, a_blend=1.0, b_blend=1.0,
        a_var_reduction_strength=2.0,
        b_var_reduction_strength=1.5,
        calculate_color_difference=True,
        reference_mask=ref_mask_tensor
    )
    
    print(f"\nColor Metrics Output:")
    print(metrics_string)
