"""comfy service"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
# pylint: disable=invalid-name
# pylint: disable=broad-exception-raised
# pylint: disable=broad-exception-caught
# pylint: disable=bare-except
# pylint: disable=global-statement
import os
import sys
import subprocess
import multiprocessing as mp
import time
import socket
import traceback
import signal
import argparse
import threading
import re
import requests
import setproctitle
import psutil
from comfy_api_simplified import ComfyApiWrapper
import config_manager
import reimage

# relaunch comfy after this many jobs as a safety mechanism
COMFY_JOB_MAX = 50

# this is only set to true in the main process
parent = False

# if true, exit
exit_flag = False

# priority penalty - this is iterated each time something takes too long
penalty = 0

# the current priority
priority = 9

# a string to represent the identity of this worker
ident = socket.gethostname()

# default broker to call if no job is in progress
default_host = "api.reimage.io"

# comfy process
comfy_process = None

# total job count
job_count = 1


def process(
        job,
        op, # The operation of this job
        **kwargs
):
    """process the job"""
    if exit_flag:
        reimage.requeue_active_job()
        return

    global job_count
    job_count = job_count + 1
    print(f"Processing {job.job_id} {op} ...", flush=True)
    elapsed_time = reimage.update_job(job, 'processing')

    # getting the job metadata and updating the job to processing should be quick
    # the server returns the elapsed time from the job assigment time to change it to processing
    # if it takes longer than 5 seconds its likely there is a network or machine issue
    # increase penalty so priority is adjusted lower when this happens
    if elapsed_time is not None and isinstance(elapsed_time, int) and elapsed_time > 5000:
        print(f"Update {job.job_id} {op} to processing ({elapsed_time}ms) slow!", flush=True)
        iterate_penalty()

    results = {}
    t1 = time.time()
    try:
        reimage.update_job(job, 'rendering')
        args = {
            "op": op,
            **kwargs,
        }
        outputs = config_manager.run_inputs(args)
        print("\n\n\n")
        reimage.update_job(job, 'complete', outputs=outputs)
    except BaseException as exc:
        # the traceback is not worth printing - its just a traceback to calling the API
        # traceback.print_exc(file=sys.stdout)
        print(f"Exception: {exc}", flush=True)
        sanitized_error = sanitize_exception_string(exc)
        outputs = {
            "error": sanitized_error,
        }
        reimage.update_job(job, 'error', detail=sanitized_error, outputs=outputs)
        if exception_is_fatal(exc):
            # do not requeue this job in case it may cause fatal exceptions to other services
            exit_cleanup(code=1, requeue=False)
    finally:
        # delete all files in results
        if results is not None and results:
            for _, filename in results.items():
                try:
                    if isinstance(filename, str):
                        os.remove(filename)
                except BaseException as exc:
                    if isinstance(exc, FileNotFoundError):
                        # We don't care about file not found errors
                        pass
                    else:
                        traceback.print_exc(file=sys.stdout)


    t2 = time.time()
    print(f"Processed  {job.job_id} {op} ({t2-t1:0.2f}s)", flush=True)
    # Check to see if this job took too long
    # if it did - send a notification, and also lower your own priority
    # if this happens many times, just exit
    # if t2-t1 > 300.0:
    #     print(f"Processed  {job.job_id} {op} ({t2-t1:0.2f}s) slow!\n", flush=True)
    #     reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} slow {op} ({t2-t1:0.2f}s) (total: {penalty}) ")
    #     iterate_penalty()

    #     # if this job is exceptionally slow just exit
    #     if t2-t1 > 600.0:
    #         print("Extra slow job. Exiting...\n", flush=True)
    #         exit_cleanup(code=1)


def iterate_penalty():
    """iterate penalty value"""
    global penalty
    penalty = penalty + 1
    print(f"New penalty: {penalty}", flush=True)
    if os.environ.get('CONTAINER_ID') is not None:
        subprocess.run(["vastai","label","instance",os.environ.get('CONTAINER_ID'),f"comfy-priority-{priority}-penalty-{penalty}","--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)


def sanitize_exception_string(exc):
    """turn an exception into a decent error message"""
    exc_str = str(exc)
    exc_str = re.sub(r"TIPS:.*", "", exc_str)
    exc_str = exc_str.strip()
    detail=f"{type(exc).__name__}: {exc_str}"
    return detail


def exception_is_fatal(exc):
    """decide if exception is fatal"""
    exc_str=f"{type(exc).__name__}: {str(exc)}".lower()
    for s in [
            # Exception: Allocation on device This error means you ran out of memory
            "This error means you ran out of memory",
            # ConnectionClosedError: no close frame received or sent - this error happens when comfy crashes
            "ConnectionClosedError: no close frame received or sent",
            # Exception: Allocation on device This error means you ran out of memory
            "Allocation on device",
            # ConnectionError: HTTPConnectionPool(host='127.0.0.1'...
            "HTTPConnectionPool(host='127.0.0.1'",
            # ConnectionClosedError: no close frame received or
            "ConnectionClosedError",
            # [Errno 5] Input/output error - vast devices storage issues
            "[Errno 5] Input/output error",
            # gemini will return 400 for requests from blocked countries
            "400 POST https://generativelanguage.googleapis.com",
            # RuntimeError: CUDA error: out of memory
            # OutOfMemoryError: CUDA out of memory. Tried to allocate ...
            "out of memory",
            # OSError: [Errno 28] No space left on device
            "space left on device",
            # IOError: [Errno 24] Too many open files
            "too many open files",
            # OSError: [Errno 122] Disk quota exceeded
            "disk quota exceeded",
            # RuntimeError: No CUDA GPUs are available
            "no cuda gpus are available",
    ]:
        if s.lower() in exc_str:
            print(f"\nFatal Exception: {exc.__class__.__name__}\n", flush=True)
            return True
    print(f"\nNon-Fatal Exception: {exc.__class__.__name__}\n", flush=True)
    return False


def signal_handler(sig, _):
    """signal handler"""
    exit_cleanup(sig)


def exit_cleanup(sig=None, code=0, requeue=True):
    """cleanup and exit"""
    global exit_flag
    try:
        if not parent:
            print("Sending SIGINT to parent...")
            os.kill(os.getppid(), signal.SIGINT)
            return

        exit_flag = True
        reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} worker quitting! [{'none' if sig is None else signal.Signals(sig).name}]")

        # if this is vastai instances, stop the instance
        if os.environ.get('CONTAINER_ID') is not None:
            print(f"Stopping vast container: {os.environ.get('CONTAINER_ID')} {os.environ.get('CONTAINER_API_KEY')}")
            result = subprocess.run(["vastai","stop","instance",os.environ.get('CONTAINER_ID'),"--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)
            print(f"Stopping vast container: {result.returncode}")

        terminate_comfy()

        if requeue:
            reimage.requeue_active_job()
        for child in mp.active_children():
            child.terminate()

        print(f"\n\nExiting... [{setproctitle.getproctitle()}]", flush=True)
        # pylint: disable=protected-access
        os._exit(code)
    except:
        traceback.print_exc(file=sys.stdout)


def launch_comfy_subproc(output):
    """subproc logic"""
    print("Launching ComfyUI...")
    os.makedirs('/tmp/comfy-base', exist_ok=True)
    subprocess.run('ln -snf /usr/local/src/comfyui/models /tmp/comfy-base/models', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False)
    os.makedirs('/tmp/comfy-base/input', exist_ok=True)
    subprocess.run('ln -snf /usr/local/src/comfyui/assets /tmp/comfy-base/input/assets', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False)
    os.makedirs('/tmp/comfy-base/custom_nodes', exist_ok=True)
    subprocess.run('cp -r /usr/local/src/comfyui/custom_nodes/* /tmp/comfy-base/custom_nodes/', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False)
    command = ["python3", "/usr/local/src/comfyui/main.py", "--listen", "0.0.0.0", "--base-directory", "/tmp/comfy-base", "--disable-smart-memory"]
    with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd='/tmp/comfy-base') as proc:
        global comfy_process
        comfy_process = proc
        line_count = 0
        while True:
            line = proc.stdout.readline()
            if not line:
                break

            line = line.rstrip().decode()
            line_count = line_count + 1
            print(f"---- {line}", file=output, flush=True)

            if line_count < 1000 and "IMPORT FAILED" in line:
                # A failed import isnt fatal in comfy, but we consider it fatal
                print(f"Error: {line}", file=output, flush=True)
                comfy_process.kill()
                break


def terminate_comfy():
    """terminate comfy"""
    global comfy_process
    proc = comfy_process
    comfy_process = None
    if proc is None:
        return
    try:
        print("Terminating comfy...", flush=True)

        #os.kill(proc.pid, signal.SIGINT) # graceful shutdown
        parentproc = psutil.Process(proc.pid)
        for child in parentproc.children(recursive=True):
            child.kill()
        parentproc.kill()

        children = mp.active_children()
        for child in children:
            child.join(timeout=10.0)

        proc.terminate()
    except:
        traceback.print_exc()


def launch_comfy(output=sys.stdout):
    """launch comfy and read stdout"""
    terminate_comfy()
    t = threading.Thread(target=launch_comfy_subproc, args=(output,))
    t.start()
    time.sleep(1)
    api = ComfyApiWrapper("http://127.0.0.1:8188/")
    while True:
        try:
            time.sleep(.1)
            api.get_queue() # throw exception if fails
            return
        except:
            continue


def main():
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--host",
        nargs='*',
        default=["broker-0.reimage.io", "broker-1.reimage.io", "broker-2.reimage.io", "broker-3.reimage.io"],
        help="The broker host(s)",
    )
    parser.add_argument(
        "--ident",
        type=str,
        default=None,
        help="The identity",
    )
    parser.add_argument(
        "--priority",
        type=int,
        default=2,
        help="The priority for requesting jobs (0 is highest)",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=4,
        help="The timeout for requesting jobs",
    )
    parser.add_argument(
        "--insecure",
        action='store_true',
        help="If specified HTTPS cert is not verified (for development use only)",
    )
    parser.add_argument(
        "--extra",
        dest="extra",
        action='store_true',
        help="return extra results from jobs",
    )

    opt = parser.parse_args()

    global priority
    priority = opt.priority
    global default_host
    default_host = opt.host[0]

    reimage.VERIFY = not opt.insecure
    reimage.SEND_EXTRA = opt.extra
    if opt.insecure:
        requests.packages.urllib3.disable_warnings()
    if opt.ident is not None:
        global ident
        ident = opt.ident
        reimage.ident = opt.ident
    if os.environ.get('CONTAINER_ID') is not None:
        ident = f"vast-id-{os.environ.get('CONTAINER_ID')}-{ident}"
        reimage.ident = opt.ident

    if os.environ.get('CONTAINER_ID') is not None:
        subprocess.run(["vastai","label","instance",os.environ.get('CONTAINER_ID'),f"comfy-priority-{priority}-penalty-{penalty}","--api-key",os.environ.get('CONTAINER_API_KEY')], shell=False, capture_output=False, check=False, timeout=20)

    launch_comfy()

    reimage.notify(default_host, f"{setproctitle.getproctitle()} {ident} worker online! (priority: {priority})")
    print("Waiting for jobs...\n\n", flush=True)
    count=0
    while not exit_flag:
        try:
            # if penalty is too high just exit
            if penalty > 5:
                print("Penalty threshold exceeded. Exiting...\n", flush=True)
                exit_cleanup(code=1)

            # relaunch comfy if job max is reached
            if job_count % COMFY_JOB_MAX == 0:
                print("Relaunching comfy...", flush=True)
                launch_comfy()

            # in this section we alternate the host name for some basic load balancing
            # if the hostname does not resolve, skip to next
            host = opt.host[count % len(opt.host)]
            count = count + 1
            ipaddr = ""
            try:
                ipaddr = socket.gethostbyname(host)
                if ipaddr == "0.0.0.0":
                    continue
            except:
                time.sleep(0.1)
                continue

            #prio = min(9, priority + penalty) if opt.auto_priority else priority
            prio = priority
            print(f"Checking for jobs... ({host} {ident} {ipaddr} {prio}) {job_count}", flush=True)
            reimage.dojob(
                host,
                "comfy",
                process,
                prio,
                opt.timeout,
                hints=[], # XXX
            )
        except BaseException as exc:
            traceback.print_exc(file=sys.stdout)
            sys.stdout.flush()
            if exception_is_fatal(exc):
                exit_cleanup(code=1, requeue=False)


if __name__ == "__main__":
    parent = True
    setproctitle.setproctitle("comfy-service")

    print("-----------------------------")
    print(" Launching comfy service ")
    print("-----------------------------")

    mp.set_start_method('spawn', force=True)
    try:
        main()
    except BaseException as e:
        traceback.print_exc(file=sys.stdout)
        sys.stdout.flush()
        exit_cleanup(code=1)
    print(f"\n\nExiting... [{setproctitle.getproctitle()}]", flush=True)
