import argparse
import cv2
import glob
import os
import sys
import signal
import requests
import time
import traceback
import tempfile
import setproctitle
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url

from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import torch
import reimage

# Global models
max_scale = 10
upsampler = None
face_enhancers = [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]

def process(job_id, image, face_enhance=True, scale=4):
    print(f"Processing {image} ...", flush=True)

    reimage.update_job(job_id, 'upscaling')

    t1 = time.time()
    output = None
    try:
        img = cv2.imread(image, cv2.IMREAD_UNCHANGED)

        if face_enhance:
            if scale > len(face_enhancers)-1 or face_enhancers[scale] is None:
                reimage.update_job(job_id, 'error', detail="Invalid scale")
                return
            _, _, output = face_enhancers[scale].enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
        else:
            output, _ = upsampler.enhance(img, outscale=scale)
    except:
        traceback.print_exc()
        reimage.update_job(job_id, 'error', detail=f"Exception: {sys.exc_info()[0]} {sys.exc_info()[1]}")
        # If CUDA out of memory, exit so service will be restarted
        if "OutOfMemoryError" in str(sys.exc_info()[0]):
            os._exit(1)
    else:
        f, path = tempfile.mkstemp(suffix=".png",prefix="reimage-")
        cv2.imwrite(path, output)
        reimage.update_job(job_id, 'complete', results={"image": path})
        os.close(f)

    t2 = time.time()
    if output is None:
        print("Processed  failed ({0:0.2f}s)".format(t2-t1), flush=True)
    else:
        print("Processed  {0} ({1:0.2f}s)".format(image, t2-t1), flush=True)


def signal_handler(sig, _):
    try:
        reimage.requeue_active_job()
        reimage.notify(f"{setproctitle.getproctitle()} worker quitting! [{signal.Signals(sig).name}]")
    finally:
        os._exit(1)


def main():
    """Inference demo for Real-ESRGAN.
    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--host",
        type=str,
        default="api.reimage.io",
        help="The broker host",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="",
        help="Which device to use",
    )
    parser.add_argument(
        "--priority",
        type=int,
        default=1,
        help="The priority for requesting jobs (0 is highest)",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=10,
        help="The timeout for requesting jobs",
    )
    parser.add_argument(
        "--insecure",
        action='store_true',
        help="If specified HTTPS cert is not verified (for development use only)",
    )
    parser.add_argument(
        '-n',
        '--model_name',
        type=str,
        default='RealESRGAN_x4plus',
        help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | '
              'realesr-animevideov3 | realesr-general-x4v3')
    )
    parser.add_argument(
        '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
    parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
    parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
    parser.add_argument('-t', '--tile', type=int, default=1024, help='Tile size, 0 for no tile during testing')


    opt = parser.parse_args()

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    reimage.BASE_URL = f"https://{opt.host}"
    reimage.VERIFY = not opt.insecure
    if opt.insecure:
        requests.packages.urllib3.disable_warnings()

    # determine models according to model names
    opt.model_name = opt.model_name.split('.')[0]
    if opt.model_name == 'RealESRGAN_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
    elif opt.model_name == 'RealESRNet_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
    elif opt.model_name == 'RealESRGAN_x4plus_anime_6B':  # x4 RRDBNet model with 6 blocks
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
    elif opt.model_name == 'RealESRGAN_x2plus':  # x2 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
    elif opt.model_name == 'realesr-animevideov3':  # x4 VGG-style model (XS size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
    elif opt.model_name == 'realesr-general-x4v3':  # x4 VGG-style model (S size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4
        file_url = [
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
        ]

    # determine model paths
    model_path = os.path.join('weights', opt.model_name + '.pth')
    if not os.path.isfile(model_path):
        model_dir = "/mldata/real-esrgan/"
        for url in file_url:
            # model_path will be updated
            model_path = load_file_from_url(
                url=url, model_dir=model_dir, progress=True, file_name=None)

    # use dni to control the denoise strength
    dni_weight = None
    if opt.model_name == 'realesr-general-x4v3' and opt.denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [opt.denoise_strength, 1 - opt.denoise_strength]

    device = None
    if opt.device != "":
        device = int(opt.device)

    # restorer
    global upsampler
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=opt.tile,
        tile_pad=opt.tile_pad,
        pre_pad=opt.pre_pad,
        half=not opt.fp32,
        gpu_id=device)

    global face_enhancers
    from gfpgan import GFPGANer
    for i in range(2, max_scale+1):
        face_enhancers[i] = GFPGANer(
            model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
            upscale=i,
            arch='clean',
            channel_multiplier=2,
            bg_upsampler=upsampler)

    reimage.notify(f"{setproctitle.getproctitle()} worker online! (priority: {opt.priority})")
    print("Waiting for jobs...", flush=True)
    while True:
        with torch.no_grad():
            reimage.dojob("real-esrgan", process, opt.priority, opt.timeout)
        torch.cuda.empty_cache()


if __name__ == "__main__":
    setproctitle.setproctitle('real-esrgan')
    main()
