"""Testing for comfy-service"""
# pylint: disable=bare-except
# pylint: disable=broad-exception-raised
import os
import sys
import argparse
import signal
import traceback
import io
import torch
from service import launch_comfy, terminate_comfy
from test_cases import create_sets

# This is the default seed iterator used by many test suites
SEED_LIST=[1]

# If quick mode only run the first test case from each suite
QUICK_MODE=False

# extra memory to put on the GPU for a safety margin in the real world (in megabytes)
MARGIN=1024

# exclude these tests
EXCLUDE=None

# the output buffer for the comfy subproc stdout
comfy_output_buffer = io.StringIO()



def parse_args(argv):
    '''parse command line arguments'''
    # pylint: disable=global-statement
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("-s", "--seed", type=int, default=1)
    parser.add_argument("-m", "--margin", type=int, default=512)
    parser.add_argument("-e", "--exclude", type=str, default=None)
    parser.add_argument("-q", "--quick", action='store_true')
    parser.add_argument("otherthings", nargs=argparse.REMAINDER)
    options = parser.parse_args(argv)

    global SEED_LIST
    SEED_LIST = list(range(options.seed))
    global QUICK_MODE
    QUICK_MODE = options.quick
    global EXCLUDE
    EXCLUDE = options.exclude
    
    if options.margin > 0:
        device = torch.device("cuda:0")
        num_elements = options.margin * 1024 * 1024 // 4
        print(f"Allocating {num_elements} extra {options.margin} MB on the GPU...")
        gpu_tensor = torch.zeros(num_elements, dtype=torch.float32, device=device)

    return options.otherthings


def exit_cleanup():
    '''cleanup and exit'''
    try:
        terminate_comfy()
    except:
        traceback.print_exc(file=sys.stdout)
    # pylint: disable=protected-access
    os._exit(0)


def signal_handler(_, __):
    '''signal handler'''
    exit_cleanup()


def main():
    '''main'''
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    sets_to_run = parse_args(sys.argv[1:])
    if len(sets_to_run) == 0:
        sets_to_run = [""] # run all suites
    sets = create_sets(SEED_LIST)

    launch_comfy(output=comfy_output_buffer)

    test_count = 0
    try:
        for testset in sets:
            if any(sub in testset.name for sub in sets_to_run):
                if EXCLUDE is not None and EXCLUDE in testset.name:
                    print(f"Skipping excluded test {testset.name}")
                    continue
                success, count = testset.run(comfy_output_buffer, quick=QUICK_MODE)
                test_count = test_count + count
                if not success:
                    exit_cleanup()
        print(f"ALL TESTS PASSED {test_count}")
    except:
        traceback.print_exc(file=sys.stdout)

    exit_cleanup()



if __name__ == "__main__":
    main()
