Skip to content

API Reference

This section provides an auto-generated API reference for the ibbi package.

ibbi

Main initialization file for the ibbi package.

This file serves as the primary entry point for the ibbi library. It exposes the most important high-level functions and classes, making them directly accessible to the user under the ibbi namespace. This includes the core model creation factory (create_model), the main workflow classes (Evaluator, Explainer), and key utility functions for accessing datasets and managing the cache.

The goal of this top-level __init__.py is to provide a clean and intuitive API, simplifying the user experience by abstracting away the underlying module structure.

ModelType = TypeVar('ModelType', YOLOSingleClassBeetleDetector, RTDETRSingleClassBeetleDetector, YOLOBeetleMultiClassDetector, RTDETRBeetleMultiClassDetector, GroundingDINOModel, YOLOWorldModel, UntrainedFeatureExtractor, HuggingFaceFeatureExtractor) module-attribute

A generic TypeVar for representing any of the model wrapper classes in the ibbi package.

This is used for type hinting in functions and methods that can accept or return any of the available model types, providing flexibility while maintaining static type safety.

Evaluator

A unified evaluator for assessing IBBI models on various tasks.

This class provides a streamlined interface for evaluating the performance of models on tasks such as object classification and embedding quality. It handles the boilerplate code for iterating through datasets, making predictions, and calculating a comprehensive suite of metrics for a holistic model assessment.

The Evaluator is initialized with a model instance from the ibbi package. It provides methods to run different types of evaluations, returning detailed performance reports.

Attributes:

Name Type Description
model ModelType

The instantiated ibbi model to be evaluated.

Source code in src\ibbi\evaluate\__init__.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
class Evaluator:
    """A unified evaluator for assessing IBBI models on various tasks.

    This class provides a streamlined interface for evaluating the performance of
    models on tasks such as object classification and embedding quality.
    It handles the boilerplate code for iterating through datasets, making predictions,
    and calculating a comprehensive suite of metrics for a holistic model assessment.

    The `Evaluator` is initialized with a model instance from the `ibbi` package.
    It provides methods to run different types of evaluations, returning detailed
    performance reports.

    Attributes:
        model (ModelType): The instantiated `ibbi` model to be evaluated.
    """

    def __init__(self, model: ModelType):
        """Initializes the Evaluator with a specific model.

        Args:
            model (ModelType): The model to be evaluated. This should be an instance of a class
                               that adheres to the `ModelType` protocol, meaning it has `predict`
                               and `extract_features` methods.
        """
        self.model = model

    def object_classification(
        self, dataset, iou_thresholds: Union[float, list[float]] = 0.5, predict_kwargs: Optional[dict[str, Any]] = None, **kwargs
    ):
        """Runs a comprehensive object detection and classification performance analysis.

        This method assesses the model's ability to both accurately localize and correctly
        classify objects within a dataset. It iterates through the provided dataset, gathering
        ground truth information and generating model predictions. These are then passed to the
        `object_classification_performance` function to compute a detailed suite of metrics.

        The evaluation provides a holistic view of performance, combining traditional object
        detection metrics (like mAP) with a full suite of classification metrics for each IoU
        threshold.

        Args:
            dataset (iterable): An iterable dataset where each item is a dictionary-like object
                                containing at least an 'image' key. For evaluation, items should
                                also contain an 'objects' key, which is a dictionary with 'bbox'
                                and 'category' keys.
            iou_thresholds (Union[float, list[float]], optional): The IoU threshold(s) at which
                to compute mAP and classification metrics. Can be a single float or a list of floats.
                Defaults to 0.5.
            predict_kwargs (Optional[dict[str, Any]], optional): A dictionary of keyword arguments
                to be passed directly to the model's `predict` method during evaluation.
                This is useful for model-specific parameters like `text_prompt` for zero-shot models.
                Defaults to None.
            **kwargs: Additional keyword arguments to be passed to the underlying
                      `object_classification_performance` function (e.g., `average`, `zero_division`).

        Returns:
            dict: A dictionary containing a comprehensive set of object detection and
                  classification metrics, including per-iou threshold classification performance,
                  and a detailed object-level performance table.
        """
        if predict_kwargs is None:
            predict_kwargs = {}

        # Set classes for GroundingDINO before evaluation
        if isinstance(self.model, GroundingDINOModel):
            if "text_prompt" in predict_kwargs:
                self.model.set_classes(predict_kwargs["text_prompt"])

        print("Running object classification evaluation...")

        if isinstance(self.model, (HuggingFaceFeatureExtractor, UntrainedFeatureExtractor)):
            print("Warning: Object classification evaluation is not supported for pure feature extractors.")
            return {}

        if isinstance(self.model, (GroundingDINOModel, YOLOWorldModel)):
            if "text_prompt" not in predict_kwargs and not self.model.get_classes():
                print("Warning: Zero-shot model has no classes set. Please provide a 'text_prompt' in 'predict_kwargs'.")
                return {}

        gt_boxes, gt_labels, gt_image_ids, gt_label_names = [], [], [], []
        pred_results_with_probs = []  # Full prediction result per image
        # Initialize model_classes before the loop.
        model_classes: list[str] = []
        if isinstance(self.model, (GroundingDINOModel)):
            if hasattr(self.model, "get_classes") and callable(self.model.get_classes):
                raw_model_classes = self.model.get_classes()
                if isinstance(raw_model_classes, dict):
                    model_classes = list(raw_model_classes.values())
                else:
                    model_classes = raw_model_classes
        class_name_to_idx: dict[str, int] = {}
        idx_to_name: dict[int, str] = {}

        print("Extracting ground truth and making predictions...")
        predict_kwargs_for_call = {**predict_kwargs, "include_full_probabilities": True}

        for i, item in enumerate(tqdm(dataset)):
            # Make the first prediction to set classes for YOLOWorld
            results = self.model.predict(item["image"], verbose=False, **predict_kwargs_for_call)
            pred_results_with_probs.append(results)

            if not model_classes:
                if not hasattr(self.model, "get_classes") or not callable(self.model.get_classes):
                    print("Warning: Model does not have a 'get_classes' method for class mapping. Skipping evaluation.")
                    return {}

                raw_model_classes = self.model.get_classes()
                if isinstance(raw_model_classes, dict):
                    model_classes: list[str] = list(raw_model_classes.values())
                else:
                    model_classes: list[str] = raw_model_classes

                if not model_classes:
                    print("Warning: Model returned an empty class list. Cannot proceed with classification-dependent metrics.")
                    return {}

                class_name_to_idx = {v: k for k, v in enumerate(model_classes)}
                idx_to_name = dict(enumerate(model_classes))

            # --- Extract Ground Truth ---
            if "objects" in item and "bbox" in item["objects"] and "category" in item["objects"]:
                for j in range(len(item["objects"]["category"])):
                    label_name = item["objects"]["category"][j]
                    gt_label_names.append(label_name)
                    bbox = item["objects"]["bbox"][j]
                    x1, y1, w, h = bbox
                    x2 = x1 + w
                    y2 = y1 + h
                    gt_boxes.append([x1, y1, x2, y2])
                    gt_labels.append(class_name_to_idx.get(label_name, -1))

                    gt_image_ids.append(i)

        # The GT and raw prediction data is prepared. Now run the core evaluation logic.
        performance_results = object_classification_performance(
            np.array(gt_boxes),
            gt_labels,
            gt_image_ids,
            pred_results_with_probs,
            gt_label_names=gt_label_names,
            iou_thresholds=iou_thresholds,
            model_classes=model_classes,
            idx_to_name=idx_to_name,
            **kwargs,
        )

        # Apply naming to the mAP results
        if "per_class_AP_at_last_iou" in performance_results:
            class_aps = performance_results["per_class_AP_at_last_iou"]
            named_class_aps = {idx_to_name.get(class_id, class_id): ap for class_id, ap in class_aps.items()}
            performance_results["per_class_AP_at_last_iou"] = named_class_aps

        return performance_results

    def embeddings(
        self,
        dataset,
        evaluation_level: str = "image",
        use_umap: bool = True,
        extract_kwargs: Optional[dict[str, Any]] = None,
        batch_size: int = 32,
        **kwargs,
    ):
        """Evaluates the quality of the model's feature embeddings.

        This method extracts feature embeddings from the provided dataset. It can operate
        at two levels: 'image' (extracting one embedding per image) or 'object'
        (extracting an embedding for each annotated object in each image). The quality of
        these embeddings is then assessed using clustering algorithms and a suite of
        internal and external validation metrics.

        Args:
            dataset (iterable): An iterable dataset where each item contains an 'image' key.
                                For 'object' level evaluation, items should also contain 'objects'
                                with 'bbox' and 'category' keys.
            evaluation_level (str, optional): The level at which to evaluate embeddings.
                                              Can be "image" or "object". Defaults to "image".
            use_umap (bool, optional): If True, applies UMAP for dimensionality reduction
                                       before clustering. Defaults to True.
            extract_kwargs (Optional[dict[str, Any]], optional): Keyword arguments to be passed
                to the model's `extract_features` method. Defaults to None.
            batch_size (int, optional): The batch size for GPU distance matrix calculation.
                                        Defaults to 32.
            **kwargs: Additional keyword arguments to be passed to the `EmbeddingEvaluator`.
                      See `ibbi.evaluate.embeddings.EmbeddingEvaluator` for more details.

        Returns:
            dict: A dictionary containing the results of the embedding evaluation, including
                  clustering metrics and optionally, correlation with external data.
        """
        if extract_kwargs is None:
            extract_kwargs = {}
        if evaluation_level not in ["image", "object"]:
            raise ValueError("evaluation_level must be either 'image' or 'object'.")

        print(f"Extracting embeddings for evaluation at the '{evaluation_level}' level...")
        embeddings_list = []
        true_labels = []
        valid_indices = []

        # Pre-calculate label mappings for efficiency
        unique_labels_lst = list(set(cat for item in dataset for cat in item.get("objects", {}).get("category", [])))
        unique_labels = sorted(unique_labels_lst)
        name_to_idx = {name: i for i, name in enumerate(unique_labels)}
        idx_to_name = dict(enumerate(unique_labels))

        for i, item in enumerate(tqdm(dataset)):
            if evaluation_level == "image":
                embedding = self.model.extract_features(item["image"], **extract_kwargs)
                if embedding is not None:
                    embeddings_list.append(embedding)
                    if "objects" in item and "category" in item["objects"] and item["objects"]["category"]:
                        label_name = item["objects"]["category"][0]
                        if label_name in name_to_idx:
                            true_labels.append(name_to_idx[label_name])
                            valid_indices.append(len(embeddings_list) - 1)

            elif evaluation_level == "object":
                if "objects" not in item or "bbox" not in item["objects"] or "category" not in item["objects"]:
                    continue

                original_image = item["image"]
                for j, bbox in enumerate(item["objects"]["bbox"]):
                    x, y, w, h = bbox
                    if w > 0 and h > 0:
                        cropped_image = original_image.crop((x, y, x + w, y + h))
                        embedding = self.model.extract_features(cropped_image, **extract_kwargs)
                        if embedding is not None:
                            embeddings_list.append(embedding)
                            label_name = item["objects"]["category"][j]
                            if label_name in name_to_idx:
                                true_labels.append(name_to_idx[label_name])
                                valid_indices.append(len(embeddings_list) - 1)

        if not embeddings_list:
            print("Warning: Could not extract any valid embeddings from the dataset.")
            return {}

        embeddings = np.array([emb.cpu().numpy().flatten() for emb in embeddings_list])
        evaluator = EmbeddingEvaluator(embeddings, use_umap=use_umap, **kwargs)

        results = {}
        results["internal_cluster_validation"] = evaluator.evaluate_cluster_structure()

        if true_labels:
            true_labels = np.array(true_labels)
            results["external_cluster_validation"] = evaluator.evaluate_against_truth(true_labels)
            results["sample_results"] = evaluator.get_sample_results(true_labels, label_map=idx_to_name)

            try:
                if len(np.unique(true_labels)) >= 3:
                    valid_embeddings = embeddings[valid_indices]
                    evaluator_for_mantel = EmbeddingEvaluator(valid_embeddings, use_umap=False)
                    mantel_corr, p_val, n, per_class_df = evaluator_for_mantel.compare_to_distance_matrix(
                        true_labels, label_map=idx_to_name, batch_size=batch_size
                    )
                    results["mantel_correlation"] = {"r": mantel_corr, "p_value": p_val, "n_items": n}
                    results["per_class_centroids"] = per_class_df
                else:
                    print("Not enough unique labels in the dataset subset to run the Mantel test.")
            except (ImportError, FileNotFoundError, ValueError) as e:
                print(f"Could not run Mantel test: {e}")
        else:
            print("Dataset does not have the required 'objects' and 'category' fields for external validation.")
            results["sample_results"] = evaluator.get_sample_results()

        return results

__init__(model)

Initializes the Evaluator with a specific model.

Parameters:

Name Type Description Default
model ModelType

The model to be evaluated. This should be an instance of a class that adheres to the ModelType protocol, meaning it has predict and extract_features methods.

required
Source code in src\ibbi\evaluate\__init__.py
41
42
43
44
45
46
47
48
49
def __init__(self, model: ModelType):
    """Initializes the Evaluator with a specific model.

    Args:
        model (ModelType): The model to be evaluated. This should be an instance of a class
                           that adheres to the `ModelType` protocol, meaning it has `predict`
                           and `extract_features` methods.
    """
    self.model = model

embeddings(dataset, evaluation_level='image', use_umap=True, extract_kwargs=None, batch_size=32, **kwargs)

Evaluates the quality of the model's feature embeddings.

This method extracts feature embeddings from the provided dataset. It can operate at two levels: 'image' (extracting one embedding per image) or 'object' (extracting an embedding for each annotated object in each image). The quality of these embeddings is then assessed using clustering algorithms and a suite of internal and external validation metrics.

Parameters:

Name Type Description Default
dataset iterable

An iterable dataset where each item contains an 'image' key. For 'object' level evaluation, items should also contain 'objects' with 'bbox' and 'category' keys.

required
evaluation_level str

The level at which to evaluate embeddings. Can be "image" or "object". Defaults to "image".

'image'
use_umap bool

If True, applies UMAP for dimensionality reduction before clustering. Defaults to True.

True
extract_kwargs Optional[dict[str, Any]]

Keyword arguments to be passed to the model's extract_features method. Defaults to None.

None
batch_size int

The batch size for GPU distance matrix calculation. Defaults to 32.

32
**kwargs

Additional keyword arguments to be passed to the EmbeddingEvaluator. See ibbi.evaluate.embeddings.EmbeddingEvaluator for more details.

{}

Returns:

Name Type Description
dict

A dictionary containing the results of the embedding evaluation, including clustering metrics and optionally, correlation with external data.

Source code in src\ibbi\evaluate\__init__.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def embeddings(
    self,
    dataset,
    evaluation_level: str = "image",
    use_umap: bool = True,
    extract_kwargs: Optional[dict[str, Any]] = None,
    batch_size: int = 32,
    **kwargs,
):
    """Evaluates the quality of the model's feature embeddings.

    This method extracts feature embeddings from the provided dataset. It can operate
    at two levels: 'image' (extracting one embedding per image) or 'object'
    (extracting an embedding for each annotated object in each image). The quality of
    these embeddings is then assessed using clustering algorithms and a suite of
    internal and external validation metrics.

    Args:
        dataset (iterable): An iterable dataset where each item contains an 'image' key.
                            For 'object' level evaluation, items should also contain 'objects'
                            with 'bbox' and 'category' keys.
        evaluation_level (str, optional): The level at which to evaluate embeddings.
                                          Can be "image" or "object". Defaults to "image".
        use_umap (bool, optional): If True, applies UMAP for dimensionality reduction
                                   before clustering. Defaults to True.
        extract_kwargs (Optional[dict[str, Any]], optional): Keyword arguments to be passed
            to the model's `extract_features` method. Defaults to None.
        batch_size (int, optional): The batch size for GPU distance matrix calculation.
                                    Defaults to 32.
        **kwargs: Additional keyword arguments to be passed to the `EmbeddingEvaluator`.
                  See `ibbi.evaluate.embeddings.EmbeddingEvaluator` for more details.

    Returns:
        dict: A dictionary containing the results of the embedding evaluation, including
              clustering metrics and optionally, correlation with external data.
    """
    if extract_kwargs is None:
        extract_kwargs = {}
    if evaluation_level not in ["image", "object"]:
        raise ValueError("evaluation_level must be either 'image' or 'object'.")

    print(f"Extracting embeddings for evaluation at the '{evaluation_level}' level...")
    embeddings_list = []
    true_labels = []
    valid_indices = []

    # Pre-calculate label mappings for efficiency
    unique_labels_lst = list(set(cat for item in dataset for cat in item.get("objects", {}).get("category", [])))
    unique_labels = sorted(unique_labels_lst)
    name_to_idx = {name: i for i, name in enumerate(unique_labels)}
    idx_to_name = dict(enumerate(unique_labels))

    for i, item in enumerate(tqdm(dataset)):
        if evaluation_level == "image":
            embedding = self.model.extract_features(item["image"], **extract_kwargs)
            if embedding is not None:
                embeddings_list.append(embedding)
                if "objects" in item and "category" in item["objects"] and item["objects"]["category"]:
                    label_name = item["objects"]["category"][0]
                    if label_name in name_to_idx:
                        true_labels.append(name_to_idx[label_name])
                        valid_indices.append(len(embeddings_list) - 1)

        elif evaluation_level == "object":
            if "objects" not in item or "bbox" not in item["objects"] or "category" not in item["objects"]:
                continue

            original_image = item["image"]
            for j, bbox in enumerate(item["objects"]["bbox"]):
                x, y, w, h = bbox
                if w > 0 and h > 0:
                    cropped_image = original_image.crop((x, y, x + w, y + h))
                    embedding = self.model.extract_features(cropped_image, **extract_kwargs)
                    if embedding is not None:
                        embeddings_list.append(embedding)
                        label_name = item["objects"]["category"][j]
                        if label_name in name_to_idx:
                            true_labels.append(name_to_idx[label_name])
                            valid_indices.append(len(embeddings_list) - 1)

    if not embeddings_list:
        print("Warning: Could not extract any valid embeddings from the dataset.")
        return {}

    embeddings = np.array([emb.cpu().numpy().flatten() for emb in embeddings_list])
    evaluator = EmbeddingEvaluator(embeddings, use_umap=use_umap, **kwargs)

    results = {}
    results["internal_cluster_validation"] = evaluator.evaluate_cluster_structure()

    if true_labels:
        true_labels = np.array(true_labels)
        results["external_cluster_validation"] = evaluator.evaluate_against_truth(true_labels)
        results["sample_results"] = evaluator.get_sample_results(true_labels, label_map=idx_to_name)

        try:
            if len(np.unique(true_labels)) >= 3:
                valid_embeddings = embeddings[valid_indices]
                evaluator_for_mantel = EmbeddingEvaluator(valid_embeddings, use_umap=False)
                mantel_corr, p_val, n, per_class_df = evaluator_for_mantel.compare_to_distance_matrix(
                    true_labels, label_map=idx_to_name, batch_size=batch_size
                )
                results["mantel_correlation"] = {"r": mantel_corr, "p_value": p_val, "n_items": n}
                results["per_class_centroids"] = per_class_df
            else:
                print("Not enough unique labels in the dataset subset to run the Mantel test.")
        except (ImportError, FileNotFoundError, ValueError) as e:
            print(f"Could not run Mantel test: {e}")
    else:
        print("Dataset does not have the required 'objects' and 'category' fields for external validation.")
        results["sample_results"] = evaluator.get_sample_results()

    return results

object_classification(dataset, iou_thresholds=0.5, predict_kwargs=None, **kwargs)

Runs a comprehensive object detection and classification performance analysis.

This method assesses the model's ability to both accurately localize and correctly classify objects within a dataset. It iterates through the provided dataset, gathering ground truth information and generating model predictions. These are then passed to the object_classification_performance function to compute a detailed suite of metrics.

The evaluation provides a holistic view of performance, combining traditional object detection metrics (like mAP) with a full suite of classification metrics for each IoU threshold.

Parameters:

Name Type Description Default
dataset iterable

An iterable dataset where each item is a dictionary-like object containing at least an 'image' key. For evaluation, items should also contain an 'objects' key, which is a dictionary with 'bbox' and 'category' keys.

required
iou_thresholds Union[float, list[float]]

The IoU threshold(s) at which to compute mAP and classification metrics. Can be a single float or a list of floats. Defaults to 0.5.

0.5
predict_kwargs Optional[dict[str, Any]]

A dictionary of keyword arguments to be passed directly to the model's predict method during evaluation. This is useful for model-specific parameters like text_prompt for zero-shot models. Defaults to None.

None
**kwargs

Additional keyword arguments to be passed to the underlying object_classification_performance function (e.g., average, zero_division).

{}

Returns:

Name Type Description
dict

A dictionary containing a comprehensive set of object detection and classification metrics, including per-iou threshold classification performance, and a detailed object-level performance table.

Source code in src\ibbi\evaluate\__init__.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def object_classification(
    self, dataset, iou_thresholds: Union[float, list[float]] = 0.5, predict_kwargs: Optional[dict[str, Any]] = None, **kwargs
):
    """Runs a comprehensive object detection and classification performance analysis.

    This method assesses the model's ability to both accurately localize and correctly
    classify objects within a dataset. It iterates through the provided dataset, gathering
    ground truth information and generating model predictions. These are then passed to the
    `object_classification_performance` function to compute a detailed suite of metrics.

    The evaluation provides a holistic view of performance, combining traditional object
    detection metrics (like mAP) with a full suite of classification metrics for each IoU
    threshold.

    Args:
        dataset (iterable): An iterable dataset where each item is a dictionary-like object
                            containing at least an 'image' key. For evaluation, items should
                            also contain an 'objects' key, which is a dictionary with 'bbox'
                            and 'category' keys.
        iou_thresholds (Union[float, list[float]], optional): The IoU threshold(s) at which
            to compute mAP and classification metrics. Can be a single float or a list of floats.
            Defaults to 0.5.
        predict_kwargs (Optional[dict[str, Any]], optional): A dictionary of keyword arguments
            to be passed directly to the model's `predict` method during evaluation.
            This is useful for model-specific parameters like `text_prompt` for zero-shot models.
            Defaults to None.
        **kwargs: Additional keyword arguments to be passed to the underlying
                  `object_classification_performance` function (e.g., `average`, `zero_division`).

    Returns:
        dict: A dictionary containing a comprehensive set of object detection and
              classification metrics, including per-iou threshold classification performance,
              and a detailed object-level performance table.
    """
    if predict_kwargs is None:
        predict_kwargs = {}

    # Set classes for GroundingDINO before evaluation
    if isinstance(self.model, GroundingDINOModel):
        if "text_prompt" in predict_kwargs:
            self.model.set_classes(predict_kwargs["text_prompt"])

    print("Running object classification evaluation...")

    if isinstance(self.model, (HuggingFaceFeatureExtractor, UntrainedFeatureExtractor)):
        print("Warning: Object classification evaluation is not supported for pure feature extractors.")
        return {}

    if isinstance(self.model, (GroundingDINOModel, YOLOWorldModel)):
        if "text_prompt" not in predict_kwargs and not self.model.get_classes():
            print("Warning: Zero-shot model has no classes set. Please provide a 'text_prompt' in 'predict_kwargs'.")
            return {}

    gt_boxes, gt_labels, gt_image_ids, gt_label_names = [], [], [], []
    pred_results_with_probs = []  # Full prediction result per image
    # Initialize model_classes before the loop.
    model_classes: list[str] = []
    if isinstance(self.model, (GroundingDINOModel)):
        if hasattr(self.model, "get_classes") and callable(self.model.get_classes):
            raw_model_classes = self.model.get_classes()
            if isinstance(raw_model_classes, dict):
                model_classes = list(raw_model_classes.values())
            else:
                model_classes = raw_model_classes
    class_name_to_idx: dict[str, int] = {}
    idx_to_name: dict[int, str] = {}

    print("Extracting ground truth and making predictions...")
    predict_kwargs_for_call = {**predict_kwargs, "include_full_probabilities": True}

    for i, item in enumerate(tqdm(dataset)):
        # Make the first prediction to set classes for YOLOWorld
        results = self.model.predict(item["image"], verbose=False, **predict_kwargs_for_call)
        pred_results_with_probs.append(results)

        if not model_classes:
            if not hasattr(self.model, "get_classes") or not callable(self.model.get_classes):
                print("Warning: Model does not have a 'get_classes' method for class mapping. Skipping evaluation.")
                return {}

            raw_model_classes = self.model.get_classes()
            if isinstance(raw_model_classes, dict):
                model_classes: list[str] = list(raw_model_classes.values())
            else:
                model_classes: list[str] = raw_model_classes

            if not model_classes:
                print("Warning: Model returned an empty class list. Cannot proceed with classification-dependent metrics.")
                return {}

            class_name_to_idx = {v: k for k, v in enumerate(model_classes)}
            idx_to_name = dict(enumerate(model_classes))

        # --- Extract Ground Truth ---
        if "objects" in item and "bbox" in item["objects"] and "category" in item["objects"]:
            for j in range(len(item["objects"]["category"])):
                label_name = item["objects"]["category"][j]
                gt_label_names.append(label_name)
                bbox = item["objects"]["bbox"][j]
                x1, y1, w, h = bbox
                x2 = x1 + w
                y2 = y1 + h
                gt_boxes.append([x1, y1, x2, y2])
                gt_labels.append(class_name_to_idx.get(label_name, -1))

                gt_image_ids.append(i)

    # The GT and raw prediction data is prepared. Now run the core evaluation logic.
    performance_results = object_classification_performance(
        np.array(gt_boxes),
        gt_labels,
        gt_image_ids,
        pred_results_with_probs,
        gt_label_names=gt_label_names,
        iou_thresholds=iou_thresholds,
        model_classes=model_classes,
        idx_to_name=idx_to_name,
        **kwargs,
    )

    # Apply naming to the mAP results
    if "per_class_AP_at_last_iou" in performance_results:
        class_aps = performance_results["per_class_AP_at_last_iou"]
        named_class_aps = {idx_to_name.get(class_id, class_id): ap for class_id, ap in class_aps.items()}
        performance_results["per_class_AP_at_last_iou"] = named_class_aps

    return performance_results

Explainer

A wrapper for LIME and SHAP explainability methods.

This class provides a simple interface to generate model explanations using either LIME (Local Interpretable Model-agnostic Explanations) or SHAP (SHapley Additive exPlanations). It is designed to work with any model created using ibbi.create_model.

Parameters:

Name Type Description Default
model ModelType

An instantiated model from ibbi.create_model.

required
Source code in src\ibbi\explain\__init__.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class Explainer:
    """A wrapper for LIME and SHAP explainability methods.

    This class provides a simple interface to generate model explanations using
    either LIME (Local Interpretable Model-agnostic Explanations) or SHAP
    (SHapley Additive exPlanations). It is designed to work with any model
    created using `ibbi.create_model`.

    Args:
        model (ModelType): An instantiated model from `ibbi.create_model`.
    """

    def __init__(self, model: ModelType):
        """A wrapper for LIME and SHAP explainability methods.

        This class provides a simple interface to generate model explanations using
        either LIME (Local Interpretable Model-agnostic Explanations) or SHAP
        (SHapley Additive exPlanations). It is designed to work with any model
        created using `ibbi.create_model`.

        Args:
            model (ModelType): An instantiated model from `ibbi.create_model`.
        """
        self.model = model

    def with_lime(self, image, **kwargs):
        """Generates a LIME explanation for a single image.

        LIME provides a local, intuitive explanation by showing which parts of an image
        contributed most to a specific prediction. This method is a wrapper around
        the `explain_with_lime` function.

        Args:
            image (PIL.Image.Image): The single image to be explained.
            **kwargs: Additional keyword arguments to be passed to the underlying
                      `ibbi.explain.lime.explain_with_lime` function. Common arguments
                      include `image_size`, `batch_size`, `num_samples`, `top_labels`,
                      and `num_features`.

        Returns:
            tuple[lime_image.ImageExplanation, PIL.Image.Image]: A tuple containing the LIME
            explanation object and the original image. The explanation object can be
            visualized using `ibbi.plot_lime_explanation`.
        """
        return explain_with_lime(self.model, image, **kwargs)

    def with_shap(self, explain_dataset, background_dataset, **kwargs):
        """Generates SHAP explanations for a set of images.

        SHAP (SHapley Additive exPlanations) provides robust, theoretically-grounded
        explanations by attributing a model's prediction to its input features. This
        method is a wrapper around the `explain_with_shap` function and requires a
        background dataset to integrate out features.

        Args:
            explain_dataset (list): A list of dictionaries, where each dictionary
                                    represents an image to be explained (e.g., `[{'image': img1}, {'image': img2}]`).
            background_dataset (list): A list of dictionaries representing a background dataset,
                                       used by SHAP to simulate feature absence.
            **kwargs: Additional keyword arguments to be passed to the underlying
                      `ibbi.explain.shap.explain_with_shap` function. Common arguments
                      include `num_explain_samples`, `max_evals`, `image_size`, and `text_prompt`.

        Returns:
            shap.Explanation: A SHAP Explanation object containing the SHAP values for each
                              image and each class. This object can be visualized using
                              `ibbi.plot_shap_explanation`.
        """
        return explain_with_shap(self.model, explain_dataset, background_dataset, **kwargs)

__init__(model)

A wrapper for LIME and SHAP explainability methods.

This class provides a simple interface to generate model explanations using either LIME (Local Interpretable Model-agnostic Explanations) or SHAP (SHapley Additive exPlanations). It is designed to work with any model created using ibbi.create_model.

Parameters:

Name Type Description Default
model ModelType

An instantiated model from ibbi.create_model.

required
Source code in src\ibbi\explain\__init__.py
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, model: ModelType):
    """A wrapper for LIME and SHAP explainability methods.

    This class provides a simple interface to generate model explanations using
    either LIME (Local Interpretable Model-agnostic Explanations) or SHAP
    (SHapley Additive exPlanations). It is designed to work with any model
    created using `ibbi.create_model`.

    Args:
        model (ModelType): An instantiated model from `ibbi.create_model`.
    """
    self.model = model

with_lime(image, **kwargs)

Generates a LIME explanation for a single image.

LIME provides a local, intuitive explanation by showing which parts of an image contributed most to a specific prediction. This method is a wrapper around the explain_with_lime function.

Parameters:

Name Type Description Default
image Image

The single image to be explained.

required
**kwargs

Additional keyword arguments to be passed to the underlying ibbi.explain.lime.explain_with_lime function. Common arguments include image_size, batch_size, num_samples, top_labels, and num_features.

{}

Returns:

Type Description

tuple[lime_image.ImageExplanation, PIL.Image.Image]: A tuple containing the LIME

explanation object and the original image. The explanation object can be

visualized using ibbi.plot_lime_explanation.

Source code in src\ibbi\explain\__init__.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def with_lime(self, image, **kwargs):
    """Generates a LIME explanation for a single image.

    LIME provides a local, intuitive explanation by showing which parts of an image
    contributed most to a specific prediction. This method is a wrapper around
    the `explain_with_lime` function.

    Args:
        image (PIL.Image.Image): The single image to be explained.
        **kwargs: Additional keyword arguments to be passed to the underlying
                  `ibbi.explain.lime.explain_with_lime` function. Common arguments
                  include `image_size`, `batch_size`, `num_samples`, `top_labels`,
                  and `num_features`.

    Returns:
        tuple[lime_image.ImageExplanation, PIL.Image.Image]: A tuple containing the LIME
        explanation object and the original image. The explanation object can be
        visualized using `ibbi.plot_lime_explanation`.
    """
    return explain_with_lime(self.model, image, **kwargs)

with_shap(explain_dataset, background_dataset, **kwargs)

Generates SHAP explanations for a set of images.

SHAP (SHapley Additive exPlanations) provides robust, theoretically-grounded explanations by attributing a model's prediction to its input features. This method is a wrapper around the explain_with_shap function and requires a background dataset to integrate out features.

Parameters:

Name Type Description Default
explain_dataset list

A list of dictionaries, where each dictionary represents an image to be explained (e.g., [{'image': img1}, {'image': img2}]).

required
background_dataset list

A list of dictionaries representing a background dataset, used by SHAP to simulate feature absence.

required
**kwargs

Additional keyword arguments to be passed to the underlying ibbi.explain.shap.explain_with_shap function. Common arguments include num_explain_samples, max_evals, image_size, and text_prompt.

{}

Returns:

Type Description

shap.Explanation: A SHAP Explanation object containing the SHAP values for each image and each class. This object can be visualized using ibbi.plot_shap_explanation.

Source code in src\ibbi\explain\__init__.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def with_shap(self, explain_dataset, background_dataset, **kwargs):
    """Generates SHAP explanations for a set of images.

    SHAP (SHapley Additive exPlanations) provides robust, theoretically-grounded
    explanations by attributing a model's prediction to its input features. This
    method is a wrapper around the `explain_with_shap` function and requires a
    background dataset to integrate out features.

    Args:
        explain_dataset (list): A list of dictionaries, where each dictionary
                                represents an image to be explained (e.g., `[{'image': img1}, {'image': img2}]`).
        background_dataset (list): A list of dictionaries representing a background dataset,
                                   used by SHAP to simulate feature absence.
        **kwargs: Additional keyword arguments to be passed to the underlying
                  `ibbi.explain.shap.explain_with_shap` function. Common arguments
                  include `num_explain_samples`, `max_evals`, `image_size`, and `text_prompt`.

    Returns:
        shap.Explanation: A SHAP Explanation object containing the SHAP values for each
                          image and each class. This object can be visualized using
                          `ibbi.plot_shap_explanation`.
    """
    return explain_with_shap(self.model, explain_dataset, background_dataset, **kwargs)

clean_cache()

Removes the entire ibbi cache directory.

This function will permanently delete all downloaded models and datasets associated with the ibbi package's cache. This can be useful for forcing a fresh download of all assets or for freeing up disk space.

Source code in src\ibbi\utils\cache.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def clean_cache():
    """Removes the entire ibbi cache directory.

    This function will permanently delete all downloaded models and datasets
    associated with the `ibbi` package's cache. This can be useful for forcing
    a fresh download of all assets or for freeing up disk space.
    """
    cache_dir = get_cache_dir()
    if cache_dir.exists():
        print(f"Removing cache directory: {cache_dir}")
        shutil.rmtree(cache_dir)
        print("Cache cleaned successfully.")
    else:
        print("Cache directory not found. Nothing to clean.")

create_model(model_name, pretrained=False, **kwargs)

Creates a model from a name or a task-based alias.

This function is the main entry point for instantiating models within the ibbi package. It uses a model registry to look up and create a model instance based on the provided model_name. Users can either specify the exact name of a model or use a convenient, task-based alias (e.g., "species_classifier").

When pretrained=True, the function will download the model's weights from the Hugging Face Hub and cache them locally for future use.

Parameters:

Name Type Description Default
model_name str

The name or alias of the model to create. A list of available model names and aliases can be obtained using ibbi.list_models().

required
pretrained bool

If True, loads pretrained weights for the model. Defaults to False.

False
**kwargs Any

Additional keyword arguments that will be passed to the underlying model's factory function. This allows for advanced customization.

{}

Returns:

Name Type Description
ModelType ModelType

An instantiated model object ready for prediction or feature extraction.

Raises:

Type Description
KeyError

If the provided model_name or its resolved alias is not found in the model registry.

Source code in src\ibbi\__init__.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def create_model(model_name: str, pretrained: bool = False, **kwargs: Any) -> ModelType:
    """Creates a model from a name or a task-based alias.

    This function is the main entry point for instantiating models within the `ibbi`
    package. It uses a model registry to look up and create a model instance based on
    the provided `model_name`. Users can either specify the exact name of a model
    or use a convenient, task-based alias (e.g., "species_classifier").

    When `pretrained=True`, the function will download the model's weights from the
    Hugging Face Hub and cache them locally for future use.

    Args:
        model_name (str): The name or alias of the model to create. A list of available
                          model names and aliases can be obtained using `ibbi.list_models()`.
        pretrained (bool, optional): If True, loads pretrained weights for the model.
                                     Defaults to False.
        **kwargs (Any): Additional keyword arguments that will be passed to the underlying
                        model's factory function. This allows for advanced customization.

    Returns:
        ModelType: An instantiated model object ready for prediction or feature extraction.

    Raises:
        KeyError: If the provided `model_name` or its resolved alias is not found in the
                  model registry.
    """
    # Resolve alias if used
    if model_name in MODEL_ALIASES:
        model_name = MODEL_ALIASES[model_name]

    if model_name not in model_registry:
        available = ", ".join(model_registry.keys())
        aliases = ", ".join(MODEL_ALIASES.keys())
        raise KeyError(f"Model '{model_name}' not found. Available models: [{available}]. Available aliases: [{aliases}].")

    model_factory = model_registry[model_name]
    model = model_factory(pretrained=pretrained, **kwargs)
    return model

get_cache_dir()

Gets the cache directory for the ibbi package.

This function determines the appropriate directory for storing cached files, such as downloaded model weights and datasets. It first checks for a custom path set by the IBBI_CACHE_DIR environment variable. If the variable is not set, it defaults to a standard user cache location (~/.cache/ibbi).

The function also ensures that the cache directory exists by creating it if it does not already.

Returns:

Name Type Description
Path Path

A pathlib.Path object representing the path to the cache directory.

Source code in src\ibbi\utils\cache.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_cache_dir() -> Path:
    """Gets the cache directory for the ibbi package.

    This function determines the appropriate directory for storing cached files,
    such as downloaded model weights and datasets. It first checks for a custom path
    set by the `IBBI_CACHE_DIR` environment variable. If the variable is not set,
    it defaults to a standard user cache location (`~/.cache/ibbi`).

    The function also ensures that the cache directory exists by creating it if it
    does not already.

    Returns:
        Path: A `pathlib.Path` object representing the path to the cache directory.
    """
    # Check for the custom environment variable
    cache_env_var = os.getenv("IBBI_CACHE_DIR")
    if cache_env_var:
        cache_dir = Path(cache_env_var)
    else:
        # Default to a user's home cache directory
        cache_dir = Path.home() / ".cache" / "ibbi"

    # Create the directory if it doesn't exist
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir

get_dataset(repo_id='IBBI-bio/ibbi_test_data', local_dir='ibbi_test_data', split='train', **kwargs)

Downloads and loads a dataset from the Hugging Face Hub.

This function facilitates the use of datasets hosted on the Hugging Face Hub by handling the download and caching process. It downloads the dataset to a local directory, and on subsequent calls, it will load the data directly from the local cache to save time and bandwidth.

Parameters:

Name Type Description Default
repo_id str

The repository ID of the dataset on the Hugging Face Hub. Defaults to "IBBI-bio/ibbi_test_data".

'IBBI-bio/ibbi_test_data'
local_dir str

The name of the local directory where the dataset will be stored. Defaults to "ibbi_test_data".

'ibbi_test_data'
split str

The name of the dataset split to load (e.g., "train", "test", "validation"). Defaults to "train".

'train'
**kwargs

Additional keyword arguments that will be passed directly to the datasets.load_dataset function. This allows for advanced customization of the data loading process.

{}

Returns:

Name Type Description
Dataset Dataset

The loaded dataset as a datasets.Dataset object.

Raises:

Type Description
TypeError

If the object loaded for the specified split is not of type datasets.Dataset.

Source code in src\ibbi\utils\data.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def get_dataset(
    repo_id: str = "IBBI-bio/ibbi_test_data",
    local_dir: str = "ibbi_test_data",
    split: str = "train",
    **kwargs,
) -> Dataset:
    """Downloads and loads a dataset from the Hugging Face Hub.

    This function facilitates the use of datasets hosted on the Hugging Face Hub by
    handling the download and caching process. It downloads the dataset to a local
    directory, and on subsequent calls, it will load the data directly from the local
    cache to save time and bandwidth.

    Args:
        repo_id (str, optional): The repository ID of the dataset on the Hugging Face Hub.
                                 Defaults to "IBBI-bio/ibbi_test_data".
        local_dir (str, optional): The name of the local directory where the dataset will be stored.
                                   Defaults to "ibbi_test_data".
        split (str, optional): The name of the dataset split to load (e.g., "train", "test", "validation").
                               Defaults to "train".
        **kwargs: Additional keyword arguments that will be passed directly to the
                  `datasets.load_dataset` function. This allows for advanced customization
                  of the data loading process.

    Returns:
        Dataset: The loaded dataset as a `datasets.Dataset` object.

    Raises:
        TypeError: If the object loaded for the specified split is not of type `datasets.Dataset`.
    """
    dataset_path = Path(local_dir)

    if not dataset_path.exists():
        print(f"Dataset not found locally. Downloading from '{repo_id}' to '{dataset_path}'...")
        snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=str(dataset_path))
        print("Download complete.")
    else:
        print(f"Found cached dataset at '{dataset_path}'. Loading from disk.")

    try:
        dataset: Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict] = load_dataset(
            str(dataset_path), split=split, trust_remote_code=True, **kwargs
        )

        if not isinstance(dataset, Dataset):
            raise TypeError(f"Expected a 'Dataset' object for split '{split}', but received type '{type(dataset).__name__}'.")

        print("Dataset loaded successfully.")
        return dataset
    except Exception as e:
        print(f"Failed to load dataset from '{dataset_path}'. Please check the path and your connection.")
        raise e

get_ood_dataset(repo_id='IBBI-bio/ibbi_ood_data', local_dir='ibbi_ood_data', split='train', **kwargs)

Downloads and loads the out-of-distribution (OOD) dataset from the Hugging Face Hub.

This function handles the download and caching of the OOD dataset. On subsequent calls, it will load the data directly from the local cache.

Parameters:

Name Type Description Default
repo_id str

The repository ID of the OOD dataset on the Hugging Face Hub. Defaults to "IBBI-bio/ibbi_ood_data".

'IBBI-bio/ibbi_ood_data'
local_dir str

The name of the local directory where the dataset will be stored. Defaults to "ibbi_ood_data".

'ibbi_ood_data'
split str

The name of the dataset split to load. Defaults to "train".

'train'
**kwargs

Additional keyword arguments for the datasets.load_dataset function.

{}

Returns:

Name Type Description
Dataset Dataset

The loaded OOD dataset as a datasets.Dataset object.

Source code in src\ibbi\utils\data.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def get_ood_dataset(
    repo_id: str = "IBBI-bio/ibbi_ood_data",
    local_dir: str = "ibbi_ood_data",
    split: str = "train",
    **kwargs,
) -> Dataset:
    """Downloads and loads the out-of-distribution (OOD) dataset from the Hugging Face Hub.

    This function handles the download and caching of the OOD dataset. On subsequent
    calls, it will load the data directly from the local cache.

    Args:
        repo_id (str, optional): The repository ID of the OOD dataset on the Hugging Face Hub.
                                 Defaults to "IBBI-bio/ibbi_ood_data".
        local_dir (str, optional): The name of the local directory where the dataset will be stored.
                                   Defaults to "ibbi_ood_data".
        split (str, optional): The name of the dataset split to load. Defaults to "train".
        **kwargs: Additional keyword arguments for the `datasets.load_dataset` function.

    Returns:
        Dataset: The loaded OOD dataset as a `datasets.Dataset` object.
    """
    return get_dataset(repo_id=repo_id, local_dir=local_dir, split=split, **kwargs)

get_shap_background_dataset(image_size=(224, 224))

Downloads, unzips, and loads the default IBBI SHAP background dataset.

This function is specifically designed to fetch the background dataset required for the SHAP (SHapley Additive exPlanations) explainability method. It handles the download of a zip archive from the Hugging Face Hub, extracts its contents, and loads the images into memory. The data is stored in the package's central cache directory to avoid re-downloads.

Parameters:

Name Type Description Default
image_size tuple[int, int]

The target size (width, height) to which the background images will be resized. This should match the input size expected by the model being explained. Defaults to (224, 224).

(224, 224)

Returns:

Type Description
list[dict]

list[dict]: A list of dictionaries, where each dictionary has an "image" key with a resized PIL Image object. This format is ready to be used with the ibbi.Explainer.with_shap method.

Source code in src\ibbi\utils\data.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def get_shap_background_dataset(image_size: tuple[int, int] = (224, 224)) -> list[dict]:
    """Downloads, unzips, and loads the default IBBI SHAP background dataset.

    This function is specifically designed to fetch the background dataset required for the
    SHAP (SHapley Additive exPlanations) explainability method. It handles the download of a
    zip archive from the Hugging Face Hub, extracts its contents, and loads the images into
    memory. The data is stored in the package's central cache directory to avoid re-downloads.

    Args:
        image_size (tuple[int, int], optional): The target size (width, height) to which the
                                                background images will be resized. This should
                                                match the input size expected by the model being
                                                explained. Defaults to (224, 224).

    Returns:
        list[dict]: A list of dictionaries, where each dictionary has an "image" key with a
                    resized PIL Image object. This format is ready to be used with the
                    `ibbi.Explainer.with_shap` method.
    """
    repo_id = "IBBI-bio/ibbi_shap_dataset"
    filename = "ibbi_shap_dataset.zip"
    cache_dir = get_cache_dir()
    unzip_dir = cache_dir / "unzipped_shap_data"
    image_dir = unzip_dir / "shap_dataset" / "images" / "train"

    if not image_dir.exists() or not any(image_dir.iterdir()):
        print(f"SHAP background data not found in cache. Downloading from '{repo_id}'...")
        downloaded_zip_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", cache_dir=str(cache_dir))

        print("Decompressing SHAP background dataset...")
        unzip_dir.mkdir(exist_ok=True)
        with zipfile.ZipFile(downloaded_zip_path, "r") as zip_ref:
            zip_ref.extractall(unzip_dir)
    else:
        print("Found cached SHAP background data. Loading from disk.")

    background_images = []
    print(f"Loading and resizing SHAP background images to {image_size}...")
    image_paths = list(image_dir.glob("*"))

    for img_path in image_paths:
        with Image.open(img_path) as img:
            resized_img = img.resize(image_size)
            background_images.append({"image": resized_img.copy()})

    print("SHAP background dataset loaded and resized successfully.")
    return background_images

list_models(as_df=False)

Displays or returns a summary of available models and their key information.

This function reads the model summary CSV file included with the package, which contains a comprehensive list of all available models, their tasks, and key performance metrics. It can either print this information to the console in a human-readable format or return it as a pandas DataFrame for programmatic access.

Parameters:

Name Type Description Default
as_df bool

If True, the function returns the model information as a pandas DataFrame. If False (the default), it prints the information directly to the console.

False

Returns:

Type Description

pd.DataFrame or None: If as_df is True, a pandas DataFrame containing the model summary is returned. Otherwise, the function returns None.

Source code in src\ibbi\utils\info.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def list_models(as_df: bool = False):
    """Displays or returns a summary of available models and their key information.

    This function reads the model summary CSV file included with the package, which
    contains a comprehensive list of all available models, their tasks, and key
    performance metrics. It can either print this information to the console in a
    human-readable format or return it as a pandas DataFrame for programmatic access.

    Args:
        as_df (bool, optional): If True, the function returns the model information as a
                                pandas DataFrame. If False (the default), it prints the
                                information directly to the console.

    Returns:
        pd.DataFrame or None: If `as_df` is True, a pandas DataFrame containing the model
                              summary is returned. Otherwise, the function returns None.
    """
    try:
        # Find the path to the data file within the package
        with resources.files("ibbi.data").joinpath("ibbi_model_summary.csv").open("r") as f:
            df = pd.read_csv(f)

        if as_df:
            return df
        else:
            print("Available IBBI Models:")
            print(df.to_string())

    except FileNotFoundError:
        print("Error: Model summary file not found.")
        return None

plot_lime_explanation(explanation, image, top_k=1, alpha=0.6)

Plots a detailed LIME explanation with a red-to-green overlay.

This function visualizes the output of explain_with_lime. It overlays the original image with a heatmap where green areas indicate features that positively contributed to the prediction, and red areas indicate negative contributions.

Parameters:

Name Type Description Default
explanation ImageExplanation

The explanation object generated by explain_with_lime.

required
image Image

The original image that was explained.

required
top_k int

The number of top classes to display explanations for. Defaults to 1.

1
alpha float

The transparency of the color overlay. Defaults to 0.6.

0.6
Source code in src\ibbi\explain\lime.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def plot_lime_explanation(explanation: lime_image.ImageExplanation, image: Image.Image, top_k: int = 1, alpha: float = 0.6) -> None:
    """Plots a detailed LIME explanation with a red-to-green overlay.

    This function visualizes the output of `explain_with_lime`. It overlays the original
    image with a heatmap where green areas indicate features that positively contributed
    to the prediction, and red areas indicate negative contributions.

    Args:
        explanation (lime_image.ImageExplanation): The explanation object generated by `explain_with_lime`.
        image (Image.Image): The original image that was explained.
        top_k (int, optional): The number of top classes to display explanations for. Defaults to 1.
        alpha (float, optional): The transparency of the color overlay. Defaults to 0.6.
    """
    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.title("Original Image")
    plt.axis("off")
    plt.show()

    segments = explanation.segments

    for label in explanation.top_labels[:top_k]:  # type: ignore[attr-defined]
        print(f"\n--- Explanation for Class Index: {label} ---")

        exp_for_label = explanation.local_exp.get(label)
        if not exp_for_label:
            print(f"No explanation available for class {label}.")
            continue

        weight_map = np.zeros(segments.shape, dtype=np.float32)
        for feature, weight in exp_for_label:
            weight_map[segments == feature] = weight

        max_abs_weight = np.max(np.abs(weight_map))
        if max_abs_weight == 0:
            print(f"No significant features found for class {label}.")
            fig, ax = plt.subplots(figsize=(6, 6))
            ax.imshow(image)
            ax.set_title(f"LIME: No features for class {label}")
            ax.axis("off")
            plt.show()
            continue

        norm = mcolors.Normalize(vmin=-max_abs_weight, vmax=max_abs_weight)
        cmap = plt.cm.RdYlGn  # type: ignore[attr-defined]

        colored_overlay_rgba = cmap(norm(weight_map))
        original_size = image.size
        colored_overlay_resized = resize(
            colored_overlay_rgba,
            (original_size[1], original_size[0]),
            anti_aliasing=True,
            mode="constant",
        )

        fig, ax = plt.subplots(figsize=(7, 6))
        ax.imshow(image)
        ax.imshow(colored_overlay_resized, alpha=alpha)  # type: ignore[arg-type]

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])

        cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("Feature Weight (Green: Positive, Red: Negative)", rotation=270, labelpad=20)

        ax.set_title(f"LIME Explanation for Class Index: {label}")
        ax.axis("off")
        plt.show()

plot_shap_explanation(shap_explanation_for_single_image, model, top_k=5, text_prompt=None)

Plots SHAP explanations for a SINGLE image.

This function is designed to visualize the output of explain_with_shap for one image. It uses SHAP's built-in image plotting capabilities to show which parts of the image contributed to the model's predictions for the top-k classes.

Parameters:

Name Type Description Default
shap_explanation_for_single_image Explanation

A SHAP Explanation object for a single image.

required
model ModelType

The ibbi model that was explained.

required
top_k int

The number of top class explanations to plot. Defaults to 5.

5
text_prompt Optional[str]

The text prompt used for explaining a zero-shot model. Defaults to None.

None
Source code in src\ibbi\explain\shap.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def plot_shap_explanation(
    shap_explanation_for_single_image: shap.Explanation,
    model: ModelType,
    top_k: int = 5,
    text_prompt: Optional[str] = None,
) -> None:
    """Plots SHAP explanations for a SINGLE image.

    This function is designed to visualize the output of `explain_with_shap` for one image.
    It uses SHAP's built-in image plotting capabilities to show which parts of the image
    contributed to the model's predictions for the top-k classes.

    Args:
        shap_explanation_for_single_image (shap.Explanation): A SHAP Explanation object for a single image.
        model (ModelType): The `ibbi` model that was explained.
        top_k (int, optional): The number of top class explanations to plot. Defaults to 5.
        text_prompt (Optional[str], optional): The text prompt used for explaining a zero-shot model. Defaults to None.
    """
    print("\n--- Generating Explanations for Image ---")

    image_for_plotting = shap_explanation_for_single_image.data
    shap_values = shap_explanation_for_single_image.values
    class_names = np.array(shap_explanation_for_single_image.output_names)

    prediction_fn = _prediction_wrapper(model, text_prompt=text_prompt)
    image_norm = _prepare_image_for_shap(np.array(image_for_plotting))

    prediction_scores = prediction_fn(image_norm[np.newaxis, ...])[0]

    if len(prediction_scores) > 1:
        top_indices = np.argsort(prediction_scores)[-top_k:][::-1]
    else:
        top_indices = [0]

    plt.figure(figsize=(5, 5))
    plt.imshow(image_for_plotting)
    plt.title("Original Image")
    plt.axis("off")
    plt.show()
    shap_values_for_plot = shap_values[..., top_indices]  # type: ignore
    class_names_for_plot = class_names[top_indices]

    if np.all(shap_values == 0):
        print("⚠️  Warning: SHAP values are all zero. The plot will be empty.")
        print("   This can happen if the model's prediction is not sensitive to the masking.")

    shap.image_plot(
        shap_values=[shap_values_for_plot] if isinstance(shap_values_for_plot, np.ndarray) else shap_values_for_plot,
        pixel_values=image_for_plotting,
        labels=np.array([class_names_for_plot]),
        show=True,
    )