import os
import glob
import cv2
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
import folder_paths
import comfy.model_management


class LoadDeepLabV3():
    @classmethod
    def INPUT_TYPES(cls):
        list_models = os.listdir(os.path.join(folder_paths.models_dir, "deeplabv3"))
        return {
            "required": {
                "model_path": (list_models, ),
            }
        }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "load_model"
    CATEGORY = "ReImage AI"
    def load_model(self, model_path):
        weights = FCN_ResNet50_Weights.DEFAULT
        model = fcn_resnet50(weights=weights)
        model.n_classes = 1
        model.n_channels = 3
        # Replace the last layer with a new 1x1 convolution layer
        model.classifier[4] = nn.Conv2d(512, 1, kernel_size=1)
        model.aux_classifier[4] = nn.Conv2d(256, 1, kernel_size=1)
        model.load_state_dict(torch.load(os.path.join(folder_paths.models_dir, "deeplabv3", model_path)))
        device = comfy.model_management.get_torch_device()
        model.to(device)
        model.eval()
        return (model,)

def preprocess_image(image, image_size=1024):
    if image.shape[-1] == 4:
        image = image[..., :3]
    elif image.shape[-1] == 1:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    image = image[..., ::-1] # BGR to RGB
    image = cv2.resize(image, (image_size, image_size))
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image)
    return image.unsqueeze(0)


class ApplyDeepLabV3():
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "model": ("MODEL",),
                "images": ("IMAGE",),
                "size_process": ("INT", {"default": "640"}),
            }
        }
    RETURN_TYPES = ("MASK",)
    FUNCTION = "apply_model"
    CATEGORY = "ReImage AI"
    def tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI image tensor (H, W, C) to numpy array (H, W, C)."""
        numpy_array = tensor.cpu().numpy()
        numpy_array = (numpy_array * 255).clip(0, 255).astype(np.uint8)
        return numpy_array

    def numpy_to_tensor(self, numpy_array: np.ndarray) -> torch.Tensor:
        """Convert numpy array (H, W, C) to ComfyUI image tensor (H, W, C)."""
        tensor = torch.from_numpy(numpy_array.astype(np.float32) / 255.0)
        return tensor

    def mask_to_numpy(self, mask_tensor: torch.Tensor) -> np.ndarray:
        """Convert ComfyUI mask tensor (H, W) to numpy array (H, W)."""
        mask_np = mask_tensor.cpu().numpy()
        mask_np = (mask_np * 255).clip(0, 255).astype(np.uint8)
        return mask_np

    def apply_model(self, model, images , size_process):
        # import pdb; pdb.set_trace()
        device = comfy.model_management.get_torch_device()
        results = []
        for image in images:
            height, width = image.shape[0], image.shape[1]
            image = self.tensor_to_numpy(image)
            image =  preprocess_image(image, size_process)
            image = image.to(device)
            with torch.no_grad():
                output = model(image)['out']
            output = F.sigmoid(output)
            output[output > 0.5] = 255
            output[output <= 0.5] = 0
            output = output.squeeze(0).squeeze(0).detach().cpu().numpy()
            output = Image.fromarray(output)
            output = output.resize((width, height), Image.LANCZOS)
            final_mask = np.array(output).astype(np.uint8)
            # final_mask = output.astype(np.uint8)
            results.append(self.numpy_to_tensor(final_mask))
        results = torch.stack(results, dim=0)
        return (results,)
