Skip to content

API Reference

This section provides an auto-generated API reference for the ibbi package.

ibbi

Main initialization file for the ibbi package.

create_model(model_name, pretrained=False, **kwargs)

Creates a model from a name.

This factory function is the main entry point for users of the package. It looks up the requested model in the registry, downloads pretrained weights from the Hugging Face Hub if requested, and returns an instantiated model object.

Parameters:

Name Type Description Default
model_name str

Name of the model to create.

required
pretrained bool

Whether to load pretrained weights from the Hugging Face Hub. Defaults to False.

False
**kwargs Any

Extra arguments to pass to the model-creating function.

{}

Returns:

Name Type Description
ModelType ModelType

An instance of the requested model (e.g., YOLOSingleClassBeetleDetector or YOLOBeetleMultiClassDetector).

Raises:

Type Description
KeyError

If the requested model_name is not found in the model registry.

Example
import ibbi

# Create a pretrained single-class detection model
detector = ibbi.create_model("yolov10x_bb_detect_model", pretrained=True)

# Create a pretrained multi-class detection model
multi_class_detector = ibbi.create_model("yolov10x_bb_multi_class_detect_model", pretrained=True)
Source code in src\ibbi\__init__.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def create_model(model_name: str, pretrained: bool = False, **kwargs: Any) -> ModelType:
    """
    Creates a model from a name.

    This factory function is the main entry point for users of the package.
    It looks up the requested model in the registry, downloads pretrained
    weights from the Hugging Face Hub if requested, and returns an
    instantiated model object.

    Args:
        model_name (str): Name of the model to create.
        pretrained (bool): Whether to load pretrained weights from the Hugging Face Hub.
                            Defaults to False.
        **kwargs (Any): Extra arguments to pass to the model-creating function.

    Returns:
        ModelType: An instance of the requested model (e.g., YOLOSingleClassBeetleDetector or
                   YOLOBeetleMultiClassDetector).

    Raises:
        KeyError: If the requested `model_name` is not found in the model registry.

    Example:
        ```python
        import ibbi

        # Create a pretrained single-class detection model
        detector = ibbi.create_model("yolov10x_bb_detect_model", pretrained=True)

        # Create a pretrained multi-class detection model
        multi_class_detector = ibbi.create_model("yolov10x_bb_multi_class_detect_model", pretrained=True)
        ```
    """
    if model_name not in model_registry:
        available = ", ".join(model_registry.keys())
        raise KeyError(f"Model '{model_name}' not found. Available models: [{available}]")

    # Look up the factory function in the registry and call it
    model_factory = model_registry[model_name]
    model = model_factory(pretrained=pretrained, **kwargs)

    return model

explain_model(model, explain_dataset, background_dataset, num_explain_samples, num_background_samples, max_evals=1000, batch_size=50, image_size=(640, 640), text_prompt=None)

Generates SHAP explanations for a given model. This function is computationally intensive.

Source code in src\ibbi\xai\shap.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def explain_model(
    model: ModelType,
    explain_dataset: list,
    background_dataset: list,
    num_explain_samples: int,
    num_background_samples: int,
    max_evals: int = 1000,
    batch_size: int = 50,
    image_size: tuple = (640, 640),
    text_prompt: Optional[str] = None,
) -> shap.Explanation:
    """
    Generates SHAP explanations for a given model.
    This function is computationally intensive.
    """
    prediction_fn = _prediction_wrapper(model, text_prompt=text_prompt)

    if isinstance(model, GroundingDINOModel):
        if not text_prompt:
            raise ValueError("A 'text_prompt' is required for explaining a GroundingDINOModel.")
        output_names = [text_prompt]
    else:
        output_names = model.get_classes()

    background_pil_images = [background_dataset[i]["image"].resize(image_size) for i in range(num_background_samples)]
    background_images = [np.array(img) for img in background_pil_images]
    background_images_norm = np.stack([_prepare_image_for_shap(img) for img in background_images])

    background_summary = np.median(background_images_norm, axis=0)

    images_to_explain_pil = [explain_dataset[i]["image"].resize(image_size) for i in range(num_explain_samples)]
    images_to_explain = [np.array(img) for img in images_to_explain_pil]
    images_to_explain_norm = [_prepare_image_for_shap(img) for img in images_to_explain]
    images_to_explain_array = np.array(images_to_explain_norm)

    masker = ImageMasker(background_summary, shape=images_to_explain_array[0].shape)
    explainer = shap.Explainer(prediction_fn, masker, output_names=output_names)

    # Ignoring the arg-type error which is due to incorrect type hints in the shap library
    shap_explanation = explainer(images_to_explain_array, max_evals=max_evals, batch_size=batch_size)  # type: ignore[arg-type]
    shap_explanation.data = np.array(images_to_explain)
    return shap_explanation

get_dataset(repo_id='IBBI-bio/ibbi_test_data', split='train', **kwargs)

Loads a dataset from the Hugging Face Hub.

This function is a wrapper around datasets.load_dataset and returns the raw Dataset object, allowing for direct manipulation.

Parameters:

Name Type Description Default
repo_id str

The Hugging Face Hub repository ID of the dataset. Defaults to "IBBI-bio/ibbi_test_data".

'IBBI-bio/ibbi_test_data'
split str

The dataset split to use (e.g., "train", "test"). Defaults to "train".

'train'
**kwargs

Additional keyword arguments passed directly to datasets.load_dataset.

{}

Returns:

Name Type Description
Dataset Dataset

The loaded dataset object from the Hugging Face Hub.

Raises:

Type Description
TypeError

If the loaded object is not of type Dataset.

Example
import ibbi

# Load the default test dataset
test_data = ibbi.get_dataset()

# Iterate through the first 5 examples
for i, example in enumerate(test_data):
    if i >= 5:
        break
    print(example['image'])
Source code in src\ibbi\utils\data.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def get_dataset(
    repo_id: str = "IBBI-bio/ibbi_test_data",
    split: str = "train",
    **kwargs,
) -> Dataset:
    """
    Loads a dataset from the Hugging Face Hub.

    This function is a wrapper around `datasets.load_dataset` and returns
    the raw Dataset object, allowing for direct manipulation.

    Args:
        repo_id (str): The Hugging Face Hub repository ID of the dataset.
                         Defaults to "IBBI-bio/ibbi_test_data".
        split (str): The dataset split to use (e.g., "train", "test").
                         Defaults to "train".
        **kwargs: Additional keyword arguments passed directly to
                  `datasets.load_dataset`.

    Returns:
        Dataset: The loaded dataset object from the Hugging Face Hub.

    Raises:
        TypeError: If the loaded object is not of type `Dataset`.

    Example:
        ```python
        import ibbi

        # Load the default test dataset
        test_data = ibbi.get_dataset()

        # Iterate through the first 5 examples
        for i, example in enumerate(test_data):
            if i >= 5:
                break
            print(example['image'])
        ```
    """
    print(f"Loading dataset '{repo_id}' (split: '{split}') from Hugging Face Hub...")
    try:
        # Load the dataset from the hub
        dataset: Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict] = load_dataset(
            repo_id, split=split, trust_remote_code=True, **kwargs
        )

        # Ensure that the returned object is a Dataset
        if not isinstance(dataset, Dataset):
            raise TypeError(
                f"Expected a 'Dataset' object for split '{split}', but received type '{type(dataset).__name__}'."
            )

        print("Dataset loaded successfully.")
        return dataset
    except Exception as e:
        print(f"Failed to load dataset '{repo_id}'. Please check the repository ID and your connection.")
        raise e

list_models(as_df=False)

Displays available models and their key information.

Reads the model summary CSV included with the package and prints it. Can also return the data as a pandas DataFrame.

Parameters:

Name Type Description Default
as_df bool

If True, returns the model information as a pandas DataFrame. If False (default), prints the information to the console.

False

Returns:

Type Description

pd.DataFrame or None: A DataFrame if as_df is True, otherwise None.

Source code in src\ibbi\utils\info.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def list_models(as_df: bool = False):
    """
    Displays available models and their key information.

    Reads the model summary CSV included with the package and prints it.
    Can also return the data as a pandas DataFrame.

    Args:
        as_df (bool): If True, returns the model information as a pandas DataFrame.
                      If False (default), prints the information to the console.

    Returns:
        pd.DataFrame or None: A DataFrame if as_df is True, otherwise None.
    """
    try:
        # Find the path to the data file within the package
        with resources.files("ibbi.data").joinpath("ibbi_model_summary.csv").open("r") as f:
            df = pd.read_csv(f)

        if as_df:
            return df
        else:
            print("Available IBBI Models:")
            print(df.to_string())

    except FileNotFoundError:
        print("Error: Model summary file not found.")
        return None

plot_explanations(shap_explanation_for_single_image, model, top_k=5, text_prompt=None)

Plots SHAP explanations for a SINGLE image.

Parameters:

Name Type Description Default
shap_explanation_for_single_image Explanation

A SHAP Explanation object for a SINGLE image. To get this, index the output of explain_model (e.g., shap_explanation[0]).

required
model ModelType

The model that was explained.

required
top_k int

The number of top predicted classes to visualize.

5
text_prompt Optional[str]

The text prompt, if a GroundingDINOModel was used.

None
Source code in src\ibbi\xai\shap.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def plot_explanations(
    shap_explanation_for_single_image: shap.Explanation,
    model: ModelType,
    top_k: int = 5,
    text_prompt: Optional[str] = None,
) -> None:
    """
    Plots SHAP explanations for a SINGLE image.

    Args:
        shap_explanation_for_single_image: A SHAP Explanation object for a SINGLE image.
                                           To get this, index the output of explain_model (e.g., `shap_explanation[0]`).
        model: The model that was explained.
        top_k: The number of top predicted classes to visualize.
        text_prompt: The text prompt, if a GroundingDINOModel was used.
    """
    print("\n--- Generating Explanations for Image ---")

    image_for_plotting = shap_explanation_for_single_image.data
    shap_values = shap_explanation_for_single_image.values
    class_names = np.array(shap_explanation_for_single_image.output_names)

    prediction_fn = _prediction_wrapper(model, text_prompt=text_prompt)
    image_norm = _prepare_image_for_shap(np.array(image_for_plotting))
    prediction_scores_batch = prediction_fn(np.expand_dims(image_norm, axis=0))
    prediction_scores = prediction_scores_batch[0]

    top_indices = np.argsort(prediction_scores)[-top_k:][::-1]

    plt.figure(figsize=(5, 5))
    plt.imshow(image_for_plotting)
    plt.title("Original Image")
    plt.axis("off")
    plt.show()

    for class_idx in top_indices:
        if prediction_scores[class_idx] > 0:
            class_name = class_names[class_idx]
            score = prediction_scores[class_idx]
            print(f"Explanation for '{class_name}' (Prediction Score: {score:.3f})")

            # FIX 2: Added `# type: ignore` to suppress the incorrect slicing error from pyright.
            # The slicing logic is correct for the shape of the shap_values array.
            shap_values_for_class = shap_values[:, :, :, class_idx]  # type: ignore[misc]

            shap.image_plot(
                shap_values=[shap_values_for_class],
                pixel_values=np.array(image_for_plotting),
                show=True,
            )