"""diffusers service"""
import os
os.environ["HF_HOME"] = "/mldata/huggingface"
os.environ["HF_TOKEN"] = "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"
os.environ["TORCH_HOME"] = "/mldata/torch"
os.environ["NLTK_DATA"] = "/mldata/nltk_data"
import logging
logger = logging.getLogger(__name__)
import torch
import argparse
import sys
import requests
import time
import setproctitle
import signal
import traceback
import tempfile
import multiprocessing as mp
import socket
import subprocess
import json
import numpy as np
from PIL import Image

import reimage
import models
import pipelines
import subproc
from operations import (
    check_prompt,
    conditioner_pack,
    classify_room,
    classify_location,
    classification_room,
    classification_location,
    seg,
)
from util import tmpname
from interior import remodel_interior
from exterior import remodel_exterior
from upscale import upscale_hat

# this is only set to true in the main process
parent = False

# if true, exit
exit_flag = False

# priority penalty - this is iterated each time something takes too long
penalty = 0

# the current priority
priority = 9

# a string to represent the identity of this worker
ident = socket.gethostname()

# default broker to call if no job is in progress
default_host = "api.reimage.io"

def run_inputs(inputs):
    """run a set of inputs"""
    op = inputs.pop("op", None)
    if op is None:
        raise Exception(f"Unknown operation: {op}")
    job = inputs.get("job")

    warning = None
    if inputs.get('prompt') is not None:
        invalid_word, ok = check_prompt(inputs.get('prompt'))
        if not ok:
            warning = f"imperative_prompt: \"{invalid_word}\""
            print(f"bad prompt detected: {warning}")

    if op == "interior-remodel":
        outputs = remodel_interior(**inputs)
        reimage.update_job(job, 'complete', outputs=outputs, warning=warning)
        return outputs
    elif op == "exterior-remodel":
        outputs = remodel_exterior(**inputs)
        reimage.update_job(job, 'complete', outputs=outputs, warning=warning)
        return outputs
    elif op == "conditioner-pack":
        outputs = conditioner_pack(inputs.get('image'))
        reimage.update_job(job, 'complete', outputs=outputs)
        return outputs
    elif op == "classify-room":
        room_type = classify_room(image=inputs.get('image'))
        outputs = {"room_type": room_type}
        reimage.update_job(job, 'complete', detail=room_type, outputs=outputs)
        return outputs
    elif op == "classify-location":
        location = classify_location(image=inputs.get('image'))
        outputs = {"location": location}
        reimage.update_job(job, 'complete', detail=location, outputs=outputs)
        return outputs
    elif op in ["classification-room", "classification-room_type"]:
        classprobs = classification_room(image=inputs.get('image'))
        outputs = {
            "room_type": None if classprobs is None else classprobs[0][0],
            "class_probabilities": classprobs,
        }
        reimage.update_job(job, 'complete', detail=classprobs, outputs=outputs)
        return outputs
    elif op == "classification-location":
        classprobs = classification_location(image=inputs.get('image'))
        outputs = {
            "location": None if classprobs is None else classprobs[0][0],
            "class_probabilities": classprobs,
        }
        reimage.update_job(job, 'complete', detail=classprobs, outputs=outputs)
        return outputs
    elif op in ["sam", "seg"]:
        inputs.pop('job', None)
        outputs = {}
        outputs["image-0"] = seg(**inputs)
        reimage.update_job(job, 'complete', outputs=outputs)
        return outputs
    elif op == "upscale_hat":
        outputs = upscale_hat(**inputs)
        reimage.update_job(job, 'complete', outputs=outputs)
        return outputs
    else:
        raise Exception(f"Unknown operation: {op}")


def process(
        job,
        op=None, # The operation of this job
        batch_size=None,
        seed=None,
        prompt=None,
        image=None,
        mask_image=None,
        preset=None, # used for remodel ops, seg
        location=None, # used for remodel ops
        room_type=None, # used for remodel ops
        structure_type=None, # used for remodel ops
        fidelity=None, # used for remodel ops
        size=None, # used for creative upscale
        classes=None, # used for seg
        negative_classes=None, # used for seg
        boxes=None, # used for seg
        points=None, # used for seg
        mask_size=None, # used for seg
        expand=None, # used for seg
        style_image=None, # used for remodel ops
        scale=None # used for accurate upscale
):
    """process the job"""
    if exit_flag:
        reimage.requeue_active_job()
        return

    print(f"Processing {job} {op} ...", flush=True)
    elapsed_time = reimage.update_job(job, 'processing')

    # getting the job metadata and updating the job to processing should be quick
    # the server returns the elapsed time from the job assigment time to change it to processing
    # if it takes longer than 5 seconds its likely there is a network or machine issue
    # increase penalty so priority is adjusted lower when this happens
    if elapsed_time is not None and isinstance(elapsed_time, int) and elapsed_time > 5000:
        print(f"Update {job} {op} to processing ({elapsed_time}ms) slow!", flush=True)
        iterate_penalty()

    outputs = {}
    t1 = time.time()
    try:
        if op not in [
                "conditioner-pack",
                "classify-room",
                "classify-location",
                "classification-room",
                "classification-location",
                "upscale_hat",
                "seg",
                "sam",
                "interior-remodel",
                "exterior-remodel",
        ]:
            raise Exception(f"Unknown operation: {op}")

        inputs = {
            "job": job,
            "op": op,
        }
        if batch_size is not None:
            inputs['batch_size'] = batch_size
        if seed is not None:
            inputs['seed'] = seed
        if prompt is not None:
            inputs['prompt'] = prompt
        if image is not None:
            inputs['image'] = image
        if mask_image is not None:
            inputs['mask_image'] = mask_image
        if preset is not None:
            inputs['preset'] = preset
        if location is not None:
            inputs['location'] = location
        if room_type is not None:
            inputs['room_type'] = room_type
        if structure_type is not None:
            inputs['structure_type'] = structure_type
        if fidelity is not None:
            inputs['fidelity'] = fidelity
        if size is not None:
            inputs['size'] = size
        if classes is not None:
            inputs['classes'] = classes
        if negative_classes is not None:
            inputs['negative_classes'] = negative_classes
        if boxes is not None:
            inputs['boxes'] = boxes
        if points is not None:
            inputs['points'] = points
        if mask_size is not None:
            inputs['mask_size'] = mask_size
        if expand is not None:
            inputs['expand'] = expand
        if style_image is not None:
            inputs['style_image'] = style_image
        if scale is not None:
            inputs['scale'] = scale

        # Edit the mask - all transparent pixels are should be black
        # This is necessary because the way we grab masks on the iphone
        # can result in transparent white pixels in some cases
        if mask_image is not None:
            rgba = np.array(Image.open(mask_image).convert("RGBA"))
            rgba[rgba[..., -1] < 255] = [0, 0, 0, 0]  # if < 100% opacity - set pixel to black
            Image.fromarray(rgba).convert("RGB").save(mask_image)

        outputs = run_inputs(inputs)

    except BaseException as exc:
        traceback.print_exc(file=sys.stdout)
        sys.stdout.flush()
        detail=f"{type(exc).__name__}: {str(exc)}"
        reimage.update_job(job, 'error', detail=detail, outputs={"error": detail})
        if exception_is_fatal(exc):
            # do not requeue this job in case it may cause fatal exceptions to other services
            exit_cleanup(code=1, requeue=False)
    finally:
        # delete all files in outputs
        if outputs != {} and outputs is not None:
            for _, filename in outputs.items():
                try:
                    if isinstance(filename, str):
                        os.remove(filename)
                except BaseException as e:
                    if isinstance(e, FileNotFoundError):
                        # We don't care about file not found errors
                        pass
                    else:
                        traceback.print_exc(file=sys.stdout)


    t2 = time.time()
    print(f"Processed  {job} {op} ({t2-t1:0.2f}s)", flush=True)
    # Check to see if this job took too long
    # if it did - send a notification, and also lower your own priority
    # if this happens many times, just exit
    if t2-t1 > 60.0:
        print(f"Processed  {job} {op} ({t2-t1:0.2f}s) slow!\n", flush=True)
        reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} slow {op} ({t2-t1:0.2f}s) (total: {penalty}) ")
        iterate_penalty()

        # if this job is exceptionally slow just exit
        if t2-t1 > 120.0:
            print("Extra slow job. Exiting...\n", flush=True)
            exit_cleanup(code=1)

    torch.cuda.empty_cache()


def iterate_penalty():
    global penalty
    penalty = penalty + 1
    print(f"New penalty: {penalty}", flush=True)
    if os.environ.get('CONTAINER_ID') is not None:
        subprocess.run(["vastai","label","instance",os.environ.get('CONTAINER_ID'),f"diffusers-priority-{priority}-penalty-{penalty}","--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)


def exception_is_fatal(exc):
    exc_str=f"{type(exc).__name__}: {str(exc)}".lower()
    for s in [
            # RuntimeError: CUDA error: out of memory
            # OutOfMemoryError: CUDA out of memory. Tried to allocate ...
            "out of memory",
            # OSError: [Errno 28] No space left on device
            "space left on device",
            # IOError: [Errno 24] Too many open files
            "too many open files",
            # OSError: [Errno 122] Disk quota exceeded
            "disk quota exceeded",
            # RuntimeError: No CUDA GPUs are available
            "no cuda gpus are available",
            # Exception: Failed to start process ...
            "failed to start process",
            # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
            "expected all tensors to be on the same device",
            # OSError: TencentARC/t2i-adapter-depth-midas-sdxl-1.0 does not appear to have a file named config.json
            "does not appear to have a file",
            # RuntimeError: CUDA error: misaligned address CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
            # RuntimeError: CUDA error: an illegal memory access was encountered
            # RuntimeError: CUDA error: unknown error CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect."
            # RuntimeError: CUDA error: device-side assert triggered
            # RuntimeError: CUDA error: out of memory
            # RuntimeError: CUDA error: invalid device ordinal
            # RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling ...
            "runtimeerror: cuda error",
            # RuntimeError: cutlassF: no kernel found to launch!
            "runtimeerror: cutlassf",
            # lookuperror: ... please use the nltk downloader to obtain the resource
            "lookuperror:"
    ]:
        if s in exc_str:
            print(f"\nFatal Exception: {exc_str}\n", flush=True)
            return True
    print(f"\nNon-Fatal Exception: {exc_str}\n", flush=True)
    return False


def signal_handler(sig, _):
    exit_cleanup(sig)


def exit_cleanup(sig=None, code=0, requeue=True):
    global exit_flag
    try:
        if not parent:
            print("Sending SIGINT to parent...")
            os.kill(os.getppid(), signal.SIGINT)
            return

        exit_flag = True
        reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} worker quitting! [{'none' if sig is None else signal.Signals(sig).name}]")

        # if this is vastai instances, stop the instance
        if os.environ.get('CONTAINER_ID') is not None:
            print(f"Stopping vast container: {os.environ.get('CONTAINER_ID')} {os.environ.get('CONTAINER_API_KEY')}")
            result = subprocess.run(["vastai","stop","instance",os.environ.get('CONTAINER_ID'),"--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)
            print(f"Stopping vast container: {result.returncode}")

        for process_name in subproc.subprocs:
            subproc.subprocs[process_name]["job_queue"].put(None)

        if requeue:
            reimage.requeue_active_job()
        for child in mp.active_children():
            child.terminate()

        print(f"\n\nExiting... [{setproctitle.getproctitle()}]", flush=True)
        # pylint: disable=protected-access
        os._exit(code)
    except:
        traceback.print_exc(file=sys.stdout)


def auto_priority():
    t1 = time.time()
    batch_size=4
    image_path = os.path.join(os.path.dirname(__file__), "test-image.jpg")
    args = {
        "job": reimage.Job(default_host),
        "batch_size": batch_size,
        "image": image_path,
        "mask_image": None,
        "preset": "modern",
        "prompt": "",
        "negative_prompt": None,
        "room_type": "kitchen",
        "style_image": None,
        "seed": 1,
        "fidelity": 0.65,
    }
    outputs = remodel_interior(**args)
    t2 = time.time()
    diff = t2-t1
    print(f"\n\nPriority test: {diff:0.2f}s", flush=True)
    if outputs != {} and outputs is not None:
        for _, filename in outputs.items():
            try:
                if isinstance(filename, str) and not "test-image.jpg" in filename:
                    os.remove(filename)
            except:
                pass
    prio = 0
    if 0.0 < diff <= 15.0:
        prio = 0
    elif 15.0 < diff <= 20.0:
        prio = 1
    elif 20.0 < diff <= 25.0:
        prio = 2
    elif 25.0 < diff <= 35.0:
        prio = 3
    elif 35.0 < diff <= 45.0:
        prio = 4
    elif 45.0 < diff <= 60.0:
        prio = 5
    else:
        print(f"Priority Test: {diff:0.2f}s FAIL\n", flush=True)
        subproc.subproc_kill()
        reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} worker failed priority test: {diff:0.2f}s ")
        exit_cleanup(code=1)
    print(f"Auto priority: +{prio}\n", flush=True)
    return prio


def main():
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    setproctitle.setproctitle("diffusers")
    mp.set_start_method('spawn', force=True)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host",
        nargs='*',
        default=["broker-0.reimage.io", "broker-1.reimage.io", "broker-2.reimage.io", "broker-3.reimage.io"],
        help="The broker host(s)",
    )
    parser.add_argument(
        "--ident",
        type=str,
        default=None,
        help="The identity",
    )
    parser.add_argument(
        "--priority",
        type=int,
        default=2,
        help="The priority for requesting jobs (0 is highest)",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=4,
        help="The timeout for requesting jobs",
    )
    parser.add_argument(
        "--insecure",
        action='store_true',
        help="If specified HTTPS cert is not verified (for development use only)",
    )
    parser.add_argument(
        "--download",
        action='store_true',
        help="Just download the models and exit",
    )
    parser.add_argument(
        "--quit-after-download",
        dest="quit_after_download",
        action='store_true',
        help="Just download the models and exit",
    )
    parser.add_argument(
        "--auto-priority",
        dest="auto_priority",
        action='store_true',
        help="Do a test render and set priority accordingly",
    )
    parser.add_argument(
        "--extra",
        dest="extra",
        action='store_true',
        help="return extra outputs from jobs",
    )

    opt = parser.parse_args()

    global priority
    priority = opt.priority
    global default_host
    default_host = opt.host[0]

    reimage.VERIFY = not opt.insecure
    reimage.SEND_EXTRA = opt.extra
    if opt.insecure:
        requests.packages.urllib3.disable_warnings()
    if opt.ident is not None:
        global ident
        ident = opt.ident
        reimage.ident = opt.ident
    if os.environ.get('CONTAINER_ID') is not None:
        ident = f"vast-id-{os.environ.get('CONTAINER_ID')}-{ident}"
        reimage.ident = opt.ident

    # inference only service
    torch.set_grad_enabled(False)

    if opt.download:
        try:
            models.download_models()
            pipelines.download_models()
        except:
            traceback.print_exc(file=sys.stdout)
            print("\nError downloading models\n", flush=True)
            exit_cleanup(code=1)
        if opt.quit_after_download:
            exit_cleanup(code=0)

    os.environ['HF_HUB_OFFLINE'] = '1'
    models.load_models()

    if opt.auto_priority:
        try:
            priority = priority + auto_priority()
        except BaseException as e:
            traceback.print_exc(file=sys.stdout)
            sys.stdout.flush()
            if exception_is_fatal(e):
                exit_cleanup(code=1)
    if os.environ.get('CONTAINER_ID') is not None:
        subprocess.run(["vastai","label","instance",os.environ.get('CONTAINER_ID'),f"diffusers-priority-{priority}-penalty-{penalty}","--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)

    reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} worker online! (priority: {priority})")
    print("Waiting for jobs...\n\n", flush=True)
    count=0
    while not exit_flag:
        try:
            # if penalty is too high just exit
            if penalty > 5:
                print("Penalty threshold exceeded. Exiting...\n", flush=True)
                exit_cleanup(code=1)

            # in this section we alternate the host name for some basic load balancing
            # if the hostname does not resolve, skip to next
            host = opt.host[count % len(opt.host)]
            count = count + 1
            ipaddr = ""
            try:
                ipaddr = socket.gethostbyname(host)
                if ipaddr == "0.0.0.0":
                    continue
            except:
                time.sleep(0.1)
                continue

            prio = min(9, priority + penalty) if opt.auto_priority else priority
            print(f"Checking for jobs... ({host} {ident} {ipaddr} {prio})", flush=True)
            reimage.dojob(
                host,
                "diffusers",
                process,
                prio,
                opt.timeout,
                hints=subproc.subprocs.keys(),
            )
            torch.cuda.empty_cache()
        except BaseException as e:
            traceback.print_exc(file=sys.stdout)
            sys.stdout.flush()
            if exception_is_fatal(e):
                exit_cleanup(code=1, requeue=False)


if __name__ == "__main__":
    parent = True

    print("-----------------------------")
    print(" Launching diffusers service ")
    print("-----------------------------")

    try:
        main()
    except BaseException as e:
        traceback.print_exc(file=sys.stdout)
        sys.stdout.flush()
        exit_cleanup(code=1)
    print(f"\n\nExiting... [{setproctitle.getproctitle()}]", flush=True)
