#!/usr/bin/python3
import os
import signal
import sys
import time
import requests
import subprocess
import json
import traceback
import datetime
import argparse
import math

VERIFY=True
BASE_URL='https://api.reimage.io/'

# number of seconds to lockout new workers after having started one
lockout_seconds = 3*60*1
# How much to bid over the minimum - this is to stop instances from going down too often
over_bid = 0.04
# Minimum bid
min_bid = 0.12
# Maximum bid
max_bid = 0.95
# Minimum driver version
min_driver_version = 525
# desired avail load
avail_target = 8.0
# max startup time a instance should take
max_startup_sec = 45*60
# if instances stopped before this time it is considered quick stop
quick_stop_sec = 60*60

# search parameters
base_params = f"disk_space>30 gpu_ram>23 inet_down>300 inet_up>100 inet_up_cost<.02 inet_down_cost<.01 min_bid<{max_bid} verified=true"

# workers to monitor
worker_names = ["diffusers"]

gpu_specs = {
    "RTX 3090": {
        "priority": 5,
        "bid_boost": 0.0,
    },
    "RTX 4090": {
        "priority": 4,
        "bid_boost": 0.10,
    },
    "RTX 3090 Ti": {
        "priority": 5,
        "bid_boost": 0.0,
    },
    "RTX 4090 Ti": {
        "priority": 4,
        "bid_boost": 0.20,
    },
    "RTX A6000": {
        "priority": 5,
        "bid_boost": 0.0,
    },
    "RTX 6000Ada": {
        "priority": 4,
        "bid_boost": 0.05,
    },
    "RTX 5000Ada": {
        "priority": 4,
        "bid_boost": 0.05,
    },
    "A100 SXM4":  {
        "priority": 5,
        "bid_boost": 0.0,
    },
    "A100 PCIE":  {
        "priority": 5,
        "bid_boost": 0.0,
    },
    "H100 PCIE": {
        "priority": 5,
        "bid_boost": 0.10,
    },
    "GH200 SXM":{
        "priority": 5,
        "bid_boost": 0.10,
    },
}

# banned hosts
banned_host_id = [
    60400, # Texas, US - ROME EPYC very slow
    39244, # North Carolina, US - ROME EPYC very slow
    8199, # Nevada, US - network issues, DNS issues
    134693, # Hong Kong HK - bad docker cache (not updated)
]

# banned machines
banned_machine_id = [
    15825, # this machine is garbage (300+ second renders consistently)
    15393, # many cuda errors
]

# experiment
# this dict tracks which host failed to start instances correctly to track problematic hosts
# a map from host_id to startup failure counts
# this will allow us to disregard problematic hosts when starting a new instance
host_stats = {
    # 1234: {
    # "quick_stops": 3,
    # "stops": 3,
    # "failures": 1,
    # "starts": 10,
    #},
}

# stores the last time a worker started
worker_variables = {
    'diffusers': {
        'last_start_time': 0,
    },
}

# stores immutable worker properties
worker_properties = {
    'diffusers': {
        'image': 'reimage/diffusers:cloud',
        #'cmd': 'python3 /usr/local/src/diffuser-models/src/service.py --download --priority {priority}',
        #'cmd': 'python3 /usr/local/src/diffuser-models/src/service.py --priority {priority} --ident {ident}',
        'cmd': 'python3 /usr/local/src/diffuser-models/src/service.py --auto-priority --ident {ident}',
        'disk_size': 5,
        'expected_disk_usage': 0.0
    },
}


def vast_instance_rank(instance):
    rank = 0.50
    # lower rank is better

    # starting rank is just the price
    # use dhp_total or min_bid (preferred)
    if instance.get('dph_total') is not None:
        rank = instance.get('dph_total')
    if instance.get('min_bid') is not None:
        rank = instance.get('min_bid')

    # apply the bid boost to prefer better GPUs
    gpu_name = instance.get('gpu_name')
    if gpu_name is not None and gpu_specs.get(gpu_name) is not None and gpu_specs.get(gpu_name).get('bid_boost') is not None:
        rank = rank - gpu_specs.get(gpu_name).get('bid_boost')

    # These epyc machines are turds, deprioritize them
    if "EPYC" in str(instance.get('cpu_name')) and "ROME" in str(instance.get('mobo_name')):
        rank = rank + 0.10

    return rank


def vast_instance_cmd(instance):
    """returns the start cmd of a given instance"""
    cmd = ""
    onstart = instance.get("onstart")
    if onstart is not None:
        cmd = onstart
    args = instance.get("image_args")
    if args is not None and isinstance(args, list) and len(args) > 0 and args[0] is not None:
        cmd = args[0]
    # Only the first two words as the rest has variables
    return cmd


def vast_search(interruptable=False):
    instances = vast_instances()
    machine_ids = []
    if instances is not None:
        for i in instances:
            machine_ids.append(i["machine_id"])

    if interruptable:
        args = ['vastai', 'search', 'offers', '-i', '--raw', base_params]
    else:
        args = ['vastai', 'search', 'offers', '--raw', base_params]

    print(f"Searching for suitable {'interruptable' if interruptable else 'on-demand'} instances")
    result = subprocess.run(args, stdout=subprocess.PIPE, check=False)
    if result.returncode != 0 or result.stdout is None:
        print(f"Failed vastai process: {result.returncode} {result.stdout}")
        return None
    try:
        options = []
        resp = json.loads(result.stdout.decode("utf-8"))
        print(f"Evaluating {len(resp)} options...")
        for r in resp:
            if r['gpu_name'] not in gpu_specs:
                print(f"Ignoring GPU {r['gpu_name']}")
                continue
            if r["machine_id"] in machine_ids:
                print(f"Ignoring ID {r['id']} - We already have this machine")
                continue
            if r["host_id"] in banned_host_id:
                print(f"Ignoring Host {r['host_id']} - Banned host")
                continue
            if r["machine_id"] in banned_machine_id:
                print(f"Ignoring Machine {r['host_id']} - Banned machine")
                continue
            if r["driver_version"] is not None and int(r["driver_version"].split(".")[0]) < 525:
                print(f"Ignoring ID {r['id']} - Driver too old: {r.get('driver_version')}")
                continue
            options.append(r)

        if len(options) == 0:
            print("No vast options found!")
            return None

        print(f"Evaluating {len(options)} options.")
        # find the best instances
        options = sorted(options, key=lambda x: vast_instance_rank(x))

        for x in options:
            print(f"Found| ID: {x['id']} Host: {x['host_id']} Price: ${x['dph_total']:.2f}/hr GPU: {x['gpu_name']} CPU: {x['cpu_name']} Location: {x['geolocation']} Host:{x['host_id']}")

        print(f"Best | ID: {options[0]['id']} Host: {options[0]['host_id']} Price: ${options[0]['dph_total']:.2f}/hr GPU: {options[0]['gpu_name']} CPU: {options[0]['cpu_name']} Location: {options[0]['geolocation']}")
        # print(json.dumps(options[0], indent=2))

        print("")
        return options[0]
    except:
        print("Failed to vast search")
        traceback.print_exc()
    return None


def vast_create_instance(worker_name, vast_id, ident, bid_price=None, priority=6, reset_time=True, host_id=None):
    print(f"Starting {worker_name} instance...")
    args = ['vastai', 'create', 'instance', str(vast_id),
            '--image', worker_properties[worker_name]['image'],
            '--login', '-u reimage -p dckr_pat_-O_CRA9-WcBNtlZV7S942RZPXPI docker.io',
            '--raw', # '--ssh',
            '--disk', str(worker_properties[worker_name]['disk_size']),
            #'--onstart-cmd', worker_properties[worker_name]['onstart'].format(priority=priority),
            ]
    if bid_price is not None:
        args.append('--price')
        args.append(str(bid_price))
    # args must come last
    args.append('--args')
    args.append(f"{worker_properties[worker_name]['cmd']}".format(priority=priority, ident=ident))

    result = subprocess.run(args, stdout=subprocess.PIPE, check=False)
    if result.returncode != 0 or result.stdout is None:
        print(f"Failed vastai process: {result.returncode} {result.stdout}")
        return None
    if reset_time:
        worker_variables[worker_name]['last_start_time'] = time.time()
    try:
        resp = json.loads(result.stdout.decode("utf-8"))
        if resp.get('success') is None or resp.get('success') is not True:
            print(f"Failed to start {worker_name} {resp}")
            return None
    except:
        print(f"Error: vastai create instance: {result.stdout}")
        return None
    if host_id is not None:
        host_stats_increment(host_id, "starts")
    print(args)
    print(f"Started  {worker_name}.")
    return resp['new_contract']


def vast_destroy_stopped_instances(worker_name):
    instances = vast_instances(worker_name)
    if instances is None:
        return 0
    destroyed = 0
    for instance in instances:
        age = round(time.time() - instance.get("start_date"))

        # Check for state running
        actual_status = instance.get('actual_status')
        intended_status = instance.get('intended_status')
        cur_state = instance.get('cur_state')
        next_state = instance.get('next_state')
        status_msg = str(instance.get('status_msg')).replace("\n", "")

        # check for known bad status_msg
        for bad in [
                "driver failed",
                "pull access denied",
        ]:
            if bad in status_msg:
                print(f"Destroying instance: {instance.get('id')} (bad status_msg)\n"
                      f"  age:           {age/60:.1f}m\n"
                      f"  actual_status: {actual_status}\n"
                      f"  status_msg:    \"{status_msg}\"\n"
                      f"  current_state: {cur_state}\n"
                      f"  next_state:    {next_state}\n")
                destroyed = destroyed + 1
                vast_destroy_instance(instance.get('id'))
                host_stats_increment(instance.get('host_id'), "failures")
                host_stats_increment(instance.get('host_id'), "failure_bad_status")
                host_stats_print(instance.get('host_id'))
                continue

        # check that an instances at least 60 minutes old is running
        # after 60 minutes, instances should be done loading and the status should be:
        # actual_status: running
        # cur_state: running
        # next_state: running
        # if we never reach running after 60 minutes, consider this a failure
        if age > 60*60 and actual_status != "running":
            print(f"Destroying instance: {instance.get('id')} (actual_status != running)\n"
                  f"  age:           {age/60:.1f}m\n"
                  f"  actual_status: {actual_status}\n"
                  f"  status_msg:    \"{status_msg}\"\n"
                  f"  current_state: {cur_state}\n"
                  f"  next_state:    {next_state}\n")
            destroyed = destroyed + 1
            vast_destroy_instance(instance.get('id'))
            if not instance.get('status_msg').startswith('success'):
                host_stats_increment(instance.get('host_id'), "failures")
                host_stats_increment(instance.get('host_id'), "failure_to_launch")
                host_stats_print(instance.get('host_id'))
            continue

        # check that an instances at least 2 minutes old is either loading or running
        if age > 2*60 and (actual_status not in ["running", "loading"] or intended_status != "running"):
            print(f"Destroying instance: {instance.get('id')} (actual_status not loading or running)\n"
                  f"  age:           {age/60:.1f}m\n"
                  f"  actual_status: {actual_status}\n"
                  f"  status_msg:    \"{status_msg}\"\n"
                  f"  current_state: {cur_state}\n"
                  f"  next_state:    {next_state}\n")
            destroyed = destroyed + 1
            vast_destroy_instance(instance.get('id'))
            host_stats_increment(instance.get('host_id'), "stops")
            if age > quick_stop_sec:
                host_stats_increment(instance.get('host_id'), "quick_stops")
            host_stats_print(instance.get('host_id'))
            continue

        # check that an instances are running
        if age > max_startup_sec and cur_state != "running" and next_state != "running":
            print(f"Destroying instance: {instance.get('id')} (state != running)\n"
                  f"  age:           {age/60:.1f}m\n"
                  f"  actual_status: {actual_status}\n"
                  f"  status_msg:    \"{status_msg}\"\n"
                  f"  current_state: {cur_state}\n"
                  f"  next_state:    {next_state}\n")
            destroyed = destroyed + 1
            vast_destroy_instance(instance.get('id'))
            host_stats_increment(instance.get('host_id'), "stops")
            if age > quick_stop_sec:
                host_stats_increment(instance.get('host_id'), "quick_stops")
            host_stats_print(instance.get('host_id'))
            continue

        # Check that it started up
        # disk_usage = instance.get("disk_usage")
        # if age > 5*60 and disk_usage is not None and disk_usage < 0.1:
        #     print(f"Destroying instance: {instance.get('id')} (no disk usage)\n"
        #           f"  age: {age}s\n"
        #           f"  disk_usage: {disk_usage}")
        #     destroyed = destroyed + 1
        #     vast_destroy_instance(instance.get('id'))
        #     increment_host_failure(instance.get('host_id'))

        # Check for instances stuck on the download
        # disk_usage = instance.get("disk_usage")
        # expected_disk_usage = worker_properties[worker_name]["expected_disk_usage"]
        # if age > 25*60 and disk_usage is not None and disk_usage < worker_properties[worker_name]["expected_disk_usage"]:
        #     print(f"Destroying instance: {instance.get('id')} (unexpected disk usage)\n"
        #           f"  age: {age}s\n"
        #           f"  expected_disk_usage: {expected_disk_usage}\n"
        #           f"  disk_usage: {disk_usage}")
        #     destroyed = destroyed + 1
        #     vast_destroy_instance(instance.get('id'))
        #     increment_host_failure(instance.get('host_id'))

    return destroyed


def vast_destroy_instance(instance_id):
    if id is None:
        print(f"Invalid ID: {instance_id}")
        return False
    args = ['vastai', 'destroy', 'instance', str(instance_id), '--raw']
    result = subprocess.run(args, stdout=subprocess.PIPE, check=False)
    if result.returncode != 0 or result.stdout is None:
        print(f"Failed vastai process: {result.returncode} {result.stdout}")
        return False
    try:
        resp = json.loads(result.stdout.decode("utf-8"))
        if resp.get('success') is None or resp.get('success') is not True:
            print(f"Failed to destroy instance {instance_id}")
            return False
    except:
        print(f"Error: vastai destroy instance: {result.stdout}")
        return None
    print(f"Destroyed  instance: {instance_id}")
    return True


def vast_instances(worker=None):
    args = ['vastai', 'show', 'instances', '--raw']
    result = subprocess.run(args, stdout=subprocess.PIPE, check=False)
    if result.returncode != 0 or result.stdout is None:
        print(f"Failed vastai process: {result.returncode} {result.stdout}")
        return []
    # If this worker is not managed by us - just return
    if worker not in worker_properties:
        return []
    try:
        instances = json.loads(result.stdout.decode("utf-8"))
        if not isinstance(instances, list):
            print(f"Invalid response to show instances: {instances}")
            return None
        if worker is None:
            return instances
        worker_instances = []
        for instance in instances:
            cmd = vast_instance_cmd(instance)
            if cmd.split()[:2] == worker_properties[worker]["cmd"].split()[:2]:
                worker_instances.append(instance)
        # sort by most expensive to least expensive
        worker_instances = sorted(worker_instances, key=lambda k: k.get('dph_total'), reverse=True)
        return worker_instances
    except:
        traceback.print_exc(file=sys.stdout)
        print(f"Error: vastai show instances: {result.stdout}")
        return []


def current_worker_count(worker_name):
    instances = vast_instances(worker_name)
    if instances is None:
        return 0
    count = 0
    for instance in instances:
        if instance.get("cur_state") != "running":
            continue
        if instance.get("next_state") != "running":
            continue
        if instance.get("actual_status") not in ["running", "loading"]:
            continue
        if instance.get("status_msg") is None or not instance.get("status_msg").startswith("success"):
            continue
        count = count + 1
    return count


def host_stats_increment(host_id, field):
    print(f"Incrementing host stat {field} for {host_id}.")
    if host_stats.get(host_id) is None:
        host_stats[host_id] = {}

    if host_stats.get(host_id).get(field) is None:
        host_stats[host_id][field] = 1
    else:
        host_stats[host_id][field] = host_stats[host_id][field] + 1


def host_stats_print(host_id=None):
    if host_id is None or host_stats.get(host_id) is None:
        print("Host stats:")
        print(json.dumps(host_stats, indent=4))
    else:
        print(f"Host {host_id} stats:")
        print(json.dumps(host_stats.get(host_id), indent=4))
    print("")


def time_to_reduce_workers():
    """returns true if its time to reduce workers. This is done about once a day at cooldown time"""
    begin_time, end_time = datetime.time(hour=21,minute=0), datetime.time(hour=21,minute=1)
    now_time = datetime.datetime.now().time()
    return begin_time <= now_time <= end_time


def get_stats():
    resp = requests.get(f"{BASE_URL}stats", verify=VERIFY)
    if resp.status_code != 200:
        print(f"Invalid response {resp.status_code}")
        return None
    response = resp.json()
    return response


def worker_lockout(worker_name):
    diff = time.time() - worker_variables[worker_name]['last_start_time']
    if diff < lockout_seconds:
        print(f"{worker_name} on lockout - {(lockout_seconds - diff):.1f} seconds remaining")
        return True
    return False


def add_worker(worker_name, reset_time=True):
    interruptable=True
    vast_machine = vast_search(interruptable=interruptable)
    if vast_machine is not None:
        # Bid a minimum of min_bid - this is because anything less gets interrupted quickly
        # Add over_bid to the bid
        if interruptable:
            bid_price = max(vast_machine['min_bid']+over_bid, min_bid)
        else:
            bid_price = max(vast_machine['dph_base']+over_bid, min_bid)

        priority = 6
        gpu_name = vast_machine.get('gpu_name')
        if gpu_name is not None and gpu_specs.get(gpu_name) is not None and gpu_specs.get(gpu_name).get('priority') is not None:
            priority = gpu_specs.get(gpu_name).get('priority')
        return vast_create_instance(worker_name, vast_machine['id'], f"machine-{vast_machine['machine_id']}-host-{vast_machine['host_id']}", bid_price=bid_price, priority=priority, reset_time=reset_time, host_id=vast_machine['host_id'])
    return None


def signal_handler(sig, _):
    # pylint: disable=protected-access
    os._exit(0)


def main():
    signal.signal(signal.SIGINT, signal_handler)
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--add",
        nargs=2,
        metavar=('num_workers', 'workers'),
        default=[0, 'none'],
        help="workers to add",
    )

    opt = parser.parse_args()

    num_workers, worker = opt.add 
    if num_workers is not None and worker is not None:
        added = False
        for _ in range(int(num_workers)):
            added = True
            add_worker(worker)
        if added:
            # pylint: disable=protected-access
            os._exit(os.EX_OK)

    i = 0
    while True:
        try:
            if i > 0:
                time.sleep(30)
            i = i + 1

            stats = get_stats()
            if stats is None:
                continue

            print(f"----------- {datetime.datetime.now().isoformat(sep=' ', timespec='seconds')} -----------")

            for k, v in stats.items():
                print(f"{k:12} | queue [{v['queueDescription']:25}] | queue: {v['queueLength']:4} | queue_load: {float(v['queueLoad1Min']):5.2f} | avail_load: {float(v['availableLoad1Min']):5.2f} | processed: {v['count']:8} ")

            for w in worker_names:
                queuestats = stats.get(w)
                if queuestats is None:
                    print(f"stats missing {w}")
                    continue
                avail_load = queuestats.get('availableLoad1Min')
                if avail_load is None:
                    print(f"{w} stats missing availableLoad1Min")
                    continue
                queue_load = queuestats.get('queueLoad1Min')
                if queue_load is None:
                    print(f"{w} stats missing queueLoad1Min")
                    continue
                worker_count = current_worker_count(w)
                print(f"{w:12} | vast_workers: {worker_count}")

                destroyed = vast_destroy_stopped_instances(w)
                if destroyed > 0 and avail_load < avail_target:
                    # If we're low on workers restart these immediately
                    print(f"Attempting to replace {destroyed} {w} workers.")
                    for i in range(destroyed):
                        add_worker(w, reset_time=False)
                    print(f"Attempted to restart {destroyed} {w} workers.")

                if avail_load < avail_target or queue_load > 1.0:
                    needed = min(15, math.floor((queue_load/2)+(avail_target-avail_load)))
                    if needed > 0:
                        print(f"{w} needs {needed} more workers. avail_load: {avail_load:.2f}  queue_load: {queue_load:.2f} current_workers: {worker_count}")
                        if not worker_lockout(w):
                            for i in range(needed):
                                add_worker(w)
                            print(f"Attempted to start {needed} {w} workers.")

                if time_to_reduce_workers():
                    print("\n\nExcess workers reduction:")
                    reduction = int(avail_load - avail_target)
                    reduction = 1
                    print(f"Reduction: {reduction}")
                    if reduction > 0:
                        instances = vast_instances(w)
                        count = 0
                        for inst in instances:
                            if count >= reduction:
                                break
                            vast_destroy_instance(inst.get('id'))
                            count = count+1

            sys.stdout.flush()
        except:
            traceback.print_exc(file=sys.stdout)


if __name__ == "__main__":
    main()
