Skip to content

Visualize Predictions

The plot() method allows you to visualize predictions from a detection or segmentation model.

If you use a detection model to run inference (i.e. Grounding DINO), the plot() method will plot bounding boxes for each prediction.

If you use a segmentation model to run inference (i.e. Grounded SAM), the plot() method will plot segmentation masks for each prediction.

Here is an example of the method used to annotate predictions from a Grounding DINO model:

from autodistill_grounding_dino import GroundingDINO
from autodistill.detection import CaptionOntology
from autodistill.utils.plot import plot
import cv2

ontology = CaptionOntology(
    {
        "dog": "dog",
    }
)

model = GroundingDINO(ontology=ontology)

result = model.predict("./dog.jpeg")

plot(
    image=cv2.imread("./dog.jpeg"),
    classes=base_model.ontology.classes(),
    detections=result
)

Bounding Box

from autodistill_grounded_sam import GroundedSAM
from autodistill.detection import CaptionOntology
from autodistill.utils import plot
import cv2

ontology = CaptionOntology(
    {
        "dog": "dog",
    }
)

model = GroundedSAM(ontology=ontology)

result = model.predict("./dog.jpeg")

plot(
    image=cv2.imread("./dog.jpeg"),
    classes=model.ontology.classes(),
    detections=result
)

Segmentation Mask

Code Reference

Plot bounding boxes or segmentation masks on an image.

Parameters:

Name Type Description Default
image np.ndarray

The image to plot on

required
detections

The detections to plot

required
classes List[str]

The classes to plot

required
raw

Whether to return the raw image or plot it interactively

False

Returns:

Type Description

The raw image (np.ndarray) if raw=True, otherwise None (image is plotted interactively

Source code in autodistill/utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot(image: np.ndarray, detections, classes: List[str], raw=False):
    """
    Plot bounding boxes or segmentation masks on an image.

    Args:
        image: The image to plot on
        detections: The detections to plot
        classes: The classes to plot
        raw: Whether to return the raw image or plot it interactively

    Returns:
        The raw image (np.ndarray) if raw=True, otherwise None (image is plotted interactively
    """
    # TODO: When we have a classification annotator
    # in supervision, we can add it here
    if detections.mask is not None:
        annotator = sv.MaskAnnotator()
    else:
        annotator = sv.BoxAnnotator()

    label_annotator = sv.LabelAnnotator()

    labels = [
        f"{classes[class_id]} {confidence:0.2f}"
        for _, _, confidence, class_id, _ in detections
    ]

    annotated_frame = annotator.annotate(scene=image.copy(), detections=detections)
    annotated_frame = label_annotator.annotate(
        scene=annotated_frame, labels=labels, detections=detections
    )

    if raw:
        return annotated_frame

    sv.plot_image(annotated_frame, size=(8, 8))