"""diffusers service"""
import os
os.environ["HF_HOME"] = "/mldata/huggingface"
os.environ["HF_TOKEN"] = "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"
os.environ["TORCH_HOME"] = "/mldata/torch"
os.environ["NLTK_DATA"] = "/mldata/nltk_data"
import logging
logger = logging.getLogger(__name__)
import torch
import argparse
import sys
import requests
import time
import setproctitle
import signal
import traceback
import tempfile
import multiprocessing as mp
import socket
import subprocess
import json
import numpy as np
from PIL import Image

import reimage
import models
import pipelines
import subproc
from operations import (
    check_prompt,
    conditioner_pack,
    classify_room,
    classify_location,
    classification_room,
    classification_location,
    seg,
)
from interior import remodel_interior
from exterior import remodel_exterior
from landscaping import remodel_landscaping
from floor import remodel_floor
from wall import remodel_wall
from reskin import remodel_reskin
from remove import remodel_remove
from replace import remodel_replace
from paint import paint
from upscale import upscale_controlnet_tile


# this is only set to true in the main process
parent = False

# if true, exit
exit_flag = False

# count of jobs that took too long
slow_count = 0

# the current priority
priority = 9

# a string to represent the identity of this worker
ident = socket.gethostname()

def process(
        job_id,
        op=None, # The operation of this job
        mode=None, # The mode of the operation
        batch_size=4,
        seed=None,
        prompt=None,
        image=None,
        mask_image=None,
        preset=None, # used for remodel ops, seg
        room_type=None, # used for remodel ops
        structure_type=None, # used for remodel ops
        fidelity=None, # used for remodel ops
        size=None, # used for upscale
        classes=None, # used for seg
        negative_classes=None, # used for seg
        boxes=None, # used for seg
        points=None, # used for seg
        mask_size=None, # used for seg
        color_r=None, # used for paint
        color_g=None, # used for paint
        color_b=None, # used for paint
        keep_surface_details="auto", # used for paint
        surface_type="auto", # used for paint
        expand=None, # used for seg
        style_image=None, # used for remodel ops
        strength=None, # used for upscale
):
    """process the job"""
    if exit_flag:
        reimage.requeue_active_job()
        return

    print(f"Processing {job_id} {op} ...", flush=True)
    reimage.update_job(job_id, 'processing')

    results = {}
    t1 = time.time()
    try:
        if op not in [
                "remodel",
                "paint",
                "conditioner-pack",
                "classify-room",
                "classify-location",
                "classification-room",
                "classification-location",
                "upscale_controlnet_tile",
                "seg",
                "sam",
                "interior-remodel",
                "exterior-remodel",
                "reskin",
                "remove",
                "replace",
                "wall",
                "floor",
                "landscaping",
        ]:
            raise Exception(f"Unknown operation: {op}")

        # Edit the mask - all transparent pixels are should be black
        # This is necessary because the way we grab masks on the iphone
        # can result in transparent white pixels in some cases
        if mask_image is not None:
            rgba = np.array(Image.open(mask_image).convert("RGBA"))
            rgba[rgba[..., -1] < 255] = [0, 0, 0, 0]  # if < 100% opacity - set pixel to black
            Image.fromarray(rgba).convert("RGB").save(mask_image)

        warning = None
        if prompt is not None:
            invalid_word, ok = check_prompt(prompt)
            if not ok:
                warning = f"imperative_prompt: \"{invalid_word}\""
                print(f"bad prompt detected: {warning}")

        if op == "interior-remodel" or (op == "remodel" and mode == "interior"):
            args = {
                "job_id": job_id,
                "image": image,
                "mask_image": mask_image,
                "style_image": style_image,
                "preset": preset,
                "prompt": prompt,
                "negative_prompt": None,
                "room_type": room_type,
                "seed": seed,
                "fidelity": fidelity,
                "batch_size": batch_size,
            }
            results = remodel_interior(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "exterior-remodel" or (op == "remodel" and mode == "exterior"):
            args = {
                "job_id": job_id,
                "image": image,
                "mask_image": mask_image,
                "style_image": style_image,
                "preset": preset,
                "prompt": prompt,
                "negative_prompt": None,
                "structure_type": structure_type,
                "fidelity": fidelity,
                "seed": seed,
                "batch_size": batch_size,
            }
            results = remodel_exterior(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "landscaping" or (op == "remodel" and mode == "landscaping"):
            args = {
                "job_id": job_id,
                "prompt": prompt,
                "negative_prompt": None,
                "image": image,
                "mask_image": mask_image,
                "style_image": style_image,
                "preset": preset,
                "fidelity": fidelity,
                "seed": seed,
                "batch_size": batch_size,
            }
            results = remodel_landscaping(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "floor" or (op == "remodel" and mode == "floor"):
            args = {
                "job_id": job_id,
                "image": image,
                "style_image": style_image,
                "preset": preset,
                "prompt": prompt,
                "negative_prompt": None,
                "fidelity": fidelity,
                "seed": seed,
                "batch_size": batch_size,
            }
            results = remodel_floor(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "wall" or (op == "remodel" and mode == "wall"):
            args = {
                "job_id": job_id,
                "image": image,
                "style_image": style_image,
                "preset": preset,
                "prompt": prompt,
                "negative_prompt": None,
                "fidelity": fidelity,
                "seed": seed,
                "batch_size": batch_size,
            }
            results = remodel_wall(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "replace" or (op == "remodel" and mode == "replace"):
            args = {
                "job_id": job_id,
                "image": image,
                "mask_image": mask_image,
                "prompt": prompt,
                "negative_prompt": None,
                "style_image": style_image,
                "fidelity": fidelity,
                "batch_size": batch_size,
                "seed": seed,
            }
            results = remodel_replace(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "remove" or (op == "remodel" and mode == "remove"):
            args = {
                "job_id": job_id,
                "image": image,
                "mask_image": mask_image,
                "batch_size": batch_size,
                "seed": seed,
            }
            results = remodel_remove(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "reskin" or (op == "remodel" and mode == "reskin"):
            args = {
                "job_id": job_id,
                "prompt": prompt,
                "negative_prompt": None,
                "image": image,
                "mask_image": mask_image,
                "style_image": style_image,
                "fidelity": fidelity,
                "batch_size": batch_size,
                "seed": seed,
            }
            results = remodel_reskin(**args)
            reimage.update_job(job_id, 'complete', results=results, warning=warning)
        elif op == "paint":
            args = {
                "job_id": job_id,
                "batch_size": batch_size,
                "image": image,
                "mask_image": mask_image,
                "color_r": color_r,
                "color_g": color_g,
                "color_b": color_b,
                "keep_surface_details": keep_surface_details,
                "surface_type": surface_type,
                "seed": seed,
            }
            results = paint(**args)
            reimage.update_job(job_id, 'complete', results=results)
        elif op == "conditioner-pack":
            results = conditioner_pack(image)
            reimage.update_job(job_id, 'complete', results=results)
        elif op == "classify-room":
            detail = classify_room(image=image)
            reimage.update_job(job_id, 'complete', detail=detail)
        elif op == "classify-location":
            detail = classify_location(image=image)
            reimage.update_job(job_id, 'complete', detail=detail)
        elif op == "classification-room":
            detail = classification_room(image=image)
            reimage.update_job(job_id, 'complete', detail=detail)
        elif op == "classification-location":
            detail = classification_location(image=image)
            reimage.update_job(job_id, 'complete', detail=detail)
        elif op in ["sam", "seg"]:
            results = {}
            results["image-0"] = seg(image=image,
                preset=preset,
                classes=classes,
                negative_classes=negative_classes,
                boxes=boxes,
                points=points,
                expand_pixels=expand,
                mask_size=mask_size if mask_size else "large",
            )
            reimage.update_job(job_id, 'complete', results=results)
        elif op == "upscale_controlnet_tile":
            args = {
                "job_id": job_id,
                "image": image,
                "prompt": prompt,
                "seed": seed,
                "size": size,
                "strength": strength,
            }
            results = upscale_controlnet_tile(**args)
            reimage.update_job(job_id, 'complete', results=results)
    except BaseException as e:
        traceback.print_exc(file=sys.stdout)
        sys.stdout.flush()
        detail=f"{type(e).__name__}: {str(e)}"
        reimage.update_job(job_id, 'error', detail=detail)
        if exception_is_fatal(e):
            # do not requeue this job in case it may cause fatal exceptions to other services
            exit_cleanup(code=1, requeue=False)
    finally:
        # delete all files in results
        if results != {} and results is not None:
            for _, filename in results.items():
                try:
                    if isinstance(filename, str):
                        os.remove(filename)
                except BaseException as e:
                    if isinstance(e, FileNotFoundError):
                        # We don't care about file not found errors
                        pass
                    else:
                        traceback.print_exc(file=sys.stdout)


    t2 = time.time()
    print(f"Processed  {job_id} {op} ({t2-t1:0.2f}s)", flush=True)
    # Check to see if this job took too long
    # if it did - send a notification, and also lower your own priority
    # if this happens many times, just exit
    if t2-t1 > 60.0:
        global slow_count
        global priority
        slow_count = slow_count + 1
        print(f"Processed  {job_id} {op} ({t2-t1:0.2f}s) slow!\n", flush=True)
        reimage.notify(f"{setproctitle.getproctitle()} {ident} slow {op} ({t2-t1:0.2f}s) (total: {slow_count}) ")
        if os.environ.get('CONTAINER_ID') is not None:
            subprocess.run(["vastai","label","instance",os.environ.get('CONTAINER_ID'),f"diffusers-priority-{priority}-slow_count-{slow_count}","--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)

        # if this job is exceptionally slow just exit
        if t2-t1 > 120.0:
            print("Extra slow job. Exiting...\n", flush=True)
            exit_cleanup(code=1)
        # if too many slow jobs just exit
        if slow_count > 5:
            print("Slow count exceeded. Exiting...\n", flush=True)
            exit_cleanup(code=1)

    torch.cuda.empty_cache()


def exception_is_fatal(exc):
    exc_str=f"{type(exc).__name__}: {str(exc)}".lower()
    for s in [
            # RuntimeError: CUDA error: out of memory
            # OutOfMemoryError: CUDA out of memory. Tried to allocate ...
            "out of memory",
            # OSError: [Errno 28] No space left on device
            "space left on device",
            # IOError: [Errno 24] Too many open files
            "too many open files",
            # OSError: [Errno 122] Disk quota exceeded
            "disk quota exceeded",
            # RuntimeError: No CUDA GPUs are available
            "no cuda gpus are available",
            # Exception: Failed to start process img2img_controlnet_sdxl SG161222/RealVisXL_V3.0
            "failed to start process",
            # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
            "expected all tensors to be on the same device",
            # OSError: TencentARC/t2i-adapter-depth-midas-sdxl-1.0 does not appear to have a file named config.json
            "does not appear to have a file",
            # RuntimeError: CUDA error: misaligned address CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
            # RuntimeError: CUDA error: an illegal memory access was encountered
            # RuntimeError: CUDA error: unknown error CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect."
            # RuntimeError: CUDA error: device-side assert triggered
            # RuntimeError: CUDA error: out of memory
            # RuntimeError: CUDA error: invalid device ordinal
            # RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling ...
            "runtimeerror: cuda error",
            # RuntimeError: cutlassF: no kernel found to launch!
            "runtimeerror: cutlassf",
            # lookuperror: ... please use the nltk downloader to obtain the resource
            "lookuperror:"
    ]:
        if s in exc_str:
            print(f"\nFatal Exception: {exc_str}\n", flush=True)
            return True
    print(f"\nNon-Fatal Exception: {exc_str}\n", flush=True)
    return False


def signal_handler(sig, _):
    exit_cleanup(sig)


def exit_cleanup(sig=None, code=0, requeue=True):
    global exit_flag
    try:
        if not parent:
            print("Sending SIGINT to parent...")
            os.kill(os.getppid(), signal.SIGINT)
            return

        exit_flag = True
        reimage.notify(f"{setproctitle.getproctitle()} {ident} worker quitting! [{'none' if sig is None else signal.Signals(sig).name}]")

        # if this is vastai instances, stop the instance
        if os.environ.get('CONTAINER_ID') is not None:
            print(f"Stopping vast container: {os.environ.get('CONTAINER_ID')} {os.environ.get('CONTAINER_API_KEY')}")
            result = subprocess.run(["vastai","stop","instance",os.environ.get('CONTAINER_ID'),"--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)
            print(f"Stopping vast container: {result.returncode}")

        for process_name in subproc.subprocs:
            subproc.subprocs[process_name]["job_queue"].put(None)

        if requeue:
            reimage.requeue_active_job()
        for child in mp.active_children():
            child.terminate()

        print(f"\n\nExiting... [{setproctitle.getproctitle()}]", flush=True)
        # pylint: disable=protected-access
        os._exit(code)
        return
    except:
        traceback.print_exc(file=sys.stdout)


def auto_priority():
    t1 = time.time()
    batch_size=4
    output_image_names = []
    for _ in range(batch_size):
        _, image_path = tempfile.mkstemp(suffix=".png",prefix="remodel-")
        output_image_names.append(image_path)
    image_path = os.path.join(os.path.dirname(__file__), "test-image.jpg")
    args = {
        "job_id": 0,
        "batch_size": batch_size,
        "image": image_path,
        "mask_image": None,
        "preset": "modern",
        "prompt": "",
        "negative_prompt": None,
        "room_type": "kitchen",
        "style_image": None,
        "seed": 1,
        "fidelity": 0.65,
    }
    remodel_interior(**args)
    for x in output_image_names:
        os.remove(x)
    t2 = time.time()
    diff = t2-t1
    print(f"\n\nPriority test: {diff:0.2f}s", flush=True)
    prio = 0
    if 0.0 < diff <= 15.0:
        prio = 0
    elif 15.0 < diff <= 20.0:
        prio = 1
    elif 20.0 < diff <= 25.0:
        prio = 2
    elif 25.0 < diff <= 35.0:
        prio = 3
    elif 35.0 < diff <= 45.0:
        prio = 4
    elif 45.0 < diff <= 60.0:
        prio = 5
    else:
        print(f"Priority Test: {diff:0.2f}s FAIL\n", flush=True)
        subproc.subproc_kill()
        reimage.notify(f"{setproctitle.getproctitle()} {ident} worker failed priority test: {diff:0.2f}s ")
        exit_cleanup(code=1)
    print(f"Auto priority: +{prio}\n", flush=True)
    return prio


def main():
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--host",
        nargs='*',
        default=["broker-0.reimage.io", "broker-1.reimage.io", "broker-2.reimage.io", "broker-3.reimage.io"],
        help="The broker host(s)",
    )
    parser.add_argument(
        "--ident",
        type=str,
        default=None,
        help="The identity",
    )
    parser.add_argument(
        "--priority",
        type=int,
        default=2,
        help="The priority for requesting jobs (0 is highest)",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=4,
        help="The timeout for requesting jobs",
    )
    parser.add_argument(
        "--insecure",
        action='store_true',
        help="If specified HTTPS cert is not verified (for development use only)",
    )
    parser.add_argument(
        "--download",
        action='store_true',
        help="Just download the models and exit",
    )
    parser.add_argument(
        "--quit-after-download",
        dest="quit_after_download",
        action='store_true',
        help="Just download the models and exit",
    )
    parser.add_argument(
        "--auto-priority",
        dest="auto_priority",
        action='store_true',
        help="Do a test render and set priority accordingly",
    )
    parser.add_argument(
        "--extra",
        dest="extra",
        action='store_true',
        help="return extra results from jobs",
    )

    opt = parser.parse_args()

    global priority
    priority = opt.priority

    reimage.BASE_URL = f"https://{opt.host[0]}"
    reimage.VERIFY = not opt.insecure
    reimage.SEND_EXTRA = opt.extra
    if opt.insecure:
        requests.packages.urllib3.disable_warnings()
    if opt.ident is not None:
        global ident
        ident = opt.ident
        reimage.ident = opt.ident
    if os.environ.get('CONTAINER_ID') is not None:
        ident = f"vast-id-{os.environ.get('CONTAINER_ID')}-{ident}"
        reimage.ident = opt.ident

    # inference only service
    torch.set_grad_enabled(False)

    if opt.download:
        try:
            models.download_models()
            pipelines.download_models()
        except:
            traceback.print_exc(file=sys.stdout)
            print("\nError downloading models\n", flush=True)
            exit_cleanup(code=1)
        if opt.quit_after_download:
            exit_cleanup(code=0)

    os.environ['HF_HUB_OFFLINE'] = '1'
    models.load_models()

    global slow_count
    if opt.auto_priority:
        try:
            priority = priority + auto_priority()
        except BaseException as e:
            traceback.print_exc(file=sys.stdout)
            sys.stdout.flush()
            if exception_is_fatal(e):
                exit_cleanup(code=1)
    if os.environ.get('CONTAINER_ID') is not None:
        subprocess.run(["vastai","label","instance",os.environ.get('CONTAINER_ID'),f"diffusers-priority-{priority}-slow_count-{slow_count}","--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)

    reimage.notify(f"{setproctitle.getproctitle()} {ident} worker online! (priority: {priority})")
    print("Waiting for jobs...\n\n", flush=True)
    count=0
    while not exit_flag:
        try:
            # in this section we alternate the host name for some basic load balancing
            # if the hostname does not resolve, skip to next
            host = opt.host[count % len(opt.host)]
            count = count + 1
            ipaddr = ""
            try:
                ipaddr = socket.gethostbyname(host)
                if ipaddr == "0.0.0.0":
                    continue
                reimage.BASE_URL=f"https://{host}"
            except:
                time.sleep(0.1)
                continue

            prio = min(9, priority + slow_count) if opt.auto_priority else priority
            print(f"Checking for jobs... ({reimage.BASE_URL} {ident} {ipaddr} {prio})")
            reimage.dojob(
                "diffusers",
                process,
                prio,
                opt.timeout,
                hints=subproc.subprocs.keys(),
            )
            torch.cuda.empty_cache()
        except BaseException as e:
            traceback.print_exc(file=sys.stdout)
            sys.stdout.flush()
            if exception_is_fatal(e):
                exit_cleanup(code=1, requeue=False)


if __name__ == "__main__":
    parent = True
    setproctitle.setproctitle("diffusers")

    print("-----------------------------")
    print(" Launching diffusers service ")
    print("-----------------------------")

    mp.set_start_method('spawn', force=True)
    try:
        main()
    except BaseException as e:
        traceback.print_exc(file=sys.stdout)
        sys.stdout.flush()
        exit_cleanup(code=1)
    print(f"\n\nExiting... [{setproctitle.getproctitle()}]", flush=True)
