"""upscale functions"""
import tempfile
import cv2

import reimage
import subproc
import subproc_util
import pipelines
from PIL import Image

from torchvision.utils import save_image
from basicsr import (
    FileClient,
    imfrombytes,
    img2tensor,
)
import torch
from typing import Dict
from reimage import Job
from util import tmpname

def resize_image_for_upscale(input_image: Image, resolution: int):
    """Resize an image to the specified resolution, round to SD compatible size"""
    input_image = input_image.convert("RGB")
    width, height = input_image.size
    scale = float(resolution) / max(height, width)
    height *= scale
    width *= scale
    height = int(round(height / 64.0)) * 64
    width = int(round(width / 64.0)) * 64
    img = input_image.resize((width, height), resample=Image.LANCZOS)
    return img


def accurate_upscale(image: str, scale: int) -> Dict[str, str]:
    """Runs the Accurate Image upscaler operation.

    Args:
        image (str): Path of the Image in str.
        scale (int): The scale factor to upscale.

    Returns:
        Dict: keys: str `image-{i}`, values: image_path_{i} in the temp dir.
    """
    torch.cuda.empty_cache()
    pipe = pipelines.pipe

    image = FileClient().get(image)
    image = imfrombytes(image, float32=True)
    image = img2tensor(image, bgr2rgb=True).unsqueeze(0)

    # Patchwise inference
    pipe.img = image
    pipe.lq = image
    pipe.scale = scale
    pipe.pre_process()
    pipe.tile_process()
    pipe.post_process()
    image = pipe.output # A tensor of B, C, H, W

    results = dict()
    for i, img in enumerate(image):
        image_path = tmpname(suffix=".png", prefix="upscale-acc-")
        save_image(img, image_path, normalize=True)
        results[f"image-{i}"] = image_path

    return results


def upscale_hat(job: Job, image: str, scale: int) -> Image:
    """Runs the HAT upscale pipeline in a subprocess and returns the results.
    HAT Upscaler is an accurate upscaler.

    Args:
        job (Job): Job object for the current job
        image (str): Path to the input image
        scale (int): The Upscaling factor.

    Raises:
        Exception: Image path is None.
        ValueError: If scale factor is not one of 2, 3, or 4.
        Exception: Either the width or the height of the image exceeds 2000px.

    Returns:
        Image: The upscaled PIL Image.
    """

    if image is None:
        raise Exception("image is required")
    if scale is None:
        raise ValueError("scale in required")
    if not isinstance(scale, int):
        raise ValueError(f"invalid scale: {scale}")
    scale = max(min(scale, 4),2) # scale must be 2-4
    image_obj = Image.open(image).convert("RGB")
    width, height = image_obj.size
    if height * width > 2048 * 2048:
        raise Exception(f"image too large: {width}x{height}")
    pipeline = f"hat_superres_{scale}X"
    args = {
        "image": image,
        "scale": scale
    }
    subproc_util.subproc_launch(pipeline, job, fn=pipelines.pipeline_get_loader(pipeline))
    reimage.update_job(job, "rendering")
    result = subproc.subproc_call(pipeline, accurate_upscale, args=args)
    return result
