SAM is one of the first attempts at a foundation model for computer vision tasks by Meta.
Foundation Model
What makes a model foundational? The defining feature is its ability to perform on never seen before tasks and data i.e zero / few shot learning.
In LLMs this can be seen as the ability to classify, translate, summarise etc. despite being trained on only one task, predicting the next token given a sequence of input tokens.
Pre-training Task
What would be an analogous pre training task for foundational vision models?
Inspired by LLMs which allow zero / few shot learning on new datasets and tasks through prompting, Meta defines prompt-able segmentation as the pre training task. Given a prompt the model must produce a valid segmentation. The prompt here can either be a point, a bounding box, a mask or even text. The requirement of a valid segmentation even in case of ambiguous prompts (a point on a shirt could refer to both the shirt and the person wearing it) is believed to force the model to generalise.
Model Architecture
A model suitable for this would have to combine image and prompt and generate a mask. This gives rise to 3 components — Image encoder, Prompt encoder, Mask decoder. Since in an interactive scenario there can be many prompts on the same image, the prompt encoder and mask decoder need to be fast and light where as the image encoder is allowed to be heavy.
Image Encoder used is a MAE (Masked Auto Encoding) pre-trained ViT that takes 1024, 1024, 3
dim image and outputs 64, 64, 256
dim embeddings.
Prompt Encoder outputs 256
dim embeddings for point, bounding box and text prompts.
- A point is represented as the sum of a positional encoding of the point’s location and a learned foreground/background embedding.
- A box is represented the pair of positional encodings of its 2 corner points summed with a learned embedding representing “top-left” and “bottom-right”
- Text is represented as embeddings of the text encoder from CLIP
A prompt can also be a mask in which case it is reduced to the same dimensions as the image embeddings using convolutional layers and added element wise.
Mask Decoder is built with self attention (b/w prompt embeddings), cross attention blocks (b/w image and prompt embeddings), some MLP layers and transpose convolution layers.
A linear combination of Dice loss and Focal loss was used to train
Training Data
For LLMs, the huge amount of text available online and the pre-training task being auto regressive (self supervised) meant there was no lack of training data.
Due to lack of such huge training data, Meta got creative and built a data engine. The engine works in 3 gears.
- Model provides an initial mask prediction and annotators refine those
- Model predicts refined masks for most objects. Annotators then add new mask for difficult objects that the model missed.
- Model predicts refined masks given an uniform grid of points as prompts.
The engine changes gears sequentially as it gets better and data generated in each gear gets fed back into the model to be trained on. A total of 1.1B such prompt mask pair has been generated and released as the SA-1B dataset.
Handling Ambiguity
Segmentation ambiguity occurs in the case of nested objects (person wearing a shirt, shirt mask would be nested within the person mask). A point prompt on nested object could be for either the nested object (shirt) or the outside object (person).
To help deal with this, the model is asked to predict 3 masks for every prompt. Since the max level of nesting for objects in the dataset was 3 (whole, part, and subpart). But the back propagation is done using the mask with the most confidence (also predicted by the mask decoder).
Is it actually foundational?
Meta evaluated the model for zero shot learning on tasks with competitive results.
- Sobel filter was applied on the predicted segmentation mask to produce edges (BSDS500 dataset)
- Segmentation masks were modified to object proposal (LVIS dataset)
- A object detector is used to provide bounding box, which is used as a mask to produce instance segmentation (COCO and LVIS dataset)
- SAM is trained on CLIP image and text embeddings to understand text prompts and produce segmentation masks
Meta has suggested that there’s room for improvement in how we pose the pre-training task for foundation models but this is a great start.