class SegmentationClassConfig:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "class_name": ("STRING", {"default": "", "multiline": False}),
                "delta_e_threshold": ("FLOAT", {"default": 20.0}),
                "color_diff_method": (["delta_e", "color_range"], {"default": "delta_e"}),
                "diff_target_colors": ("STRING", {"default": "red,blue", "multiline": False}),
                "use_bbox": ("BOOLEAN", {"default": False}),
                "use_expand_mask": ("BOOLEAN", {"default": True}),
                "expand_mask_pixel": ("INT", {"default": 15, "min": 0}),
                "use_blur_mask": ("BOOLEAN", {"default": False}),
                "positive": ("BOOLEAN", {"default": True}),
            }
        }

    RETURN_TYPES = ("CLASS_CONFIG",)
    RETURN_NAMES = ("class_config",)
    FUNCTION = "create"
    CATEGORY = "ReImage AI"

    def create(self, class_name, delta_e_threshold, color_diff_method, diff_target_colors,
               use_bbox, use_expand_mask, expand_mask_pixel, use_blur_mask, positive):
        colors = [x.strip() for x in diff_target_colors.split(",") if x.strip()]
        return ({
            "class_name": class_name,
            "delta_e_threshold": delta_e_threshold,
            "color_diff_method": color_diff_method,
            "diff_target_colors": colors,
            "use_bbox": use_bbox,
            "use_expand_mask": use_expand_mask,
            "expand_mask_pixel": expand_mask_pixel,
            "use_blur_mask": use_blur_mask,
            "positive": positive,
        },)

class SegmentationClassStack:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {"class_config_1": ("CLASS_CONFIG",),},
            "optional": {
                "class_config_2": ("CLASS_CONFIG",),
                "class_config_3": ("CLASS_CONFIG",),
                "class_config_4": ("CLASS_CONFIG",),
                "class_config_5": ("CLASS_CONFIG",),
                "class_config_6": ("CLASS_CONFIG",),
                "class_config_7": ("CLASS_CONFIG",),
                "class_config_8": ("CLASS_CONFIG",),
                "class_config_9": ("CLASS_CONFIG",),
                "class_config_10": ("CLASS_CONFIG",),
            }
        }

    RETURN_TYPES = ("LIST",)
    RETURN_NAMES = ("class_list",)
    FUNCTION = "stack"
    CATEGORY = "ReImage AI"
    def stack(self, **kwargs):
        class_list = []
        for i in range(1, 11):
            key = f"class_config_{i}"
            if key in kwargs and kwargs[key] is not None:
                class_list.append(kwargs[key])
        return (class_list,)

class SegmentationClassAggregator:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "class_list": ("LIST",), 
            }
        }

    RETURN_TYPES = ("SEGMENTOR_CONFIG",)
    RETURN_NAMES = ("segmentor_config",)
    FUNCTION = "concat"
    CATEGORY = "ReImage AI"
    def concat(self, class_list):
        pos_class_names = []
        pos_delta_e_threshold = []
        pos_color_diff_method = []
        pos_diff_target_colors = []
        pos_use_bbox = []
        pos_use_expand_mask = []
        pos_expand_mask_pixel = []
        pos_use_blur_mask = []

        neg_class_names = []
        neg_delta_e_threshold = []
        neg_color_diff_method = []
        neg_diff_target_colors = []
        neg_use_bbox = []
        neg_use_expand_mask = []
        neg_expand_mask_pixel = []
        neg_use_blur_mask = []

        for cls in class_list:
            if cls["positive"]:
                pos_class_names.append(cls["class_name"])
                pos_delta_e_threshold.append(cls["delta_e_threshold"])
                pos_color_diff_method.append(cls["color_diff_method"])
                pos_diff_target_colors.append(cls["diff_target_colors"])
                pos_use_bbox.append(cls["use_bbox"])
                pos_use_expand_mask.append(cls["use_expand_mask"])
                pos_expand_mask_pixel.append(cls["expand_mask_pixel"])
                pos_use_blur_mask.append(cls["use_blur_mask"])
            else:
                neg_class_names.append(cls["class_name"])
                neg_delta_e_threshold.append(cls["delta_e_threshold"])
                neg_color_diff_method.append(cls["color_diff_method"])
                neg_diff_target_colors.append(cls["diff_target_colors"])
                neg_use_bbox.append(cls["use_bbox"])
                neg_use_expand_mask.append(cls["use_expand_mask"])
                neg_expand_mask_pixel.append(cls["expand_mask_pixel"])
                neg_use_blur_mask.append(cls["use_blur_mask"])
        config = {
            "pos_class_names": pos_class_names,
            "pos_delta_e_threshold": pos_delta_e_threshold,
            "pos_color_diff_method": pos_color_diff_method,
            "pos_diff_target_colors": pos_diff_target_colors,
            "pos_use_bbox": pos_use_bbox,
            "pos_use_expand_mask": pos_use_expand_mask,
            "pos_expand_mask_pixel": pos_expand_mask_pixel,
            "pos_use_blur_mask": pos_use_blur_mask,
            "neg_class_names": neg_class_names,
            "neg_delta_e_threshold": neg_delta_e_threshold,
            "neg_color_diff_method": neg_color_diff_method,
            "neg_diff_target_colors": neg_diff_target_colors,
            "neg_use_bbox": neg_use_bbox,
            "neg_use_expand_mask": neg_use_expand_mask,
            "neg_expand_mask_pixel": neg_expand_mask_pixel,
            "neg_use_blur_mask": neg_use_blur_mask,
        }
        return (config,)