"""grounded sam utils"""

from collections import defaultdict, deque
from contextlib import contextmanager
from enum import Enum

import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
import yaml

from PIL import Image
from skimage import morphology
from torchvision.ops import nms

import models


class Segment:
    r"""
    Segment class for segmentation pipeline.

    Args:
        segment_id (`int`):
            The segment's unique id
        name ('str'):
            The segment's name
        classes ('List[str]'):
            Optional input classes for DINO. Used to generate bounding boxes for SAM.
        boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
            Optional input boxes for SAM.
        points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
            Optional input points for SAM.
        box_nms (`bool`, *optional*, defaults to False):
            Whether to run NMS on detected boxes.
        box_nms_ious_threshold (`float`, *optional*, defaults to 0.5):
            NMS discards all overlapping boxes with IoU > box_nms_iou_threshold
        threshold (`float`, *optional*, defaults to 0.2):
            Score threshold to keep DINO object detection predictions.
        text_threshold (`float`, *optional*, defaults to 0.3):
            Score threshold to keep DINO text detection predictions.
        mask_threshold (`float`, *optional*, defaults to 0.0):
            The threshold to use for binarizing SAM masks.
        padding ('int', *optional*, defaults to 0):
            The padding size for the mask image.
        min_pos_mask_iou ('float', *optional*, defaults to 0.0):
            The minimum allowable mask iou between this segment's pos_filter_mask and another segment's pos submask.
        max_pos_mask_iou ('float', *optional*, defaults to 1.0):
            The minimum allowable mask iou between this segment's pos_filter_mask and another segment's pos submask.
        min_neg_mask_iou ('float', *optional*, defaults to 0.0):
            The minimum allowable mask iou between this segment's neg_filter_mask and another segment's neg submask.
        max_pos_mask_iou ('float', *optional*, defaults to 1.0):
            The minimum allowable mask iou between this segment's neg_filter_mask and another segment's neg submask.
        min_obj_size ('int', *optional*, defaults to 0):
            The smallest allowable object size in the output mask.
        mask_type ('str', *optional*, defaults to `pos`):
            The segment's mask type (pos|neg).
        join_type ('left', *optional*, defaults to `left`):
            The segment's join type (left|right)
        filter_mask_type ('str', *optional*, defaults to `pos`):
            The segment's filter mask type (pos|neg|both) to calculate iou with another segment's submasks.
        dependencies ('List[Segment]'):
            The segment's dependencies or parent segments.
        max_num_masks ('int', *optional*, defaults to `None`):
            The maximum number of submasks to keep from SAM output. If None, all submasks are used.
        sort ('str', *optional*, defaults to `asc`):
            The sort order (asc|desc) of the segment's submasks.
        sort_by ('str', *optional*, defaults to `size`):
            The dimension (size|score) to sort the segment's submasks by.
        outpaint ('bool'):
            Whether to invert the segment's mask.
        version ('int', *optional*, defaults to 1):
            The segment's version number.
    """

    def __init__(
        self,
        dino_model="base",
        segment_id=-1,
        name="",
        classes=None,
        boxes=None,
        points=None,
        box_nms=False,
        box_nms_iou_threshold=0.5,
        threshold=0.2,
        text_threshold=0.3,
        mask_threshold=0,
        min_obj_size=0,
        padding=0,
        min_pos_mask_iou=0.0,
        max_pos_mask_iou=1.0,
        min_neg_mask_iou=0.0,
        max_neg_mask_iou=1.0,
        mask_type="pos",
        join_type="left",
        filter_mask_type="pos",
        dependencies=None,
        max_num_masks=None,
        sort="asc",
        sort_by="size",
        outpaint=False,
        pos_mask=None,
        neg_mask=None,
        pos_masks=None,
        neg_masks=None,
        masks=None,
        version=3,
    ):
        self.dino_model = dino_model
        self.segment_id = segment_id
        self.name = name
        self.boxes = boxes
        self.classes = classes
        self.points = points
        self.box_nms = box_nms
        self.box_nms_iou_threshold = box_nms_iou_threshold
        self.threshold = threshold
        self.text_threshold = text_threshold
        self.mask_threshold = mask_threshold
        self.padding = padding if padding else 0
        self.min_pos_mask_iou = min_pos_mask_iou
        self.max_pos_mask_iou = max_pos_mask_iou
        self.min_neg_mask_iou = min_neg_mask_iou
        self.max_neg_mask_iou = max_neg_mask_iou
        self.mask_type = mask_type
        self.join_type = join_type
        self.filter_mask_type = filter_mask_type
        self.dependencies = dependencies if dependencies else []
        self.max_num_masks = max_num_masks
        self.min_obj_size = min_obj_size
        self.sort = sort
        self.sort_by = sort_by
        self.outpaint = outpaint
        self.pos_mask = pos_mask
        self.neg_mask = neg_mask
        self.pos_masks = pos_masks
        self.neg_masks = neg_masks
        self.version = version

    def join(self, other):
        r"""
        Join one segment with the other.

        Each submask of the other segment must have min_iou <= iou <= max_iou.
        If join_type='left' and filter_mask_type='pos', the other segment's submasks are filtered using current segment's pos_mask.
        If join_type='left' and filter_mask_type='neg', the other segment's submasks are filtered using current segment's neg_mask.
        If join_type='left' and filter_mask_type='both', the other segment's pos submasks are filtered using current segment's pos_mask,
        and the other segment's neg submasks are filtered using current segment's neg_mask.
        For join_type='right', the same logic is applied but reversing the roles of current and other segment.

        Args:
            other (`Segment`):

        Returns:
            Segment
        """
        if not (self.has_mask() and other.has_mask()):
            raise ValueError("at least one segment does not have a mask")
        pos_mask = self.pos_mask.detach().clone()
        neg_mask = self.neg_mask.detach().clone()
        other_pos_mask = other.pos_mask.detach().clone()
        other_neg_mask = other.neg_mask.detach().clone()
        pos_masks = self.pos_masks if self.has_pos_submasks() else [self.pos_mask]
        neg_masks = self.neg_masks if self.has_neg_submasks() else [self.neg_mask]
        other_pos_masks = other.pos_masks if other.has_pos_submasks() else [other.pos_mask]
        other_neg_masks = other.neg_masks if other.has_neg_submasks() else [other.neg_mask]
        join_type = self.join_type
        if join_type == "left":
            filter_mask_type = self.filter_mask_type
            if filter_mask_type == "pos":
                pos_filter_mask = pos_mask
                neg_filter_mask = pos_mask
            elif filter_mask_type == "neg":
                pos_filter_mask = neg_mask
                neg_filter_mask = neg_mask
            else:
                pos_filter_mask = pos_mask
                neg_filter_mask = neg_mask
            filtered_masks, keep_mask = self._filter_masks(
                pos_filter_mask, other_pos_masks, min_iou=self.min_pos_mask_iou, max_iou=self.max_pos_mask_iou
            )
            pos_mask[keep_mask > 0] = 1
            pos_masks.extend(filtered_masks)
            filtered_masks, keep_mask = self._filter_masks(
                neg_filter_mask, other_neg_masks, min_iou=self.min_neg_mask_iou, max_iou=self.max_neg_mask_iou
            )
            neg_mask[keep_mask > 0] = 1
            neg_masks.extend(filtered_masks)
            new_pos_mask = pos_mask
            new_neg_mask = neg_mask
            new_pos_masks = pos_masks
            new_neg_masks = neg_masks
        elif join_type == "right":
            filter_mask_type = other.filter_mask_type
            if filter_mask_type == "pos":
                pos_filter_mask = other_pos_mask
                neg_filter_mask = other_pos_mask
            elif filter_mask_type == "neg":
                pos_filter_mask = other_neg_mask
                neg_filter_mask = other_neg_mask
            else:
                pos_filter_mask = other_pos_mask
                neg_filter_mask = other_neg_mask
            filtered_masks, keep_mask = self._filter_masks(
                pos_filter_mask, pos_masks, min_iou=other.min_pos_mask_iou, max_iou=other.max_pos_mask_iou
            )
            other_pos_mask[keep_mask > 0] = 1
            other_pos_masks.extend(filtered_masks)
            filtered_masks, keep_mask = self._filter_masks(
                neg_filter_mask, neg_masks, min_iou=other.min_neg_mask_iou, max_iou=other.max_neg_mask_iou
            )
            other_neg_mask[keep_mask > 0] = 1
            other_neg_masks.extend(filtered_masks)
            new_pos_mask = other_pos_mask
            new_neg_mask = other_neg_mask
            new_pos_masks = other_pos_masks
            new_neg_masks = other_neg_masks
        else:
            raise ValueError(f"unsupported join type: {join_type}")
        if self.mask_type == "neg" and other.mask_type == "neg":
            mask_type = "neg"
        else:
            mask_type = "pos"
        return Segment(
            dino_model=self.dino_model,
            mask_type=mask_type,
            pos_mask=new_pos_mask,
            neg_mask=new_neg_mask,
            pos_masks=new_pos_masks,
            neg_masks=new_neg_masks,
            dependencies=[self, other],
        )

    def _filter_masks(self, mask, masks, min_iou=0, max_iou=1):
        keep = []
        keep_mask = torch.zeros_like(mask)
        for other_mask in masks:
            iou = get_mask_iou(mask.numpy() > 0, other_mask.numpy() > 0)
            if min_iou <= iou <= max_iou:
                keep_mask[other_mask > 0] = 1
                keep.append(other_mask)
        return keep, keep_mask

    def generate_mask(self, image):
        r"""
        Generate mask with the input image.

        Args:
            image (`PIL.Image.Image`):
        """

        if self.has_mask():
            return
        combined_mask, masks = sam_masks(
            image,
            dino_model=self.dino_model,
            input_boxes=self.boxes,
            input_classes=self.classes,
            input_points=self.points,
            box_nms=self.box_nms,
            box_nms_iou_threshold=self.box_nms_iou_threshold,
            threshold=self.threshold,
            text_threshold=self.text_threshold,
            mask_threshold=self.mask_threshold,
            outpaint=self.outpaint,
            max_num_masks=self.max_num_masks,
            min_obj_size=self.min_obj_size,
            sort=self.sort,
            sort_by=self.sort_by,
            version=self.version,
        )
        width, height = image.size
        if combined_mask is None:
            self.pos_mask = torch.zeros((height, width))
            self.neg_mask = torch.zeros((height, width))
        elif self.mask_type == "pos":
            self.pos_mask = combined_mask
            self.neg_mask = torch.zeros((height, width))
            self.pos_masks = masks
        else:
            self.pos_mask = torch.zeros((height, width))
            self.neg_mask = combined_mask
            self.neg_masks = masks
        prev_segment = None
        for i, segment in enumerate(self.get_dependencies()):
            segment.generate_mask(image)
            if i == 0:
                prev_segment = segment
            else:
                prev_segment = prev_segment.join(segment)
        if prev_segment is not None:
            new_segment = prev_segment.join(self)
            self.pos_mask = new_segment.pos_mask
            self.pos_masks = new_segment.pos_masks
            self.neg_mask = new_segment.neg_mask
            self.neg_masks = new_segment.neg_masks
        self._get_mask_image()

    def _get_mask_image(self):
        if not self.has_mask():
            raise ValueError("no mask available: must call generate_mask()")
        label = int(self.mask_type == "pos")
        mask = torch.zeros_like(self.pos_mask)
        if self.mask_type == "pos":
            mask[self.pos_mask > 0] = 1
            mask[self.neg_mask > 0] = 0
        else:
            mask[self.neg_mask > 0] = 1
            mask[self.pos_mask > 0] = 0
        mask = transforms.functional.to_pil_image(mask)
        mask = np.array(mask)
        mask, _ = clean_mask(mask, min_obj_size=self.min_obj_size)
        padding = self.padding
        if padding > 0:
            kernel = np.ones((padding, padding), np.uint8)
            mask = cv2.dilate(mask, kernel, iterations=2)
        self.mask_image = Image.fromarray(mask).convert("L")
        return self.mask_image

    def has_mask(self):
        return self.pos_mask is not None or self.neg_mask is not None

    def has_pos_submasks(self):
        return self.pos_masks is not None

    def has_neg_submasks(self):
        return self.neg_masks is not None

    def add_dependency(self, segment):
        self.dependencies.append(segment)

    def get_dependencies(self):
        return self.dependencies

    def __repr__(self):
        attrs = {}
        attrs["segment_id"] = self.segment_id
        attrs["name"] = self.name
        attrs["mask_type"] = self.mask_type
        attrs["depends"] = [s.segment_id for s in self.get_dependencies()]
        attrs = " ".join(f"{k}={v}" for k, v in attrs.items())
        return f"<Segment {attrs}>"


class SegmentationPipeline:
    """
    Pipeline for Grounded SAM based segmentation.
    """

    def __init__(self, segments=None, version=3):
        self.segments = segments if segments else []
        self.id_to_segment_map = {s.segment_id: s for s in self.segments}
        self.version = version

    @classmethod
    def from_config(cls, filepath, version=3):
        r"""
        Instantiate a SegmentationPipeline from a yaml config file.

        Args:
            filepath (`str`):
                A path to a segments yaml config file
            version (`int`):
                The version number.

        Returns:
            SegmentationPipeline
        """
        with open(filepath, "r") as f:
            config = yaml.load(f, Loader=yaml.CLoader)
            if not config:
                raise ValueError("received empty config file")
        segments = []
        segment_edges = []
        id_to_segment_map = {}
        for s in config["segments"]:
            segment_id = s["id"]
            name = s["name"]
            mask_type = s["mask_type"]
            classes = s.get("classes")
            box_nms = s.get("box_nms", False)
            box_nms_iou_threshold = s.get("box_nms_iou_threshold", 0.5)
            threshold = s.get("threshold", 0.2)
            text_threshold = s.get("text_threshold", 0.3)
            mask_threshold = s.get("mask_threshold", 0)
            padding = s.get("padding", 0)
            min_pos_mask_iou = s.get("min_pos_mask_iou", 0)
            max_pos_mask_iou = s.get("max_pos_mask_iou", 1)
            min_neg_mask_iou = s.get("min_neg_mask_iou", 0)
            max_neg_mask_iou = s.get("max_neg_mask_iou", 1)
            max_num_masks = s.get("max_num_masks")
            min_obj_size = s.get("min_obj_size", 0)
            sort = s.get("sort", "asc")
            sort_by = s.get("sort_by", "size")
            join_type = s.get("join", "left")
            filter_mask_type = s.get("filter_mask_type", "pos")
            dino_model = s.get("dino_model", "base")
            if "depends" in s:
                parents = s["depends"]
                for p in parents:
                    segment_edges.append((p, segment_id))
            segment = Segment(
                dino_model=dino_model,
                version=version,
                segment_id=segment_id,
                name=name,
                mask_type=mask_type,
                classes=classes,
                box_nms=box_nms,
                box_nms_iou_threshold=box_nms_iou_threshold,
                threshold=threshold,
                text_threshold=text_threshold,
                mask_threshold=mask_threshold,
                min_obj_size=min_obj_size,
                padding=padding,
                min_pos_mask_iou=min_pos_mask_iou,
                max_pos_mask_iou=max_pos_mask_iou,
                min_neg_mask_iou=min_neg_mask_iou,
                max_neg_mask_iou=max_neg_mask_iou,
                max_num_masks=max_num_masks,
                sort=sort,
                sort_by=sort_by,
                join_type=join_type,
                filter_mask_type=filter_mask_type,
            )
            segments.append(segment)
            id_to_segment_map[segment_id] = segment
        for parent, child in segment_edges:
            child_segment = id_to_segment_map[child]
            parent_segment = id_to_segment_map.get(parent)
            if parent_segment is None:
                raise ValueError(f"invalid segmentation graph: parent {parent} does not exist")
            child_segment.add_dependency(parent_segment)
        return cls(segments, version=version)

    def add_segment(self, segment):
        r"""
        Add a segment to the pipeline.

        Args:
            segment (`Segment`)
        """
        segment.version = self.version
        self.segments.append(segment)
        self.id_to_segment_map[segment.segment_id] = segment

    def add_segments(self, segments):
        r"""
        Add a list of segments to the pipeline.

        Args:
            segments (`List[Segment]`)
        """
        for segment in segments:
            self.add_segment(segment)

    def _get_segment_ordering(self):
        # build graph
        edges = []
        for segment in self.segments:
            for dep in segment.get_dependencies():
                edges.append([dep.segment_id, segment.segment_id])
        id_to_segment_map = self.id_to_segment_map
        graph = defaultdict(set)
        indegree = defaultdict(int)
        for edge in edges:
            parent, child = edge
            child_segment = id_to_segment_map[child]
            parent_segment = id_to_segment_map.get(parent)
            if parent_segment is None:
                raise ValueError(f"invalid segmentation graph: parent {parent} does not exist")
            graph[parent].add(child)
            indegree[child] += 1
        segment_ids = [s.segment_id for s in self.segments]
        # check no more than one sink/output node
        sink_nodes = [s for s in segment_ids if s not in graph]
        if len(sink_nodes) > 1:
            raise ValueError("invalid segmentation graph: multiple sink nodes")
        # topological sort
        order = []
        q = deque([s for s in segment_ids if indegree[s] == 0])
        num_visited = 0
        while q:
            sid = q.popleft()
            order.append(self.id_to_segment_map[sid])
            num_visited += 1
            for child in graph[sid]:
                indegree[child] -= 1
                if indegree[child] == 0:
                    q.append(child)
        # check if cycle exists
        if len(order) != len(segment_ids):
            raise ValueError("invalid segmentation graph: cycle found")
        return order

    @contextmanager
    def prepare_gpu(self, version):
        if version == 1:
            model = models.sam_model
        elif version == 2:
            raise ValueError(f"invalid version: {version}")
        elif version == 3:
            model = models.sam3_processor.model
        else:
            raise ValueError(f"invalid version: {version}")
        try:
            model = model.to("cuda")
            yield model
        finally:
            model = model.to("cpu")
            torch.cuda.empty_cache()

    def __call__(self, image):
        r"""
        The call function to the pipeline for generation.

        Args:
            image (`PIL.Image.Image`):
                SegmentationPipeline accepts a PIL image as input.
        Returns:
            SegmentationPipelineOutput
        """
        with self.prepare_gpu(self.version):
            segments = self._get_segment_ordering()
            segments[-1].generate_mask(image )
            return SegmentationPipelineOutput(segments)


class SegmentationPipelineOutput:
    """
    Output class for SegmentationPipeline
    """

    def __init__(self, segments):
        output = segments[-1]
        self.output_segment = output
        self.segments = segments
        self.mask_image = output.mask_image


def bounding_boxes(
    image,
    dino_model,
    classes,
    threshold=0.25,
    text_threshold=0.3,
    box_nms=False,
    box_nms_iou_threshold=0.5,
):
    prompt = " ".join([f"{cls}." for cls in classes])
    if not dino_model:
        dino_model = "base"
    elif dino_model not in models.grounding_dino_processor:
        raise ValueError(f"invalid dino model: {dino_model}")
    inputs = models.grounding_dino_processor[dino_model](images=image, text=prompt, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = models.grounding_dino_model[dino_model](**inputs)
    results = models.grounding_dino_processor[dino_model].post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        threshold=threshold,
        text_threshold=text_threshold,
        target_sizes=[image.size[::-1]],
    )[0]
    boxes = results["boxes"].cpu()
    # nms prioritizing smallest boxes, can support score-based nms in future if needed
    if box_nms:
        sizes = torch.tensor([(box[2] - box[0]) * (box[3] - box[1]) for box in boxes])
        idx = nms(boxes, 1 / sizes, iou_threshold=box_nms_iou_threshold)
        boxes = boxes[idx]
    return boxes


def sam1(
    image,
    input_boxes=None,
    input_points=None,
    mask_threshold=0,
):
    inputs = models.sam_processor(
        image,
        input_boxes=input_boxes,
        input_points=[input_points] if input_boxes is None else None,
        return_tensors="pt",
    ).to("cuda")
    with torch.no_grad():
        outputs = models.sam_model(**inputs)
    masks = models.sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu(),
        mask_threshold=mask_threshold,
    )[0]
    # free up any SAM tensors in gpu
    masks = masks.cpu()
    scores = outputs.iou_scores.cpu()
    del inputs, outputs
    torch.cuda.empty_cache()
    return masks, scores.squeeze(0)


def sam3(
    image,
    input_boxes=None,
    input_points=None,
    mask_threshold=0,
    classes=None,
):
    state = models.sam3_processor.set_image(image)
    models.sam3_processor.reset_all_prompts(state)
    models.sam3_processor.set_confidence_threshold(mask_threshold, state)
    if input_boxes is not None:
        ipboxes = input_boxes[0]
        # normalize boxes to 0-1
        ipboxes = ipboxes / torch.tensor([image.size[0], image.size[1], image.size[0], image.size[1]])
        state = models.sam3_processor.add_multiple_box_prompts(
            ipboxes,
            [1] * len(ipboxes),
            state,
        )
    elif input_points is not None :
        # BAD RESULTS - points uses v1 (slimSAM instead)
        # normalize points to 0-1
        ippoints  = torch.tensor(input_points) / torch.tensor([image.size[0], image.size[1]])
        state = models.sam3_processor.add_point_prompt(
            ippoints,
            [1] * len(ippoints),
            state,
        )
    else :
        state = models.sam3_processor.set_text_prompt(" ".join([f"{cls}." for cls in classes]), state)
    masks = state.get("masks").to('cpu')
    scores = state.get("scores").to('cpu')
    return masks, scores


def sam_masks(
    image,
    dino_model="base",
    input_boxes=None,
    input_classes=None,
    input_points=None,
    box_nms=False,
    box_nms_iou_threshold=0.5,
    threshold=0.2,
    text_threshold=0.3,
    mask_threshold=0,
    min_obj_size=0,
    outpaint=False,
    combine_masks=True,
    max_num_masks=None,
    sort="asc",
    sort_by="size",
    version=3,
):
    if version == 2:
        raise ValueError(f"invalid version: {version}")
    if version == 1 or input_classes is None:
        if input_boxes is None and input_classes:
            boxes = bounding_boxes(
                image,
                dino_model=dino_model,
                classes=input_classes,
                threshold=threshold,
                text_threshold=text_threshold,
                box_nms=box_nms,
                box_nms_iou_threshold=box_nms_iou_threshold,
            )
            if boxes.shape[0] == 0:
                return None, None
            input_boxes = boxes.unsqueeze(0)
        elif input_boxes is not None:
            input_boxes = [input_boxes]
        elif input_points is None:
            return None, None
    if version == 1:
        masks, scores = sam1(
            image,
            input_boxes=input_boxes,
            input_points=input_points,
            mask_threshold=mask_threshold,
        )
    else :
        masks, scores = sam3(
            image,
            input_boxes=input_boxes,
            input_points=input_points,
            mask_threshold=mask_threshold,
            classes=input_classes,
        )
    single_mask = torch.zeros(1, *masks.shape[-2:])
    all_masks = []
    if len(masks.shape) == 4:
        masks = torch.flatten(masks, start_dim=0, end_dim=1)
    if len(scores.shape) == 2:
        scores = torch.flatten(scores, start_dim=0, end_dim=1)
    sizes = torch.tensor([int(torch.count_nonzero(mask)) for mask in masks])
    masks = masks[scores >= mask_threshold]
    sizes = sizes[scores >= mask_threshold]
    scores = scores[scores >= mask_threshold]
    if sort_by == "size":
        sort_contents = sizes
    else:
        sort_contents = scores
    max_num_masks = len(masks) if max_num_masks is None else max_num_masks
    sign = 1 - 2 * int(sort == "desc")
    for i, (mask, _) in enumerate(sorted(zip(masks, sort_contents), key=lambda x: sign * x[1])):
        single_mask = torch.logical_or(single_mask, mask)
        all_masks.append(mask.float())
        if i + 1 == max_num_masks:
            break
    single_mask = single_mask.float().reshape(single_mask.shape[-2:], -1).squeeze()
    if outpaint:
        outpaint_mask = torch.zeros_like(single_mask)
        outpaint_mask[np.nonzero(single_mask)] = 0
        outpaint_mask[np.where(single_mask == 0)] = 1
        single_mask = outpaint_mask
    return single_mask, all_masks


def clean_mask(mask, min_obj_size=100, hole_threshold=100):
    new_mask = morphology.remove_small_objects(mask > 0, min_size=min_obj_size)
    if np.count_nonzero(new_mask) == 0:
        return mask, True
    new_mask = (255 * morphology.remove_small_holes(new_mask, area_threshold=hole_threshold)).astype(np.uint8)
    return new_mask, False


def get_mask_iou(mask1, mask2):
    intersection = (mask1 * mask2).sum()
    if intersection == 0:
        return 0
    union = (mask1 | mask2).sum()
    return intersection / union
