Skip to content

Utilities

Learn about utility functions available for use with Autodistill.

Plot an Image with Predictions

Plot bounding boxes or segmentation masks on an image.

Parameters:

Name Type Description Default
image np.ndarray

The image to plot on

required
detections

The detections to plot

required
classes List[str]

The classes to plot

required
raw

Whether to return the raw image or plot it interactively

False

Returns:

Type Description

The raw image (np.ndarray) if raw=True, otherwise None (image is plotted interactively

Source code in autodistill/utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot(image: np.ndarray, detections, classes: List[str], raw=False):
    """
    Plot bounding boxes or segmentation masks on an image.

    Args:
        image: The image to plot on
        detections: The detections to plot
        classes: The classes to plot
        raw: Whether to return the raw image or plot it interactively

    Returns:
        The raw image (np.ndarray) if raw=True, otherwise None (image is plotted interactively
    """
    # TODO: When we have a classification annotator
    # in supervision, we can add it here
    if detections.mask is not None:
        annotator = sv.MaskAnnotator()
    else:
        annotator = sv.BoxAnnotator()

    label_annotator = sv.LabelAnnotator()

    labels = [
        f"{classes[class_id]} {confidence:0.2f}"
        for _, _, confidence, class_id, _, _ in detections
    ]

    annotated_frame = annotator.annotate(scene=image.copy(), detections=detections)
    annotated_frame = label_annotator.annotate(
        scene=annotated_frame, labels=labels, detections=detections
    )

    if raw:
        return annotated_frame

    sv.plot_image(annotated_frame, size=(8, 8))

Compare Models

Compare the predictions of multiple models on multiple images.

Parameters:

Name Type Description Default
models list

The models to compare

required
images List[str]

The images to compare

required

Returns:

Type Description

A grid of images with the predictions of each model on each image.

Source code in autodistill/utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def compare(models: list, images: List[str]):
    """
    Compare the predictions of multiple models on multiple images.

    Args:
        models: The models to compare
        images: The images to compare

    Returns:
        A grid of images with the predictions of each model on each image.
    """
    image_results = []
    model_results = []

    for model in models:
        # get model class name
        model_name = model.__class__.__name__

        for image in images:
            results = model.predict(image)

            image_data = cv2.imread(image)

            image_result = plot(
                image_data, results, classes=model.ontology.prompts(), raw=True
            )

            image_results.append(image_result)

            model_results.append(model_name)

    sv.plot_images_grid(
        image_results,
        grid_size=(len(models), len(images)),
        titles=model_results,
        size=(16, 16),
    )

Load an Image

Load an image from a file path, URI, PIL image, or numpy array.

This function is for use by Autodistill modules. You don't need to use it directly.

Parameters:

Name Type Description Default
image Any

The image to load

required
return_format

The format to return the image in

'cv2'

Returns:

Type Description
Any

The image in the specified format

Source code in autodistill/helpers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def load_image(
    image: Any,
    return_format="cv2",
) -> Any:
    """
    Load an image from a file path, URI, PIL image, or numpy array.

    This function is for use by Autodistill modules. You don't need to use it directly.

    Args:
        image: The image to load
        return_format: The format to return the image in

    Returns:
        The image in the specified format
    """
    if return_format not in ACCEPTED_RETURN_FORMATS:
        raise ValueError(f"return_format must be one of {ACCEPTED_RETURN_FORMATS}")

    if isinstance(image, Image.Image) and return_format == "PIL":
        return image
    elif isinstance(image, Image.Image) and return_format == "cv2":
        # channels need to be reversed for cv2
        return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    elif isinstance(image, Image.Image) and return_format == "numpy":
        return np.array(image)

    if isinstance(image, np.ndarray) and return_format == "PIL":
        return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    elif isinstance(image, np.ndarray) and return_format == "cv2":
        return image
    elif isinstance(image, np.ndarray) and return_format == "numpy":
        return image

    if isinstance(image, str) and image.startswith("http"):
        if return_format == "PIL":
            response = requests.get(image)
            return Image.open(BytesIO(response.content))
        elif return_format == "cv2" or return_format == "numpy":
            response = requests.get(image)
            pil_image = Image.open(BytesIO(response.content))
            return np.array(pil_image)
    elif os.path.isfile(image):
        if return_format == "PIL":
            return Image.open(image)
        elif return_format == "cv2":
            # channels need to be reversed for cv2
            return cv2.cvtColor(np.array(Image.open(image)), cv2.COLOR_RGB2BGR)
        elif return_format == "numpy":
            pil_image = Image.open(image)
            return np.array(pil_image)
    else:
        raise ValueError(f"{image} is not a valid file path or URI")

Split Video Frames

Split a video into frames and save them to a directory.

Parameters:

Name Type Description Default
video_path str

The path to the video

required
output_dir str

The directory to save the frames to

required
stride int

The stride to use when splitting the video into frames

required

Returns:

Type Description
None

None

Source code in autodistill/helpers.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def split_video_frames(video_path: str, output_dir: str, stride: int) -> None:
    """
    Split a video into frames and save them to a directory.

    Args:
        video_path: The path to the video
        output_dir: The directory to save the frames to
        stride: The stride to use when splitting the video into frames

    Returns:
        None
    """
    video_paths = sv.list_files_with_extensions(
        directory=video_path, extensions=["mov", "mp4", "MOV", "MP4"]
    )

    for name in tqdm(video_paths):
        image_name_pattern = name + "-{:05d}.jpg"
        with sv.ImageSink(
            target_dir_path=output_dir, image_name_pattern=image_name_pattern
        ) as sink:
            for image in sv.get_video_frames_generator(
                source_path=str(video_path), stride=stride
            ):
                sink.save_image(image=image)