import os

import pytest
import torch
import torchvision.transforms.functional as F
import numpy as np
from PIL import Image

from grounded_sam import bounding_boxes, sam_masks, Segment, SegmentationPipeline

DEFAULT_CONFIG_DIR = f"{os.path.dirname(__file__)}/configs/segmentation/v1"
DEFAULT_MASK_DIR = f"{os.path.dirname(__file__)}/masks"


class TestSegmentation:
    @classmethod
    def setup_class(cls):
        input_image = Image.new("RGB", (1024, 1024))
        square_pos_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "square_512x512.png")).convert("L")
        rectangle_pos_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "rectangle_right_256x512.png")).convert("L")
        bottom_left_neg_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "bottom_left_fill.png")).convert("L")
        bottom_right_neg_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "bottom_right_fill.png")).convert("L")
        cls.input_image = input_image
        cls.square_pos_mask = F.to_tensor(square_pos_mask).squeeze(0)
        cls.rectangle_pos_mask = F.to_tensor(rectangle_pos_mask).squeeze(0)
        cls.bottom_left_neg_mask = F.to_tensor(bottom_left_neg_mask).squeeze(0)
        cls.bottom_right_neg_mask = F.to_tensor(bottom_right_neg_mask).squeeze(0)
        expected_mask = torch.zeros((1, 1024, 1024))
        expected_mask[:, 256:768, 256:768] = 1
        expected_masks = torch.zeros((1, 2, 1024, 1024))
        expected_masks[0, 0, 256:768, 256:768] = 1
        expected_masks[0, 1, 384:640, 384:640] = 1
        expected_masks_sorted_asc = [expected_masks[0, 1], expected_masks[0, 0]]
        expected_masks_sorted_desc = [expected_masks[0, 0], expected_masks[0, 1]]
        expected_masks_scores = torch.tensor([[[0.5, 0.9]]])
        cls.expected_mask = expected_mask
        cls.expected_masks = expected_masks
        cls.expected_masks_sorted_asc = expected_masks_sorted_asc
        cls.expected_masks_sorted_desc = expected_masks_sorted_desc
        cls.expected_masks_scores = expected_masks_scores

    @pytest.fixture
    def grounded_sam_fixture(self, mocker):
        mocker.patch("models.grounding_dino_processor", mocker.MagicMock())
        mocker.patch("models.grounding_dino_model", mocker.MagicMock())
        # SAM 1
        mock_sam_output = mocker.MagicMock()
        mock_sam_output.iou_scores = self.expected_masks_scores
        mock_sam_model = mocker.MagicMock()
        mock_sam_model.return_value = mock_sam_output
        mock_sam_model.to.return_value = mock_sam_model
        mocker.patch("models.sam_model", mock_sam_model)
        mock_image_processor = mocker.Mock()
        mock_image_processor.post_process_masks.return_value = [self.expected_masks]
        mock_sam_processor = mocker.MagicMock(image_processor=mock_image_processor)
        mocker.patch("models.sam_processor", mock_sam_processor)
        # SAM 2
        mock_sam2_predictor = mocker.MagicMock()
        mock_sam2_predictor.predict.return_value = (
            self.expected_masks.numpy(),
            self.expected_masks_scores[0].numpy(),
            None,
        )
        mocker.patch("models.sam2_predictor", mock_sam2_predictor)

    def test_bounding_boxes_nms(self, mocker, grounded_sam_fixture):
        mock_dino_processor = mocker.MagicMock()
        mock_dino_processor.post_process_grounded_object_detection.return_value = [
            {"boxes": torch.FloatTensor([[10, 20, 30, 40], [10, 20, 30, 40]])}
        ]
        mocker.patch("models.grounding_dino_processor", mock_dino_processor)
        boxes = bounding_boxes(self.input_image, classes=["test"], box_nms=True)
        assert len(boxes) == 1

    def test_sam_masks_classes(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"])
        assert len(actual_masks) == 2
        self._assert_masks_equal([self.expected_mask] + self.expected_masks_sorted_asc, [actual_mask] + actual_masks)

    def test_sam_masks_classes_version_2(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"], version=2)
        assert len(actual_masks) == 2
        self._assert_masks_equal([self.expected_mask] + self.expected_masks_sorted_asc, [actual_mask] + actual_masks)

    def test_sam_masks_boxes(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_boxes=[[0, 0, 0, 0]])
        assert len(actual_masks) == 2
        self._assert_masks_equal([self.expected_mask] + self.expected_masks_sorted_asc, [actual_mask] + actual_masks)

    def test_sam_masks_points(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_points=[[0, 0]])
        assert len(actual_masks) == 2
        self._assert_masks_equal([self.expected_mask] + self.expected_masks_sorted_asc, [actual_mask] + actual_masks)

    def test_sam_masks_max_num_masks(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"], max_num_masks=1)
        assert len(actual_masks) == 1
        smallest_mask = self.expected_masks_sorted_asc[0]
        self._assert_masks_equal([smallest_mask, smallest_mask], [actual_mask] + actual_masks)

    def test_sam_masks_outpaint(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"], outpaint=True)
        assert len(actual_masks) == 2
        self._assert_masks_equal(
            [1 - self.expected_mask] + self.expected_masks_sorted_asc, [actual_mask] + actual_masks
        )

    def test_sam_masks_sort_size_desc(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"], sort="desc")
        assert len(actual_masks) == 2
        self._assert_masks_equal([self.expected_mask] + self.expected_masks_sorted_desc, [actual_mask] + actual_masks)

    def test_sam_masks_sort_score_desc(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"], sort="desc", sort_by="score")
        assert len(actual_masks) == 2
        self._assert_masks_equal([self.expected_mask] + self.expected_masks_sorted_asc, [actual_mask] + actual_masks)

    def test_sam_masks_no_inputs(self, grounded_sam_fixture):
        actual_mask, actual_masks = sam_masks(self.input_image)
        assert actual_mask is None and actual_masks is None

    def test_sam_masks_empty_boxes(self, mocker, grounded_sam_fixture):
        mock_dino_processor = mocker.MagicMock()
        mock_dino_processor.post_process_grounded_object_detection.return_value = [{"boxes": torch.zeros((0, 4))}]
        mocker.patch("models.grounding_dino_processor", mock_dino_processor)
        actual_mask, actual_masks = sam_masks(self.input_image, input_classes=["test"])
        assert actual_mask is None and actual_masks is None

    def test_init_from_empty_file(self):
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "empty.yaml")
        with pytest.raises(ValueError) as exc_info:
            SegmentationPipeline.from_config(config_path)
        assert str(exc_info.value) == "received empty config file"

    def test_raise_exception_on_missing_parent(self, grounded_sam_fixture):
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "missing_parent.yaml")
        with pytest.raises(ValueError) as exc_info:
            pipe = SegmentationPipeline.from_config(config_path)
            pipe(self.input_image)
        assert str(exc_info.value) == "invalid segmentation graph: parent 2 does not exist"

    def test_raise_exception_on_multiple_sink_nodes(self, grounded_sam_fixture):
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multiple_sinks.yaml")
        with pytest.raises(ValueError) as exc_info:
            pipe = SegmentationPipeline.from_config(config_path)
            pipe(self.input_image)
        assert str(exc_info.value) == "invalid segmentation graph: multiple sink nodes"

    def test_raise_exception_on_graph_cycle(self, grounded_sam_fixture):
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "cycle.yaml")
        with pytest.raises(ValueError) as exc_info:
            pipe = SegmentationPipeline.from_config(config_path)
            pipe(self.input_image)
        assert str(exc_info.value) == "invalid segmentation graph: cycle found"

    def test_single_pos_stage(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [(self.square_pos_mask, [self.square_pos_mask])]
        expected_mask = F.to_pil_image(self.square_pos_mask)
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "single_stage_pos.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        self._assert_tensor_equal(self.square_pos_mask, result.output_segment.pos_mask)
        self._assert_image_equal(expected_mask, result.mask_image)
        # calling second time should return same mask
        result = pipe(self.input_image)
        self._assert_tensor_equal(self.square_pos_mask, result.output_segment.pos_mask)
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_single_pos_stage_with_padding(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [(self.square_pos_mask, [self.square_pos_mask])]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "square_512x512_padding_10.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "single_stage_pos_padding.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        self._assert_tensor_equal(self.square_pos_mask, result.output_segment.pos_mask)
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_single_pos_stage_version_2(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [(self.square_pos_mask, [self.square_pos_mask])]
        expected_mask = F.to_pil_image(self.square_pos_mask)
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "single_stage_pos.yaml")
        pipe = SegmentationPipeline.from_config(config_path, version=2)
        result = pipe(self.input_image)
        self._assert_tensor_equal(self.square_pos_mask, result.output_segment.pos_mask)
        self._assert_image_equal(expected_mask, result.mask_image)
        # calling second time should return same mask
        result = pipe(self.input_image)
        self._assert_tensor_equal(self.square_pos_mask, result.output_segment.pos_mask)
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_single_pos_and_neg_stage(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "single_stage_pos_neg.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "single_stage_pos_neg.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 3
        self._assert_tensor_equal(self.square_pos_mask, result.output_segment.pos_mask)
        self._assert_tensor_equal(self.bottom_left_neg_mask, result.output_segment.neg_mask)
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (None, None),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
            (self.bottom_right_neg_mask, [self.bottom_right_neg_mask]),
        ]
        # set up expected outputs
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage.png")).convert("L")
        empty_mask = torch.zeros_like(self.square_pos_mask)
        pos_mask = torch.zeros_like(self.square_pos_mask)
        pos_mask[self.square_pos_mask > 0] = 1
        pos_mask[self.rectangle_pos_mask > 0] = 1
        neg_mask = torch.zeros_like(self.bottom_left_neg_mask)
        neg_mask[self.bottom_left_neg_mask > 0] = 1
        neg_mask[self.bottom_right_neg_mask > 0] = 1
        # run pipeline
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 7
        # square + rect segment should have correct pos mask and no neg mask
        self._assert_tensor_equal(pos_mask, result.segments[4].pos_mask)
        self._assert_tensor_equal(empty_mask, result.segments[4].neg_mask)
        # bottom half segment should have correct neg mask and no pos mask
        self._assert_tensor_equal(empty_mask, result.segments[5].pos_mask)
        self._assert_tensor_equal(neg_mask, result.segments[5].neg_mask)
        # check final mask image
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_neg(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (None, None),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
            (self.bottom_right_neg_mask, [self.bottom_right_neg_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_neg.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_neg.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 7
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_join_left_pos(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_join_left_pos.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_join_left_pos.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 5
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_join_left_neg(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_join_left_neg.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_join_left_neg.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 5
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_join_left_both(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.bottom_right_neg_mask, [self.bottom_right_neg_mask]),
            (None, None),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_join_left_both.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_join_left_both.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 7
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_join_right_pos(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_join_right_pos.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_join_right_pos.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 5
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_join_right_neg(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (self.bottom_left_neg_mask, [self.bottom_left_neg_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_join_right_neg.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_join_right_neg.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 5
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_multi_stage_join_right_both(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.bottom_right_neg_mask, [self.bottom_right_neg_mask]),
            (None, None),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "multi_stage_join_right_both.png")).convert("L")
        config_path = os.path.join(DEFAULT_CONFIG_DIR, "multi_stage_join_right_both.yaml")
        pipe = SegmentationPipeline.from_config(config_path)
        result = pipe(self.input_image)
        assert len(result.segments) == 7
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_add_segments(self, mocker, grounded_sam_fixture):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (None, None),
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
        ]
        expected_mask = Image.open(os.path.join(DEFAULT_MASK_DIR, "square_rect.png")).convert("L")
        pipe = SegmentationPipeline()
        segment0 = Segment(
            segment_id=0,
            name="segment0",
            mask_type="pos",
            classes=["square"],
        )
        segment1 = Segment(
            segment_id=1,
            name="segment1",
            mask_type="pos",
            classes=["rectangle"],
        )
        output = Segment(
            segment_id=2,
            name="output",
            mask_type="pos",
        )
        output.add_dependency(segment0)
        output.add_dependency(segment1)
        pipe.add_segments([segment0, segment1, output])
        result = pipe(self.input_image)
        assert len(result.segments) == 3
        self._assert_image_equal(expected_mask, result.mask_image)

    def test_add_segments_missing_parent(self, grounded_sam_fixture):
        pipe = SegmentationPipeline()
        segment0 = Segment(
            segment_id=0,
            name="segment0",
            mask_type="pos",
            classes=["square"],
        )
        segment1 = Segment(
            segment_id=1,
            name="segment1",
            mask_type="pos",
            classes=["rectangle"],
        )
        output = Segment(
            segment_id=2,
            name="output",
            mask_type="pos",
        )
        output2 = Segment(
            segment_id=3,
            name="output",
            mask_type="pos",
        )
        output.add_dependency(segment0)
        output.add_dependency(segment1)
        # missing segment1
        pipe.add_segments([segment0, output])
        with pytest.raises(ValueError) as exc_info:
            pipe(self.input_image)
        assert str(exc_info.value) == "invalid segmentation graph: parent 1 does not exist"

    def test_no_mask_available(self):
        segment = Segment()
        with pytest.raises(ValueError) as exc_info:
            segment._get_mask_image()
        assert str(exc_info.value) == "no mask available: must call generate_mask()"

    def test_join_no_mask(self, mocker):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (self.square_pos_mask, [self.square_pos_mask]),
        ]
        segment0 = Segment()
        segment1 = Segment()
        segment0.generate_mask(self.input_image)
        with pytest.raises(ValueError) as exc_info:
            segment0.join(segment1)
        assert str(exc_info.value) == "at least one segment does not have a mask"

    def test_join_unknown_type(self, mocker):
        mock_sam_masks = mocker.patch("grounded_sam.sam_masks")
        mock_sam_masks.side_effect = [
            (self.square_pos_mask, [self.square_pos_mask]),
            (self.rectangle_pos_mask, [self.rectangle_pos_mask]),
        ]
        segment0 = Segment()
        segment1 = Segment()
        segment0.generate_mask(self.input_image)
        segment1.generate_mask(self.input_image)
        segment0.join_type = "new_join"
        with pytest.raises(ValueError) as exc_info:
            segment0.join(segment1)
        assert str(exc_info.value) == "unsupported join type: new_join"

    def test_segment_repr(self):
        segment = Segment(segment_id=0, name="test", classes=["test"])
        expected_repr = "<Segment segment_id=0 name=test mask_type=pos depends=[]>"
        assert expected_repr == repr(segment)

    def _assert_image_equal(self, img1, img2):
        assert np.allclose(np.array(img1), np.array(img2))

    def _assert_tensor_equal(self, t1, t2):
        assert torch.allclose(t1, t2)

    def _assert_masks_equal(self, expected_masks, actual_masks):
        for mask1, mask2 in zip(expected_masks, actual_masks):
            self._assert_tensor_equal(mask1, mask2)
