"""A custom wrapper that adds some functions"""
import asyncio
import json
import uuid
import requests
import websockets
from requests.compat import urljoin
from comfy_api_simplified import ComfyApiWrapper
from comfy_api_simplified import ComfyWorkflowWrapper

class CustomComfyApiWrapper(ComfyApiWrapper):
    """Our wrapper for the official wrapper with custom methods"""
    def __init__(self, path: str):
        super().__init__(path)
        self.text = ""

    def queue_and_wait_results(
            self, prompt: ComfyWorkflowWrapper, loop: asyncio.BaseEventLoop = asyncio.get_event_loop()
    ):
        """queue prompt and wait on results"""
        prompt_id = loop.run_until_complete(self.queue_prompt_and_wait(prompt))

        history = self.get_history(prompt_id)
        results = history[prompt_id]
        # print(f" {pprint.pformat(results)[1:-1]}")
        return results

    async def queue_prompt_and_wait(self, prompt: dict) -> str:
        """
          Queues a prompt for execution and waits for the result.
           Args:
              prompt (dict): The prompt to be executed.
           Returns:
              str: The prompt ID.
           Raises:
              Exception: If an execution error occurs.
         """
        client_id = str(uuid.uuid4())
        resp = self.queue_prompt(prompt, client_id)
        prompt_id = resp["prompt_id"]
        async with websockets.connect(uri=self.ws_url.format(client_id)) as websocket:
            while True:
                # out = ws.recv()
                out = await websocket.recv()
                if isinstance(out, str):
                    message = json.loads(out)
                    if message["type"] == "crystools.monitor":
                        continue
                    if message["type"] == "execution_error":
                        data = message["data"]
                        if data["prompt_id"] == prompt_id:
                            if data.get('exception_message'):
                                raise Exception(data.get('exception_message'))
                            raise Exception("Execution error occurred.")
                    if message["type"] == "status":
                        data = message["data"]
                        if data["status"]["exec_info"]["queue_remaining"] == 0:
                            return prompt_id
                    if message["type"] == "executing":
                        data = message["data"]
                        if data["node"] is None and data["prompt_id"] == prompt_id:
                            return prompt_id

    def free(self):
        """calls /free api endpoint - not clear if effective"""
        payload = {
            "unload_models": True,
            "free_memory": True
        }
        data = json.dumps(payload).encode("utf-8")
        url = urljoin(self.url, "/free")
        resp = requests.post(url, auth=self.auth, data=data, timeout=5)
        if resp.status_code == 200:
            print("Memory freed")
            return resp.status_code
        raise Exception(
            f"Request failed with status code {resp.status_code}: {resp.reason}"
        )
