Skip to content

Combine Models

You can combine detection, segmentation, and classification models to leverage the strengths of each model.

For example, consider a scenario where you want to build a logo detection model that identifies popular logos. You could use a detection model to identify logos (i.e. Grounding DINO), then a classification model to classify between the logos (i.e. Microsoft, Apple, etc.).

To combine models, you need to choose:

  1. Either a detection or a segmentation model, and;
  2. A classification model.

Let's walk through an example of using a combination of Grounding DINO and SAM (GroundedSAM), and CLIP for logo classification.

from autodistill_clip import CLIP
from autodistill.detection import CaptionOntology
from autodistill_grounded_sam import GroundedSAM
import supervision as sv

from autodistill.core.custom_detection_model import CustomDetectionModel
import cv2

classes = ["McDonalds", "Burger King"]


SAMCLIP = CustomDetectionModel(
    detection_model=GroundedSAM(
        CaptionOntology({"logo": "logo"})
    ),
    classification_model=CLIP(
        CaptionOntology({k: k for k in classes})
    )
)

IMAGE = "logo.jpg"

results = SAMCLIP.predict(IMAGE)

image = cv2.imread(IMAGE)

annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()

labels = [
    f"{classes[class_id]} {confidence:0.2f}"
    for _, _, confidence, class_id, _ in results
]

annotated_frame = annotator.annotate(
    scene=image.copy(), detections=results
)
annotated_frame = label_annotator.annotate(
    scene=annotated_frame, labels=labels, detections=results
)

sv.plot_image(annotated_frame, size=(8, 8))

Here are the results:

SAMCLIP Example

See Also

Code Reference

Bases: DetectionBaseModel

Run inference with a detection model then run inference with a classification model on the detected regions.

Source code in autodistill/core/composed_detection_model.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class ComposedDetectionModel(DetectionBaseModel):
    """
    Run inference with a detection model then run inference with a classification model on the detected regions.
    """

    def __init__(
        self,
        detection_model,
        classification_model,
        set_of_marks=None,
        set_of_marks_annotator=DEFAULT_LABEL_ANNOTATOR,
    ):
        self.detection_model = detection_model
        self.classification_model = classification_model
        self.set_of_marks = set_of_marks
        self.set_of_marks_annotator = set_of_marks_annotator
        self.ontology = self.classification_model.ontology

    def predict(self, image: str) -> sv.Detections:
        """
        Run inference with a detection model then run inference with a classification model on the detected regions.

        Args:
            image: The image to run inference on
            annotator: The annotator to use to annotate the image

        Returns:
            detections (sv.Detections)
        """
        opened_image = Image.open(image)

        detections = self.detection_model.predict(image)

        if self.set_of_marks is not None:
            labels = [f"{num}" for num in range(len(detections.xyxy))]

            opened_image = np.array(opened_image)

            annotated_frame = self.set_of_marks_annotator.annotate(
                scene=opened_image, labels=labels, detections=detections
            )

            opened_image = Image.fromarray(annotated_frame)

            opened_image.save("temp.jpeg")

            if not hasattr(self.classification_model, "set_of_marks"):
                raise Exception(
                    f"The set classification model does not have a set_of_marks method. Supported models: {SET_OF_MARKS_SUPPORTED_MODELS}"
                )

            result = self.classification_model.set_of_marks(
                input=image, masked_input="temp.jpeg", classes=labels, masks=detections
            )

            return detections

        for pred_idx, bbox in enumerate(detections.xyxy):
            # extract region from image
            region = opened_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))

            # save as tempfile
            region.save("temp.jpeg")

            result = self.classification_model.predict("temp.jpeg")

            if len(result.class_id) == 0:
                continue

            result = result.get_top_k(1)[0][0]

            detections.class_id[pred_idx] = result

        return detections

predict(image)

Run inference with a detection model then run inference with a classification model on the detected regions.

Parameters:

Name Type Description Default
image str

The image to run inference on

required
annotator

The annotator to use to annotate the image

required

Returns:

Type Description
sv.Detections

detections (sv.Detections)

Source code in autodistill/core/composed_detection_model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def predict(self, image: str) -> sv.Detections:
    """
    Run inference with a detection model then run inference with a classification model on the detected regions.

    Args:
        image: The image to run inference on
        annotator: The annotator to use to annotate the image

    Returns:
        detections (sv.Detections)
    """
    opened_image = Image.open(image)

    detections = self.detection_model.predict(image)

    if self.set_of_marks is not None:
        labels = [f"{num}" for num in range(len(detections.xyxy))]

        opened_image = np.array(opened_image)

        annotated_frame = self.set_of_marks_annotator.annotate(
            scene=opened_image, labels=labels, detections=detections
        )

        opened_image = Image.fromarray(annotated_frame)

        opened_image.save("temp.jpeg")

        if not hasattr(self.classification_model, "set_of_marks"):
            raise Exception(
                f"The set classification model does not have a set_of_marks method. Supported models: {SET_OF_MARKS_SUPPORTED_MODELS}"
            )

        result = self.classification_model.set_of_marks(
            input=image, masked_input="temp.jpeg", classes=labels, masks=detections
        )

        return detections

    for pred_idx, bbox in enumerate(detections.xyxy):
        # extract region from image
        region = opened_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))

        # save as tempfile
        region.save("temp.jpeg")

        result = self.classification_model.predict("temp.jpeg")

        if len(result.class_id) == 0:
            continue

        result = result.get_top_k(1)[0][0]

        detections.class_id[pred_idx] = result

    return detections