"""
models.py contains many supporting models like controlnet, controlnet related, CLIP, etc
It contains the functions to load, evaluate, and related functions
"""

import os
import torch

from controlnet_aux import (
    MLSDdetector,
    PidiNetDetector,
)
from transformers import (
    AutoImageProcessor,
    AutoModelForDepthEstimation,
    CLIPProcessor,
    CLIPModel,
    SamModel,
    SamProcessor,
    AutoProcessor,
    GroundingDinoForObjectDetection,
)
from diffusers import (
    ControlNetModel,
    T2IAdapter,
    MultiAdapter,
)
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from simple_lama_inpainting import SimpleLama
from huggingface_hub import snapshot_download, hf_hub_download

access_token = os.environ["HF_TOKEN"]
depth_processor = None
depth_model = None
hed_model = None
mlsd_model = None
clip_model = None
clip_processor = None
sam_processor = None
sam_model = None
sam3_processor = None
grounding_dino_model = None
grounding_dino_processor = None
lama_model = None

controlnet_models = {
    "canny-sdxl": None,
    "depth-sdxl": None,
    "canny": None,
    "depth": None,
    "mlsd": None,
    "tile": None,
}

adapter_models = {
    "depth": None,
}


def load_models():
    print("Loading models...")
    load_depth_model()
    load_hed_model()
    load_mlsd_model()
    load_clip_models()
    load_sam1_model()
    load_sam3_model()
    load_grounding_dino_models()
    load_lama_model()
    print("Loading models... done\n\n")


def download_models():
    print("Downloading models...", flush=True)

    import nltk
    nltk_dir = os.environ.get("NLTK_DATA", "/mldata/nltk_data")
    nltk.download("averaged_perceptron_tagger_eng", download_dir=nltk_dir)
    nltk.download("punkt_tab", download_dir=nltk_dir)

    ControlNetModel.from_pretrained(
        "diffusers/controlnet-canny-sdxl-1.0-mid",
        variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.bfloat16,
    )
    ControlNetModel.from_pretrained(
        "diffusers/controlnet-depth-sdxl-1.0-mid",
        variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.bfloat16,
    )
    ControlNetModel.from_pretrained(
        "reimager/controlnet-canny-sdxl-white-cabinets-1.0",
        variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.bfloat16,
        token=access_token,
    )
    ControlNetModel.from_pretrained(
        "reimager/controlnet-depth-sdxl-white-cabinets-1.0",
        variant="fp16",
        use_safetensors=True,
        torch_dtype=torch.bfloat16,
        token=access_token,
    )
    ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.bfloat16)
    ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.bfloat16)
    ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_mlsd", torch_dtype=torch.bfloat16)
    ControlNetModel.from_pretrained("lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.bfloat16)
    T2IAdapter.from_pretrained(
        "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
        torch_dtype=torch.bfloat16,
        varient="fp16",
    )
    T2IAdapter.from_pretrained(
        "TencentARC/t2i-adapter-canny-sdxl-1.0",
        torch_dtype=torch.bfloat16,
        varient="fp16",
    )

    AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf")
    AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf")
    PidiNetDetector.from_pretrained("lllyasviel/Annotators")
    MLSDdetector.from_pretrained("lllyasviel/ControlNet")
    CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
    SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
    AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
    GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base")
    AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
    GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
    hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", local_files_only=False, token="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")
    SimpleLama()

    print("Downloading models... done\n\n", flush=True)


def load_depth_model():
    global depth_processor
    global depth_model
    if depth_model is not None:
        return
    print("Loading depth conditioner model...")
    depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf", local_files_only=True)
    depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf", local_files_only=True).to("cuda")
    print("Loading depth conditioner model... done")


def load_hed_model():
    global hed_model
    if hed_model is not None:
        return
    print("Loading hed conditioner model...")
    hed_model = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
    hed_model.to("cuda")
    print("Loading hed conditioner model... done")


def load_mlsd_model():
    global mlsd_model
    if mlsd_model is not None:
        return
    print("Loading mlsd conditioner model...")
    mlsd_model = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
    mlsd_model.to("cuda")
    print("Loading mlsd conditioner model... done")


def load_clip_models():
    global clip_model
    global clip_processor
    if clip_model is not None and clip_processor is not None:
        return
    print("Loading CLIP model...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=True)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", local_files_only=True)
    clip_model.to("cuda")
    # clip_processor is CPU only
    # clip_processor.to("cuda")
    print("Loading CLIP model... done")


def load_sam1_model():
    global sam_processor
    global sam_model
    print("Loading SAM1 model...")
    sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77", local_files_only=True)
    sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77", local_files_only=True)
    print("Loading SAM1 model... done")


def load_sam3_model():
    global sam3_processor
    print("Loading SAM3 model...")
    SAM3_BPE_PATHS = [
        "/usr/local/src/sam3/assets/bpe_simple_vocab_16e6.txt.gz",
        f"{os.path.dirname(__file__)}/assets/bpe_simple_vocab_16e6.txt.gz",
        "./assets/bpe_simple_vocab_16e6.txt.gz",
    ]
    sam3_bpe_path = ""
    for possible_path in SAM3_BPE_PATHS:
        if os.path.exists(possible_path):
            sam3_bpe_path = possible_path
            break
    model_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", local_files_only=True)
    sam3_model = build_sam3_image_model(
        bpe_path=sam3_bpe_path,
        device= "cuda",
        eval_mode=True,
        checkpoint_path=model_path,
        load_from_HF=False,
        enable_segmentation=True,
        enable_inst_interactivity=False,
    )
    sam3_processor = Sam3Processor(sam3_model, device="cuda")
    print("Loading SAM3 model... done")


def load_grounding_dino_models():
    global grounding_dino_processor
    global grounding_dino_model
    grounding_dino_model = {}
    grounding_dino_processor = {}
    grounding_dino_processor['base'] = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base", local_files_only=True)
    grounding_dino_model['base'] = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base", local_files_only=True).to("cuda")
    grounding_dino_processor['tiny'] = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny", local_files_only=True)
    grounding_dino_model['tiny'] = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny", local_files_only=True).to("cuda")

    print("Loading Grounding DINO model... done")


def load_lama_model():
    global lama_model
    if lama_model is None:
        print("Loading LAMA model...")
        lama_model = SimpleLama()
        print("Loading LAMA model... done")
    return lama_model


def controlnet_get(control_type):
    """returns a controlnet model for the appropriate type, will load & memoize as necessary"""
    if control_type is None:
        return None
    if isinstance(control_type, list):
        if len(control_type) == 0:
            raise Exception(f"unknown control type: {control_type}")
        if len(control_type) == 1:
            return controlnet_get(control_type[0])
        m = MultiControlNetModel(list(map(controlnet_get, control_type)))
        m.to("cuda")
        return m
    if control_type not in [
        "canny-sdxl",
        "depth-sdxl",
        "depth",
        "canny",
        "mlsd",
        "tile",
    ]:
        raise Exception(f"unknown control type: {control_type}")

    if controlnet_models.get(control_type) is None:
        print(f"Loading {control_type} controlnet model...")
        model = None
        if control_type == "canny-sdxl":
            model = ControlNetModel.from_pretrained(
                "diffusers/controlnet-canny-sdxl-1.0-mid",
                variant="fp16",
                use_safetensors=True,
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
        elif control_type == "depth-sdxl":
            model = ControlNetModel.from_pretrained(
                "diffusers/controlnet-depth-sdxl-1.0-mid",
                variant="fp16",
                use_safetensors=True,
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
        elif control_type == "canny":
            model = ControlNetModel.from_pretrained(
                "lllyasviel/control_v11p_sd15_canny",
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
        elif control_type == "depth":
            model = ControlNetModel.from_pretrained(
                "lllyasviel/control_v11f1p_sd15_depth",
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
        elif control_type == "mlsd":
            model = ControlNetModel.from_pretrained(
                "lllyasviel/control_v11p_sd15_mlsd",
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
        elif control_type == "tile":
            model = ControlNetModel.from_pretrained(
                "lllyasviel/control_v11f1e_sd15_tile",
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
        controlnet_models[control_type] = model
        controlnet_models[control_type].to("cuda")
    return controlnet_models[control_type]


def adapter_get(adapter_type):
    """returns a adapter model for the appropriate type, will load & memoize as necessary"""
    if adapter_type is None:
        return None
    if isinstance(adapter_type, list):
        if len(adapter_type) == 0:
            raise Exception(f"unknown adapter type: {adapter_type}")
        if len(adapter_type) == 1:
            return adapter_get(adapter_type[0])
        m = MultiAdapter(list(map(adapter_get, adapter_type)))
        m.to("cuda")
        return m
    if adapter_type not in ["depth", "canny"]:
        raise Exception(f"unknown adapter type: {adapter_type}")

    if adapter_models.get(adapter_type) is None:
        print(f"Loading {adapter_type} adapter model...")
        model = None
        if adapter_type == "depth":
            model = T2IAdapter.from_pretrained(
                "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
                torch_dtype=torch.bfloat16,
                varient="fp16",
                local_files_only=True,
            )
        elif adapter_type == "canny":
            model = T2IAdapter.from_pretrained(
                "TencentARC/t2i-adapter-canny-sdxl-1.0",
                torch_dtype=torch.bfloat16,
                varient="fp16",
                local_files_only=True,
            )
        adapter_models[adapter_type] = model
        adapter_models[adapter_type].to("cuda")
    return adapter_models[adapter_type]
