import cv2
import tempfile
import numpy as np

from PIL import Image


class OperationCheckpoint:
    def __init__(self):
        self.save_idx = 0
        self.state = {}

    def save_image(
        self,
        img_arr,
        tmp_prefix="paint-",
        ckpt_name=None,
        is_extra=True,
        update_idx=True,
        convert_from_lab=False,
    ):
        if convert_from_lab:
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_LAB2RGB)
            img_arr = (255 * img_arr).astype(np.uint8)
        img = Image.fromarray(img_arr).convert("RGB")
        _, image_path = tempfile.mkstemp(suffix=".png", prefix=tmp_prefix)
        img.save(image_path)
        self.save_result(
            image_path, name=ckpt_name, is_extra=is_extra, update_idx=update_idx
        )
        return image_path

    def save_result(self, image_path, name=None, is_extra=True, update_idx=True):
        idx = self.save_idx
        if name:
            if is_extra:
                key = f"extra-{idx:02}-{name}"
            else:
                key = name
        else:
            key = f"extra-{idx:02}"
        self.state[key] = image_path
        if update_idx:
            self.save_idx += 1

    def merge(self, results, prefix=None):
        for key in results.keys():
            if not prefix or key.startswith(prefix):
                self.state[key] = results[key]

    def get_state(self):
        return self.state
