"""Helper function to call functions in subprocesses"""
import sys
import traceback
import time
import setproctitle
import multiprocessing as mp
import os

import torch

# torch does not have the ability to free GPU memory by unloading pipelines & models
# as a result we load large pipeline in subprocesses so they can be killed (this frees the GPU memory)
# this is utility functions to start the subprocess and some helpers to call the subprocesses
# each subprocess is named after "{pipeline}" and loads the specified pipeline
# each subprocess can be instructed to call a target function with specified parameters and return the result
# the loaded pipeline will be passed to the target function for use

# This is the maximum number of subprocesses allowed
# This will depend on GPU memory and how much each subprocess takes
MAX_SUBPROCESS_COUNT = 1

# This stores a list of all the subprocesses running
subprocs = {
    # "subproc_name": { # then name of the subproc
    #     "process": None, # The process
    #     "start_time": None, # The start time of the process
    #     "last_time": None, # The last time the process was given a job
    #     "job_queue": None, # queue to deliver jobs to process
    #     "result_queue": None,  # queue to retrieve results from process
    # },
}


def subproc_running(subproc_name):
    """returns true if a subprocess is running for the specified pipeline"""
    return subproc_name in subprocs


def subproc_size(subproc_name):
    """This estimates a rough size of the subprocess"""
    # sdxl based models require the whole GPU
    return 1


def subproc_count():
    """returns the number of active subproc"""
    count = 0
    for process_name in subprocs:
        count = count + subproc_size(process_name)
    return count


def subproc_kill():
    """kills the subprocs that was least recently used"""
    if len(subprocs.items()) == 0:
        return
    sub_list = sorted(subprocs.items(), key=lambda item: item[1]['last_time'])
    name, sub = sub_list[0]
    print(f"Killing {name} subprocess...")
    sub["job_queue"].put(None)
    time.sleep(.1)
    sub["process"].terminate()
    sub["process"].join()
    sub["job_queue"].close()
    sub["result_queue"].close()
    del subprocs[name]
    print(f"Killed  {name} subprocess...")
    return


def subproc_kill_if_necessary(headroom_needed):
    """kills subproc until there is enough memory to start a new one"""
    while subproc_count() + headroom_needed > MAX_SUBPROCESS_COUNT:
        subproc_kill()


def subproc_start(subproc_name):
    """start a new subproc"""
    if subproc_running(subproc_name):
        return True

    # If we've reached our max subproc count - kill the least recently used subproc
    subproc_kill_if_necessary(subproc_size(subproc_name))

    print(f"Starting Process {subproc_name}", flush=True)
    job_queue = mp.Queue(5)
    result_queue = mp.Queue(5)
    subproc_process = mp.Process(target=subproc_main, args=(job_queue, result_queue))
    sub = {}
    sub["process"] = subproc_process
    sub["start_time"] = int(time.time())
    sub["last_time"] = int(time.time())
    sub["job_queue"] = job_queue
    sub["result_queue"] = result_queue
    subprocs[subproc_name] = sub
    subproc_process.start()
    result = False
    try:
        result = result_queue.get(timeout=300)
    except:
        pass
    if result:
        print(f"Started  Process {subproc_name}", flush=True)
    else:
        print(f"ERROR: Failed to start process {subproc_name}")
        del subprocs[subproc_name]
    return result


def subproc_main(job_queue, result_queue):
    """a subproc subprocess for a specific model/pipeline"""
    os.environ['HF_HUB_OFFLINE'] = '1'
    result_queue.put(True)

    while True:
        try:
            job = job_queue.get()
            if job is None:
                break # quit
            target = job.get('target')
            if target is None or not callable(target):
                print(f"Error: Invalid target {target}")
                break
            args = job.get('args')
            if args is None or not isinstance(args, dict):
                print(f"Error: Invalid args {args}")
                break

            result = target(**args)
            result_queue.put(result)
            torch.cuda.empty_cache()
        except Exception as e:
            result_queue.put(e)
            traceback.print_exc(file=sys.stdout)

    # stop process
    job_queue.close()
    result_queue.close()
    print(f"Subproc stopping... {setproctitle.getproctitle()}", flush=True)


def subproc_call(subproc_name, target, args=None):
    if subproc_name not in subprocs or subprocs[subproc_name].get("job_queue") is None:
        raise Exception(f"Missing subproc {subproc_name}")

    process_args = {
        "target": target,
        "args": args
    }
    subprocs[subproc_name]["last_time"] = int(time.time())
    subprocs[subproc_name]["job_queue"].put(process_args)
    result = subprocs[subproc_name]["result_queue"].get()
    if isinstance(result, BaseException):
        raise result
    if isinstance(result, Exception):
        raise result
    return result
