"""
pipelines stores the main diffusers pipelines and the related loading functions
"""
import os
import sys
from diffusers import (
    StableDiffusionXLAdapterPipeline,
    StableDiffusionXLControlNetPipeline,
    StableDiffusionXLControlNetInpaintPipeline,
    StableDiffusionXLControlNetPAGPipeline,
    StableDiffusionXLInpaintPipeline,
    DiffusionPipeline,
)
from diffusers.schedulers import (
    DDIMScheduler,
    PNDMScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    DPMSolverMultistepScheduler,
    LMSDiscreteScheduler,
    UniPCMultistepScheduler,
)
from transformers import CLIPVisionModelWithProjection
import torch
import yaml

from huggingface_hub import snapshot_download
from functools import partial

access_token = os.environ["HF_TOKEN"]

# The loaded pipeline (if any)
pipe = None

# The loaded refiner (if any)
refiner = None

# The loaded schedulers by name (if any)
schedulers = {}


def load_pipeline(
    model,
    clazz,
    load_gpu=True,
    scheduler_clazz=DPMSolverMultistepScheduler,
    scheduler_karras=True,
    variant=None,
    controlnet=False,
    t2i_adapter=False,
    image_encoder_name=None,
    image_encoder_clazz=None,
    image_encoder_subfolder=None,
    ip_adapter_name=None,
    ip_adapter_subfolder=None,
    ip_adapter_weight_name=None,
    vae=None,
    vae_clazz=None,
    enable_model_cpu_offload=False,
    local_files_only=True,
):
    """This is a utility function to make loading pipelines easy and reduce code duplication."""
    model_args = {
        "torch_dtype": torch.bfloat16,
        "token": access_token,
    }
    if t2i_adapter is True:
        # We must specify an adapter, the real t2i_adapter will be loaded at inference time
        model_args["adapter"] = None
    if controlnet is True:
        # We must specify an controlnet, the real t2i_adapter will be loaded at inference time
        model_args["controlnet"] = None
    if variant is not None:
        model_args["variant"] = variant
    if None not in [vae, vae_clazz]:
        vae = vae_clazz.from_pretrained(
            vae, torch_dtype=torch.bfloat16, token=access_token
        )
        model_args["vae"] = vae
    if None not in [image_encoder_clazz, image_encoder_name, image_encoder_subfolder]:
        image_encoder = image_encoder_clazz.from_pretrained(
            image_encoder_name,
            subfolder=image_encoder_subfolder,
            torch_dtype=torch.bfloat16,
            token=access_token,
            local_files_only=local_files_only,
        )
        model_args["image_encoder"] = image_encoder

    p = clazz.from_pretrained(model, local_files_only=local_files_only, **model_args)
    p.scheduler = scheduler_clazz.from_pretrained(
        model,
        subfolder="scheduler",
        token=access_token,
        local_files_only=local_files_only,
        use_karras_sigmas=scheduler_karras,
    )

    if None not in [ip_adapter_name, ip_adapter_subfolder, ip_adapter_weight_name]:
        p.load_ip_adapter(
            ip_adapter_name,
            subfolder=ip_adapter_subfolder,
            weight_name=ip_adapter_weight_name,
            local_files_only=local_files_only,
        )

    if load_gpu:
        if enable_model_cpu_offload:
            p.enable_model_cpu_offload()
        else:
            # need to explicitly move pipeline's optional parameters, such as image_encoder to cuda
            _, optional_params = p._get_signature_keys(p)
            modules = [getattr(p, n, None) for n in optional_params]
            modules = [m for m in modules if isinstance(m, torch.nn.Module)]
            for module in modules:
                module.to("cuda")
            p.to("cuda")

    return p


def load_refiner(pipeline, refiner_name, local_files_only=True):
    """This is a utility function to load an SDXL refiner"""
    r = DiffusionPipeline.from_pretrained(
        refiner_name,
        text_encoder_2=pipeline.text_encoder_2,
        vae=pipeline.vae,
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
        variant="fp16",
        local_files_only=local_files_only,
    )
    r.enable_model_cpu_offload()
    return r


def load_sdxl_controlnet(load_gpu=True, local_files_only=True, with_refiner=True):
    global pipe
    pipe = load_pipeline(
        "SG161222/RealVisXL_V5.0",
        StableDiffusionXLControlNetPipeline,
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        controlnet=True,
        variant="fp16",
    )
    if with_refiner:
        global refiner
        refiner = load_refiner(
            pipe,
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            local_files_only=local_files_only,
        )


def load_sdxl_controlnet_ip_adapter(load_gpu=True, local_files_only=True):
    global pipe
    pipe = load_pipeline(
        "SG161222/RealVisXL_V5.0",
        StableDiffusionXLControlNetPipeline,
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        variant="fp16",
        controlnet=True,
        image_encoder_name="h94/IP-Adapter",
        image_encoder_clazz=CLIPVisionModelWithProjection,
        image_encoder_subfolder="models/image_encoder",
        ip_adapter_name="h94/IP-Adapter",
        ip_adapter_subfolder="sdxl_models",
        ip_adapter_weight_name="ip-adapter-plus_sdxl_vit-h.safetensors",
        enable_model_cpu_offload=True,  # required to fit within 24G with refiner
    )
    global refiner
    refiner = load_refiner(
        pipe,
        "stabilityai/stable-diffusion-xl-refiner-1.0",
        local_files_only=local_files_only,
    )


def load_sdxl_controlnet_ip_adapter_global(load_gpu=True, local_files_only=True, with_refiner=False):
    global pipe
    pipe = load_pipeline(
        "SG161222/RealVisXL_V5.0",
        StableDiffusionXLControlNetPipeline,
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        variant="fp16",
        controlnet=True,
        ip_adapter_name="h94/IP-Adapter",
        ip_adapter_subfolder="sdxl_models",
        ip_adapter_weight_name="ip-adapter_sdxl.safetensors",
        enable_model_cpu_offload=True,
    )
    if with_refiner:
        global refiner
        refiner = load_refiner(
            pipe,
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            local_files_only=local_files_only,
        )


def load_sdxl_controlnet_pag(load_gpu=True, local_files_only=True, with_refiner=False):
    global pipe
    pipe = load_pipeline(
        "SG161222/RealVisXL_V5.0",
        StableDiffusionXLControlNetPAGPipeline,
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        controlnet=True,
        variant="fp16",
    )
    if with_refiner:
        global refiner
        refiner = load_refiner(
            pipe,
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            local_files_only=local_files_only,
        )


def load_sdxl_t2i_adapter(load_gpu=True, local_files_only=True, with_refiner=True):
    global pipe
    pipe = load_pipeline(
        "SG161222/RealVisXL_V5.0",
        StableDiffusionXLAdapterPipeline,
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        variant="fp16",
        t2i_adapter=True,
    )
    if with_refiner:
        global refiner
        refiner = load_refiner(
            pipe,
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            local_files_only=local_files_only,
        )


def load_sdxl_inpainting(load_gpu=True, local_files_only=True):
    global pipe
    pipe = load_pipeline(
        "reimager/juggernaut-xl-inpainting",
        StableDiffusionXLInpaintPipeline,
        variant="fp16",
        load_gpu=load_gpu,
        local_files_only=local_files_only,
    )


def load_sdxl_inpainting_controlnet(load_gpu=True, local_files_only=True):
    global pipe
    pipe = load_pipeline(
        "reimager/juggernaut-xl-inpainting",
        StableDiffusionXLControlNetInpaintPipeline,
        variant="fp16",
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        controlnet=True,
    )


def load_sdxl_inpainting_ip_adapter(load_gpu=True, local_files_only=True):
    global pipe
    pipe = load_pipeline(
        "reimager/juggernaut-xl-inpainting",
        StableDiffusionXLInpaintPipeline,
        variant="fp16",
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        image_encoder_name="h94/IP-Adapter",
        image_encoder_clazz=CLIPVisionModelWithProjection,
        image_encoder_subfolder="models/image_encoder",
        ip_adapter_name="h94/IP-Adapter",
        ip_adapter_subfolder="sdxl_models",
        ip_adapter_weight_name="ip-adapter-plus_sdxl_vit-h.safetensors",
    )


def load_sdxl_inpainting_controlnet_ip_adapter(load_gpu=True, local_files_only=True):
    global pipe
    pipe = load_pipeline(
        "reimager/juggernaut-xl-inpainting",
        StableDiffusionXLControlNetInpaintPipeline,
        variant="fp16",
        load_gpu=load_gpu,
        local_files_only=local_files_only,
        controlnet=True,
        image_encoder_name="h94/IP-Adapter",
        image_encoder_clazz=CLIPVisionModelWithProjection,
        image_encoder_subfolder="models/image_encoder",
        ip_adapter_name="h94/IP-Adapter",
        ip_adapter_subfolder="sdxl_models",
        ip_adapter_weight_name="ip-adapter-plus_sdxl_vit-h.safetensors",
    )


def load_hat_superres(scale: int, load_gpu=True, local_files_only=True) -> None:
    """Get HAT Super Resolution model according to the scale factor.
    Current accepted values of scale factors are: [2, 3, 4]

    Args:
        scale (int): The Upscaling factor.

    """
    global pipe
    from basicsr import build_model
    # Need the below imports for basicsr model registry.
    from hat import HAT, HATModel
    print(f"Loading HAT_X{scale} model...")
    MODEL_NAME = f"HAT_X{scale}"
    HAT_CONFIG_ROOT = f"{os.path.dirname(__file__)}/configs/hat"
    with open(os.path.join(HAT_CONFIG_ROOT, f"{MODEL_NAME}.yml"), "r") as f:
        opt = yaml.load(f, Loader=yaml.CLoader)

    ckpt_path = snapshot_download(repo_id=f"reimager/HAT_SR", local_files_only=local_files_only)
    ckpt_path = os.path.join(ckpt_path, f"{MODEL_NAME}.pth")
    opt["path"]["pretrain_network_g"] = ckpt_path

    device = torch.device("cuda" if load_gpu else "cpu")
    pipe = build_model(opt)
    pipe.opt = opt
    pipe.device = device
    print(f"Loading HAT_X{scale} model... done")

# pipeline_loaders stores the various pipeline loading functions
# the key is the "human name" for the pipeline
pipeline_loaders = {
    "sdxl_controlnet": load_sdxl_controlnet,
    "sdxl_controlnet_ip_adapter": load_sdxl_controlnet_ip_adapter,
    "sdxl_controlnet_ip_adapter_global": load_sdxl_controlnet_ip_adapter_global,
    "sdxl_controlnet_pag": load_sdxl_controlnet_pag,
    "sdxl_t2i_adapter": load_sdxl_t2i_adapter,
    "sdxl_inpainting": load_sdxl_inpainting,
    "sdxl_inpainting_ip_adapter": load_sdxl_inpainting_ip_adapter,
    "sdxl_inpainting_controlnet": load_sdxl_inpainting_controlnet,
    "sdxl_inpainting_controlnet_ip_adapter": load_sdxl_inpainting_controlnet_ip_adapter,
    "hat_superres_2X": partial(load_hat_superres, scale=2),
    "hat_superres_3X": partial(load_hat_superres, scale=3),
    "hat_superres_4X": partial(load_hat_superres, scale=4),
}


def free_components(pipe):
    if pipe is None:
        return
    components = [
        "text_encoder",
        "text_encoder_2",
        "tokenizer",
        "tokenizer_2",
        "unet",
        "vae",
        "controlnet",
        "adapter",
    ]
    for component in components:
        try:
            delattr(pipe, component)
        except:
            pass


def free_pipelines():
    global pipe
    free_components(pipe)
    pipe = None
    global refiner
    free_components(refiner)
    refiner = None
    torch.cuda.empty_cache()


def pipeline_get_loader(pipeline):
    return pipeline_loaders.get(pipeline)


def scheduler_get(scheduler, scheduler_config):
    if scheduler not in scheduler_confs:
        raise Exception(f"Unknown scheduler: {scheduler}")
    global schedulers
    if scheduler in schedulers:
        return schedulers[scheduler]
    print(f"Loading {scheduler} scheduler...")
    schedulers[scheduler] = scheduler_confs[scheduler]["class"].from_config(
        scheduler_config, **scheduler_confs[scheduler]["args"]
    )
    return schedulers[scheduler]


# this is the different schedulers available
# this can be swapped out at inference time
scheduler_confs = {
    "pndm": {
        "class": PNDMScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
    "dpms": {
        "class": DPMSolverMultistepScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
    "ddim": {
        "class": DDIMScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
    "euler": {
        "class": EulerDiscreteScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
    "euler-a": {
        "class": EulerAncestralDiscreteScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
    "lms": {
        "class": LMSDiscreteScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
    "unipc": {
        "class": UniPCMultistepScheduler,
        "args": {
            "use_karras_sigmas": True,
        },
    },
}


def download_models():
    print("Downloading pipeline models...")

    for name, loader in pipeline_loaders.items():
        print(f"Downloading {name}...")
        # load but don't load to GPU
        # set local_files_only to False to allow download
        loader(load_gpu=False, local_files_only=False)
        sys.stdout.flush()
        sys.stderr.flush()
    global pipe
    pipe = None
    global refiner
    refiner = None

    print("Downloading pipeline models... done\n\n")
