Skip to content

Composed Model

Bases: DetectionBaseModel

Run inference with a detection model then run inference with a classification model on the detected regions.

Source code in autodistill/core/composed_detection_model.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class ComposedDetectionModel(DetectionBaseModel):
    """
    Run inference with a detection model then run inference with a classification model on the detected regions.
    """

    def __init__(
        self,
        detection_model,
        classification_model,
        set_of_marks=None,
        set_of_marks_annotator=DEFAULT_LABEL_ANNOTATOR,
    ):
        self.detection_model = detection_model
        self.classification_model = classification_model
        self.set_of_marks = set_of_marks
        self.set_of_marks_annotator = set_of_marks_annotator
        self.ontology = self.classification_model.ontology

    def predict(self, image: str) -> sv.Detections:
        """
        Run inference with a detection model then run inference with a classification model on the detected regions.

        Args:
            image: The image to run inference on
            annotator: The annotator to use to annotate the image

        Returns:
            detections (sv.Detections)
        """
        opened_image = Image.open(image)

        detections = self.detection_model.predict(image)

        if self.set_of_marks is not None:
            labels = [f"{num}" for num in range(len(detections.xyxy))]

            opened_image = np.array(opened_image)

            annotated_frame = self.set_of_marks_annotator.annotate(
                scene=opened_image, labels=labels, detections=detections
            )

            opened_image = Image.fromarray(annotated_frame)

            opened_image.save("temp.jpeg")

            if not hasattr(self.classification_model, "set_of_marks"):
                raise Exception(
                    f"The set classification model does not have a set_of_marks method. Supported models: {SET_OF_MARKS_SUPPORTED_MODELS}"
                )

            result = self.classification_model.set_of_marks(
                input=image, masked_input="temp.jpeg", classes=labels, masks=detections
            )

            return detections

        for pred_idx, bbox in enumerate(detections.xyxy):
            # extract region from image
            region = opened_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))

            # save as tempfile
            region.save("temp.jpeg")

            result = self.classification_model.predict("temp.jpeg")

            if len(result.class_id) == 0:
                continue

            result = result.get_top_k(1)[0][0]

            detections.class_id[pred_idx] = result

        return detections

predict(image)

Run inference with a detection model then run inference with a classification model on the detected regions.

Parameters:

Name Type Description Default
image str

The image to run inference on

required
annotator

The annotator to use to annotate the image

required

Returns:

Type Description
sv.Detections

detections (sv.Detections)

Source code in autodistill/core/composed_detection_model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def predict(self, image: str) -> sv.Detections:
    """
    Run inference with a detection model then run inference with a classification model on the detected regions.

    Args:
        image: The image to run inference on
        annotator: The annotator to use to annotate the image

    Returns:
        detections (sv.Detections)
    """
    opened_image = Image.open(image)

    detections = self.detection_model.predict(image)

    if self.set_of_marks is not None:
        labels = [f"{num}" for num in range(len(detections.xyxy))]

        opened_image = np.array(opened_image)

        annotated_frame = self.set_of_marks_annotator.annotate(
            scene=opened_image, labels=labels, detections=detections
        )

        opened_image = Image.fromarray(annotated_frame)

        opened_image.save("temp.jpeg")

        if not hasattr(self.classification_model, "set_of_marks"):
            raise Exception(
                f"The set classification model does not have a set_of_marks method. Supported models: {SET_OF_MARKS_SUPPORTED_MODELS}"
            )

        result = self.classification_model.set_of_marks(
            input=image, masked_input="temp.jpeg", classes=labels, masks=detections
        )

        return detections

    for pred_idx, bbox in enumerate(detections.xyxy):
        # extract region from image
        region = opened_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))

        # save as tempfile
        region.save("temp.jpeg")

        result = self.classification_model.predict("temp.jpeg")

        if len(result.class_id) == 0:
            continue

        result = result.get_top_k(1)[0][0]

        detections.class_id[pred_idx] = result

    return detections