"""stable diffusion service provider"""

import os

os.environ["NLTK_DATA"] = "/mldata/nltk_data"
import sys
import torch
import tempfile
import traceback
import numpy as np
import cv2
import re
import torchvision.transforms as transforms
from nltk import pos_tag
from nltk.tokenize import word_tokenize
from PIL import Image
from pytorch_lightning import seed_everything
from grounded_sam import (
    get_mask_iou,
    sam_masks,
    Segment,
    SegmentationPipeline,
)
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

import pipelines
import models
from util import tmpname

from checkpoints import OperationCheckpoint

SEG_CONFIG_ROOT = f"{os.path.dirname(__file__)}/configs/segmentation"
SEG_LATEST_VERSION = 3


def classification_room(image=None):
    if image is None:
        raise Exception("image required")

    classes = [
        "bathroom",
        "bedroom",
        "dining room",
        "kitchen",
        "living room",
        "office room",
        "laundry room",
        "hallway",
    ]
    probs = clip(image, classes)
    probs = probs.tolist()[0]

    classprobs = zip(classes, probs)
    # only look at classes of a certain probability
    classprobs = list(filter(lambda x: (x[1] > 0.3), classprobs))
    if len(classprobs) == 0:
        return None
    classprobs = sorted(classprobs, key=lambda x: x[1], reverse=True)
    for i, clazz in enumerate(classprobs):
        if clazz[0] == "office room":
            classprobs[i] = ("office", clazz[1])
    return classprobs


def classify_room(image=None):
    classification = classification_room(image)
    if classification is None:
        return None
    return classification[0][0]


def classification_location(image=None):
    if image is None:
        raise Exception("image required")

    classes = ["interior", "exterior", "inside", "outside"]
    probs = clip(image, classes)
    probs = probs.tolist()[0]

    classprobs = zip(classes, probs)
    classprobs = sorted(list(classprobs), key=lambda x: x[1], reverse=True)

    # just include the highest confidence of exterior/outside and interior/inside
    finalprobs = []
    interior_added = False
    exterior_added = False
    for clazz in classprobs:
        if clazz[0] in ["interior", "inside"]:
            if not interior_added:
                finalprobs.append(("interior", clazz[1]))
                interior_added = True
        if clazz[0] in ["exterior", "outside"]:
            if not exterior_added:
                finalprobs.append(("exterior", clazz[1]))
                exterior_added = True

    return finalprobs


def classify_location(image=None):
    classification = classification_location(image)
    if classification is None:
        return None
    return classification[0][0]


def clip(image_path, prompts):
    img = Image.open(image_path).convert("RGB")
    inputs = models.clip_processor(text=prompts, images=img, return_tensors="pt", padding=True).to("cuda")
    outputs = models.clip_model(**inputs)
    logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
    return probs


def seg(
    image=None,
    preset=None,
    classes=None,
    negative_classes=None,
    boxes=None,
    points=None,
    expand=None,
    mask_size="auto",
):
    if image is None:
        raise Exception("image required")
    print(f"seg( preset={preset}, classes={classes}, points={points} )")
    pipe = None
    if preset:
        for i in reversed(range(1, SEG_LATEST_VERSION+1)):
            config_path = os.path.join(SEG_CONFIG_ROOT, f"v{i}")
            preset_config_file = os.path.join(config_path, f"{preset}.yaml")
            if os.path.exists(preset_config_file):
                print(f"Loading {preset_config_file}...")
                pipe = SegmentationPipeline.from_config(preset_config_file, version=i)
                break
        if pipe is None:
            raise Exception(f"unknown preset: {preset}")
    elif points or boxes:
        # large = return largest mask
        # small = return smallest mask
        # auto = return best scoring mask
        if mask_size == "large":
            sort="desc"
            sort_by="size"
        elif mask_size == "small":
            sort="asc"
            sort_by="size"
        elif mask_size == "auto":
            sort="desc"
            sort_by="score"
        segment = Segment(
            segment_id=0,
            name="segment",
            boxes=boxes,
            points=points,
            padding=expand,
            max_num_masks=1,
            sort=sort,
            sort_by=sort_by,
        )
        pipe = SegmentationPipeline(version=1)
        pipe.add_segment(segment)
    elif classes and negative_classes:
        pos_segment = Segment(
            segment_id=0,
            name="pos",
            mask_type="pos",
            classes=classes,
        )
        neg_segment = Segment(
            segment_id=1,
            name="neg",
            mask_type="neg",
            classes=negative_classes,
        )
        output = Segment(
            segment_id=2,
            name="output",
            mask_type="pos",
            padding=expand,
            dependencies=[pos_segment, neg_segment],
        )
        pipe = SegmentationPipeline()
        pipe.add_segments([pos_segment, neg_segment, output])
    else:
        segment = Segment(
            segment_id=0,
            name="pos",
            mask_type="pos",
            classes=classes,
        )
        pipe = SegmentationPipeline()
        pipe.add_segment(segment)
    img = Image.open(image).convert("RGB")
    result = pipe(image=img)
    # set all black to transparent (this app expects this)
    mask = result.mask_image.convert("RGBA")
    pixdata = mask.load()
    width, height = img.size
    for y in range(height):
        for x in range(width):
            if (pixdata[x, y][0], pixdata[x, y][1], pixdata[x, y][2]) == (0, 0, 0):
                pixdata[x, y] = (0, 0, 0, 0)
    image_path = tmpname(suffix=".png", prefix="sam-")
    mask.save(image_path)
    return image_path


def mlsd(image=None):
    if image is None:
        raise Exception("image required")
    mlsd_image_path = tmpname(suffix=".png", prefix="mlsd-")
    img = conditioner_mlsd(image)
    img.save(mlsd_image_path)
    return mlsd_image_path


def hed(image=None):
    if image is None:
        raise Exception("image required")
    hed_image_path = tmpname(suffix=".png", prefix="hed-")
    img = conditioner_hed(image)
    img.save(hed_image_path)
    return hed_image_path


def canny(image=None):
    if image is None:
        raise Exception("image required")
    canny_image_path = tmpname(suffix=".png", prefix="canny-")
    img = conditioner_canny(image)
    img.save(canny_image_path)
    return canny_image_path


def depth(image=None, colormap=None):
    if image is None:
        raise Exception("image required")
    depth_image_path = tmpname(suffix=".png", prefix="depth-")
    img = conditioner_depth(image)
    img.save(depth_image_path)

    # do nothing if colormap = grayscale r None
    if colormap is None or colormap == "grayscale":
        return depth_image_path

    try:
        colormap = re.sub(r"\W+", "", colormap)
        im = cv2.imread(depth_image_path)
        # pylint: disable=eval-used
        im = cv2.applyColorMap(im, eval(f"cv2.COLORMAP_{colormap.upper()}"))
        cv2.imwrite(depth_image_path, im)
    except:
        # just ignore return the un-colormaped image
        traceback.print_exc(file=sys.stdout)
    return depth_image_path


def enhance_image(image=None):
    if image is None:
        raise Exception("image required")
    enhanced_image_path = tmpname(suffix=".png", prefix="enhanced-")
    image = cv2.imread(image)
    image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(image_lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    cl = clahe.apply(l)
    enhanced_lab = cv2.merge((cl, a, b))
    enhanced_image = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2BGR)
    cv2.imwrite(enhanced_image_path, enhanced_image)
    return enhanced_image_path


def conditioner_pack(image):
    """This is just an API call to return 5 conditioned images used by remodel AI"""
    results = {}
    results["image-0"] = depth(image)
    results["image-1"] = depth(image, colormap="magma")
    results["image-2"] = mlsd(image)
    results["image-3"] = canny(image)
    results["image-4"] = hed(image)
    return results


def conditioner_mlsd(image_path):
    # https://huggingface.co/lllyasviel/sd-controlnet-mlsd
    height, width, _ = cv2.imread(image_path).shape
    img = Image.open(image_path).convert("RGB")
    res = 1024
    img = models.mlsd_model(img, detect_resolution=res, image_resolution=res)
    img = img.resize((width, height))
    return img


def conditioner_hed(image_path):
    # https://huggingface.co/lllyasviel/sd-controlnet-hed
    height, width, _ = cv2.imread(image_path).shape
    img = Image.open(image_path).convert("RGB")
    img = models.hed_model(img)
    img = img.resize((width, height))
    return img


def conditioner_canny(image_path):
    # Use Edge Drawing which gives better canny-like edges
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    ed = cv2.ximgproc.createEdgeDrawing()
    params = cv2.ximgproc.EdgeDrawing.Params()
    params.PFmode = True
    ed.setParams(params)
    edges = ed.detectEdges(image)
    edge_map = ed.getEdgeImage(edges)
    return Image.fromarray(edge_map)


def conditioner_depth(image_path):
    # This code is pulled from
    # https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything_v2
    height, width, _ = cv2.imread(image_path).shape
    image = Image.open(image_path).convert("RGB")
    inputs = models.depth_processor(images=image, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = models.depth_model(**inputs)
    post_processed_output = models.depth_processor.post_process_depth_estimation(
        outputs,
        target_sizes=[(image.height, image.width)],
    )
    predicted_depth = post_processed_output[0]["predicted_depth"]
    depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
    depth = depth.detach().cpu().numpy() * 255
    depth = Image.fromarray(depth.astype("uint8"))
    return depth


def controlnet(
    prompt=None,
    control_types=None,
    control_images_conditioned=None,
    control_scales=None,
    batch_size=None,
    steps=None,
    guidance_scale=None,
    seed=None,
    scheduler=None,
    width=None,
    height=None,
    negative_prompt=None,
    ip_adapter_image=None,
    ip_adapter_scale=None,
):
    if prompt is None:
        raise Exception("prompt required")
    if (
        (width is None or height is None)
        and control_images_conditioned is not None
        and len(control_images_conditioned) > 0
    ):
        height, width, _ = cv2.imread(control_images_conditioned[0]).shape
    if width is None:
        raise Exception("width required")
    if height is None:
        raise Exception("height required")
    if control_images_conditioned is None or len(control_images_conditioned) == 0:
        raise Exception("control image(s) required")
    batch_size = 4 if batch_size is None else batch_size
    steps = 50 if steps is None else steps
    guidance_scale = 7.5 if guidance_scale is None else guidance_scale
    ip_adapter_scale = 0.7 if ip_adapter_scale is None else ip_adapter_scale
    scheduler = "dpms" if scheduler is None else scheduler
    if width % 64 != 0 or height % 64 != 0:
        raise Exception(f"unsupported size  width: {width} height: {height}")
    batch_size = adjust_batch_size(batch_size, width, height)
    guidance_scale = float(guidance_scale)
    if control_images_conditioned is not None:
        for control_image_conditioned in control_images_conditioned:
            control_image_height, control_image_width, _ = cv2.imread(control_image_conditioned).shape
            if control_image_height % 64 != 0 or control_image_width % 64 != 0:
                raise Exception(f"dimensions unsupported: {width}x{height}")
            if control_image_height != height or control_image_width != width:
                raise Exception(
                    f"control image dimensions do not match: {width}x{height} {control_image_width}x{control_image_height}"
                )
    results = {}

    print(f"controlnet| prompt: {prompt} negative_prompt: {negative_prompt}")
    print(
        f"controlnet| control_types: {control_types} control_scales: {control_scales} control_images_conditioned: {control_images_conditioned}"
    )
    print(f"controlnet| scheduler: {scheduler} width: {width} height: {height}")
    print(
        f"controlnet| batch_size: {batch_size} steps: {steps} seed: {seed} guidance_scale: {guidance_scale}",
        flush=True,
    )
    if ip_adapter_image is not None:
        print(f"controlnet| ip_adapter_image: {ip_adapter_image} ip_adapter_scale: {ip_adapter_scale}")

    pipe = pipelines.pipe
    # We load the controlnet & scheduler at inference time
    # instead of having it permanently attached to the pipe because it is quick
    # This reduces the number of subprocs for unique pipelines
    pipe.controlnet = models.controlnet_get(control_types)
    pipe.scheduler = pipelines.scheduler_get(scheduler, pipe.scheduler.config)

    if seed is not None:
        seed_everything(seed)

    # Process the control images
    control_images = []
    for control_image_conditioned in control_images_conditioned:
        control_images.append(Image.open(control_image_conditioned).convert("RGB"))

    args = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "num_images_per_prompt": batch_size,
        "width": width,
        "height": height,
        "image": control_images,
        "guidance_scale": guidance_scale,
        "num_inference_steps": steps,
    }
    if ip_adapter_image is not None:
        args["ip_adapter_image"] = Image.open(ip_adapter_image).convert("RGB")
        pipe.set_ip_adapter_scale(ip_adapter_scale)
    if control_scales is not None and len(control_scales) > 0:
        args["controlnet_conditioning_scale"] = control_scales[0] if len(control_scales) == 1 else control_scales
    result = pipe(**args)

    if len(result.images) != batch_size:
        print(f"Unexpected result size {len(result.images)} != {batch_size}")
        raise Exception(f"unexpected result size {len(result.images)} != {batch_size}")

    torch.cuda.empty_cache()

    refiner = pipelines.refiner
    if refiner is not None:
        refiner_steps = 20
        refiner_strength = 0.13
        result = refiner(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_images_per_prompt=batch_size,
            num_inference_steps=refiner_steps,
            strength=refiner_strength,
            # denoising_start=high_noise_frac,
            image=result.images,
        )
        if len(result.images) != batch_size:
            print(f"Unexpected result size {len(result.images)} != {batch_size}")
            raise Exception(f"unexpected result size {len(result.images)} != {batch_size}")

    for i, img in enumerate(result.images):
        image_path = tmpname(suffix=".png", prefix="controlnet-")
        img.save(image_path)
        results[f"image-{i}"] = image_path
    return results


def inpainting(
    prompt=None,
    input_image=None,
    mask_image=None,
    batch_size=None,
    steps=None,
    guidance_scale=None,
    strength=None,
    seed=None,
    scheduler=None,
    negative_prompt=None,
    overmask=None,
    control_types=None,
    control_images=None,
    control_images_conditioned=None,
    control_scales=None,
    ip_adapter_image=None,
    ip_adapter_scale=None,
    lora_names=None,
    lora_weights=None,
):
    if input_image is None:
        raise Exception("image required")
    if mask_image is None:
        raise Exception("mask required")
    prompt = "" if prompt is None else prompt
    batch_size = 4 if batch_size is None else batch_size
    steps = 50 if steps is None else steps
    guidance_scale = 7.5 if guidance_scale is None else guidance_scale
    strength = 0.99 if strength is None else strength
    ip_adapter_scale = 0.7 if ip_adapter_scale is None else ip_adapter_scale
    overmask = True if overmask is None else overmask
    if not os.path.isfile(input_image):
        raise Exception(f"missing file: {input_image}")
    scheduler = "dpms" if scheduler is None else scheduler
    guidance_scale = float(guidance_scale)
    results = {}
    ckpt = OperationCheckpoint()
    ckpt.save_result(input_image, name="orig")
    ckpt.save_result(mask_image, name="mask")

    image_input = Image.open(input_image).convert("RGB")
    width, height = image_input.size
    assert os.path.isfile(mask_image)
    mask_height, mask_width, _ = cv2.imread(mask_image).shape
    if mask_width != width or mask_height != height:
        raise Exception(f"image and mask dimensions differ: {width}x{height} {mask_width}x{mask_height}")

    image_mask = Image.open(mask_image).convert("1")
    batch_size = adjust_batch_size(batch_size, width, height)

    print(f"inpainting| prompt: {prompt} negative_prompt: {negative_prompt} ")
    print(f"inpainting| scheduler: {scheduler} width: {width} height: {height}")
    print(
        f"inpainting| control_types: {control_types} control_scales: {control_scales} control_images_conditioned: {control_images_conditioned}"
    )
    print(
        f"inpainting| batch_size: {batch_size} steps: {steps} seed: {seed} guidance_scale: {guidance_scale} strength: {strength} overmask: {overmask}",
        flush=True,
    )
    if ip_adapter_image is not None:
        print(f"inpainting| ip_adapter_image: {ip_adapter_image} ip_adapter_scale: {ip_adapter_scale}")

    pipe = pipelines.pipe

    # We load the controlnet & scheduler at inference time
    # instead of having it permanently attached to the pipe because it is quick
    # This reduces the number of subprocs for unique pipelines
    pipe.controlnet = models.controlnet_get(control_types)
    pipe.scheduler = pipelines.scheduler_get(scheduler, pipe.scheduler.config)

    if seed is not None:
        seed_everything(seed)

    # load LORAs
    current_adapters = [] if pipe.get_list_adapters().get("unet") is None else pipe.get_list_adapters().get("unet")
    if isinstance(lora_names, list) and len(lora_names) > 0:
        for lora_name in lora_names:
            if lora_name not in current_adapters:
                print(f"Loading {lora_name}.safetensors LORA...")
                pipe.load_lora_weights(
                    f"./{lora_name}.safetensors",
                    weight_name=f"./{lora_name}.safetensors",
                    adapter_name=lora_name,
                )
        pipe.set_adapters(lora_names, adapter_weights=lora_weights)
    else:
        active_adapters = pipe.get_active_adapters()
        pipe.delete_adapters(active_adapters)

    # Process the control images
    if control_types is not None:
        control_images = []
        for control_type, control_image_conditioned in zip(control_types, control_images_conditioned):
            control_images.append(Image.open(control_image_conditioned).convert("RGB"))
            ckpt.save_result(control_image_conditioned, name=control_type)

    # dilate mask before applying mask blur
    image_mask_cv2 = np.array(image_mask.convert("RGB"))
    kernel = np.ones((15, 15), np.uint8)
    dilation = cv2.dilate(image_mask_cv2, kernel, iterations=2)
    image_mask_blur = Image.fromarray(dilation)
    image_mask_blur = pipe.mask_processor.blur(image_mask_blur, blur_factor=40)
    ckpt.save_image(
        np.array(image_mask_blur),
        tmp_prefix="diffusion-mask-",
        ckpt_name="diffusion-mask",
    )
    args = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "num_images_per_prompt": batch_size,
        "image": image_input,
        "mask_image": image_mask_blur,
        "width": width,
        "height": height,
        "guidance_scale": guidance_scale,
        "strength": strength,
        "num_inference_steps": steps,
    }
    if ip_adapter_image is not None:
        ckpt.save_result(ip_adapter_image, name="style")
        args["ip_adapter_image"] = Image.open(ip_adapter_image).convert("RGB")
        pipe.set_ip_adapter_scale(ip_adapter_scale)
    if control_types is not None:
        args["control_image"] = control_images
        if control_scales is not None and len(control_scales) > 0:
            args["controlnet_conditioning_scale"] = control_scales[0] if len(control_scales) == 1 else control_scales
    result = pipe(**args)

    if len(result.images) != batch_size:
        print(f"Unexpected result size {len(result.images)} != {batch_size}")
        raise Exception(f"unexpected result size {len(result.images)} != {batch_size}")

    # dilate mask to blend original input with diffusion output
    image_mask_cv2 = np.array(image_mask.convert("RGB"))
    kernel = np.ones((15, 15), np.uint8)
    dilation = cv2.dilate(image_mask_cv2, kernel, iterations=2)
    image_mask_blur = Image.fromarray(dilation)
    image_mask_blur = pipe.mask_processor.blur(image_mask_blur.convert("RGB"), blur_factor=15)
    ckpt.save_image(
        np.array(image_mask_blur),
        tmp_prefix="overlay-mask-",
        ckpt_name="overlay-mask",
    )
    for i, img in enumerate(result.images):
        if overmask:
            img = pipe.image_processor.apply_overlay(image_mask_blur, image_input, img)
        ckpt.save_image(
            np.array(img),
            tmp_prefix="inpainting-",
            ckpt_name=f"image-{i}",
            is_extra=False,
            update_idx=False,
        )
    return ckpt.get_state()


def t2i_adapter(
    prompt=None,
    adapter_types=None,
    adapter_images_conditioned=None,
    adapter_scales=None,
    lora_names=None,
    lora_weights=None,
    batch_size=None,
    steps=None,
    guidance_scale=None,
    seed=None,
    scheduler=None,
    width=None,
    height=None,
    negative_prompt=None,
):
    if prompt is None:
        raise Exception("prompt required")
    if (
        (width is None or height is None)
        and adapter_images_conditioned is not None
        and len(adapter_images_conditioned) > 0
    ):
        height, width, _ = cv2.imread(adapter_images_conditioned[0]).shape
    if width is None:
        raise Exception("width required")
    if height is None:
        raise Exception("height required")
    if adapter_images_conditioned is None or len(adapter_images_conditioned) == 0:
        raise Exception("adapter image(s) required")
    batch_size = 4 if batch_size is None else batch_size
    steps = 50 if steps is None else steps
    guidance_scale = 7.5 if guidance_scale is None else guidance_scale
    scheduler = "dpms" if scheduler is None else scheduler
    if width % 64 != 0 or height % 64 != 0:
        raise Exception(f"unsupported size  width: {width} height: {height}")
    batch_size = adjust_batch_size(batch_size, width, height)
    guidance_scale = float(guidance_scale)
    if adapter_images_conditioned is not None:
        for adapter_image_conditioned in adapter_images_conditioned:
            adapter_image_height, adapter_image_width, _ = cv2.imread(adapter_image_conditioned).shape
            if adapter_image_height % 64 != 0 or adapter_image_width % 64 != 0:
                raise Exception(f"dimensions unsupported: {width}x{height}")
            if adapter_image_height != height or adapter_image_width != width:
                raise Exception(
                    f"adapter image dimensions do not match: {width}x{height} {adapter_image_width}x{adapter_image_height}"
                )
    results = {}

    print(f"t2i_adapter| prompt: {prompt}")
    print(f"t2i_adapter| negative_prompt: {negative_prompt}")
    print(
        f"t2i_adapter| adapter_types: {adapter_types} adapter_scales: {adapter_scales} adapter_images_conditioned: {adapter_images_conditioned}"
    )
    print(f"t2i_adapter| scheduler: {scheduler} width: {width} height: {height}")
    print(f"t2i_adapter| loras: {lora_names} lora_weights: {lora_weights}")
    print(
        f"t2i_adapter| batch_size: {batch_size} steps: {steps} seed: {seed} guidance_scale: {guidance_scale}",
        flush=True,
    )

    pipe = pipelines.pipe
    # We load the adapter & scheduler at inference time
    # instead of having it permanently attached to the pipe because it is quick
    # This reduces the number of subprocs for unique pipelines
    pipe.adapter = models.adapter_get(adapter_types)
    pipe.scheduler = pipelines.scheduler_get(scheduler, pipe.scheduler.config)

    if seed is not None:
        seed_everything(seed)

    # load LORAs
    current_adapters = [] if pipe.get_list_adapters().get("unet") is None else pipe.get_list_adapters().get("unet")
    if isinstance(lora_names, list) and len(lora_names) > 0:
        for lora_name in lora_names:
            if lora_name not in current_adapters:
                print(f"Loading {lora_name}.safetensors LORA...")
                pipe.load_lora_weights(
                    f"./{lora_name}.safetensors",
                    weight_name=f"./{lora_name}.safetensors",
                    adapter_name=lora_name,
                )
        pipe.set_adapters(lora_names, adapter_weights=lora_weights)
    else:
        active_adapters = pipe.get_active_adapters()
        pipe.delete_adapters(active_adapters)

    # Process the adapter images
    adapter_images = []
    for adapter_image_conditioned in adapter_images_conditioned:
        adapter_images.append(Image.open(adapter_image_conditioned).convert("RGB"))
    if len(adapter_images) == 1:
        adapter_images = adapter_images[0]

    args = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "num_images_per_prompt": batch_size,
        "width": width,
        "height": height,
        "image": adapter_images,
        "guidance_scale": guidance_scale,
        "num_inference_steps": steps,
    }
    if adapter_scales is not None and len(adapter_scales) > 0:
        args["adapter_conditioning_scale"] = adapter_scales[0] if len(adapter_scales) == 1 else adapter_scales

    result = pipe(**args)

    if len(result.images) != batch_size:
        print(f"Unexpected result size {len(result.images)} != {batch_size}")
        raise Exception(f"unexpected result size {len(result.images)} != {batch_size}")

    torch.cuda.empty_cache()

    refiner = pipelines.refiner
    if isinstance(lora_names, list) and len(lora_names) > 0:
        refiner = None  # do not user refiner if loading LORA
    if refiner is not None:
        refiner_steps = 20
        refiner_strength = 0.13
        result = refiner(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_images_per_prompt=batch_size,
            num_inference_steps=refiner_steps,
            strength=refiner_strength,
            # denoising_start=high_noise_frac,
            image=result.images,
        )
        if len(result.images) != batch_size:
            print(f"Unexpected result size {len(result.images)} != {batch_size}")
            raise Exception(f"unexpected result size {len(result.images)} != {batch_size}")

    for i, img in enumerate(result.images):
        image_path = tmpname(suffix=".png", prefix="t2i-")
        img.save(image_path)
        results[f"image-{i}"] = image_path
    return results


def image_apply_mask(original_image_path, new_image_path, mask_image_path):
    # mask image takes the original image, new image, and mask
    # it replaces all the unmasked pixel in the new image with the original image
    # this is because stable-diffusion currently edits pixels outside the mask, which messes up faces/eyes etc.
    orig = cv2.imread(original_image_path)
    new = cv2.imread(new_image_path)
    mask = cv2.imread(mask_image_path)

    # Replace pixel in orig with pixel from new where mask is white
    orig[np.where(mask == 255)] = new[np.where(mask == 255)]

    # Overwrite new
    cv2.imwrite(new_image_path, orig)


def adjust_batch_size(batch_size, width, height, allow_oversize=False):
    # The maximum size allowed for batch_size of 4
    max_pixel_batch_size_4 = 768 * 832
    # The maximum size allowed for batch_size of 3
    max_pixel_batch_size_3 = 1088 * 1088
    # The maximum size allowed for batch_size of 2
    max_pixel_batch_size_2 = 1152 * 1152
    # The maximum size allowed for batch_size of 1 (sizes above this won't be processed at all)
    max_pixel_batch_size_1 = 1536 * 1536
    # The maximum size for any dimension
    max_dimension = 1920

    if allow_oversize:
        max_pixel_batch_size_1 = 2200 * 2200
        max_dimension = 2048
    batch_size = min(4, batch_size)  # 4 is the max batch_size
    if height * width > max_pixel_batch_size_4:
        batch_size = min(3, batch_size)  # reduce to 3 (or less)
    if height * width > max_pixel_batch_size_3:
        batch_size = min(2, batch_size)  # reduce to 2 (or less)
    if height * width > max_pixel_batch_size_2:
        batch_size = min(1, batch_size)  # reduce to 1 (or less)
    if height * width > max_pixel_batch_size_1:
        raise Exception(f"image too large: {width}x{height}")
    if height > max_dimension:
        raise Exception(f"image height too large: {width}x{height}")
    if width > max_dimension:
        raise Exception(f"image width too large: {width}x{height}")
    return batch_size


def check_prompt(prompt):
    for word, tag in pos_tag(word_tokenize(prompt)):
        if tag == "VB":
            return word, False
    return None, True
