Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from segment_anything import SamPredictor | |
| from segment_anything.modeling import Sam | |
| class SamPredictorHQ(SamPredictor): | |
| def __init__( | |
| self, | |
| sam_model: Sam, | |
| sam_is_hq: bool = False, | |
| ) -> None: | |
| """ | |
| Uses SAM to calculate the image embedding for an image, and then | |
| allow repeated, efficient mask prediction given prompts. | |
| Arguments: | |
| sam_model (Sam): The model to use for mask prediction. | |
| """ | |
| super().__init__(sam_model=sam_model) | |
| self.is_hq = sam_is_hq | |
| def set_torch_image( | |
| self, | |
| transformed_image: torch.Tensor, | |
| original_image_size: Tuple[int, ...], | |
| ) -> None: | |
| """ | |
| Calculates the image embeddings for the provided image, allowing | |
| masks to be predicted with the 'predict' method. Expects the input | |
| image to be already transformed to the format expected by the model. | |
| Arguments: | |
| transformed_image (torch.Tensor): The input image, with shape | |
| 1x3xHxW, which has been transformed with ResizeLongestSide. | |
| original_image_size (tuple(int, int)): The size of the image | |
| before transformation, in (H, W) format. | |
| """ | |
| assert ( | |
| len(transformed_image.shape) == 4 | |
| and transformed_image.shape[1] == 3 | |
| and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size | |
| ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." | |
| self.reset_image() | |
| self.original_size = original_image_size | |
| self.input_size = tuple(transformed_image.shape[-2:]) | |
| input_image = self.model.preprocess(transformed_image) | |
| if self.is_hq: | |
| self.features, self.interm_features = self.model.image_encoder(input_image) | |
| else: | |
| self.features = self.model.image_encoder(input_image) | |
| self.is_image_set = True | |
| def predict_torch( | |
| self, | |
| point_coords: Optional[torch.Tensor], | |
| point_labels: Optional[torch.Tensor], | |
| boxes: Optional[torch.Tensor] = None, | |
| mask_input: Optional[torch.Tensor] = None, | |
| multimask_output: bool = True, | |
| return_logits: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Predict masks for the given input prompts, using the currently set image. | |
| Input prompts are batched torch tensors and are expected to already be | |
| transformed to the input frame using ResizeLongestSide. | |
| Arguments: | |
| point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the | |
| model. Each point is in (X,Y) in pixels. | |
| point_labels (torch.Tensor or None): A BxN array of labels for the | |
| point prompts. 1 indicates a foreground point and 0 indicates a | |
| background point. | |
| boxes (np.ndarray or None): A Bx4 array given a box prompt to the | |
| model, in XYXY format. | |
| mask_input (np.ndarray): A low resolution mask input to the model, typically | |
| coming from a previous prediction iteration. Has form Bx1xHxW, where | |
| for SAM, H=W=256. Masks returned by a previous iteration of the | |
| predict method do not need further transformation. | |
| multimask_output (bool): If true, the model will return three masks. | |
| For ambiguous input prompts (such as a single click), this will often | |
| produce better masks than a single prediction. If only a single | |
| mask is needed, the model's predicted quality score can be used | |
| to select the best mask. For non-ambiguous prompts, such as multiple | |
| input prompts, multimask_output=False can give better results. | |
| return_logits (bool): If true, returns un-thresholded masks logits | |
| instead of a binary mask. | |
| Returns: | |
| (torch.Tensor): The output masks in BxCxHxW format, where C is the | |
| number of masks, and (H, W) is the original image size. | |
| (torch.Tensor): An array of shape BxC containing the model's | |
| predictions for the quality of each mask. | |
| (torch.Tensor): An array of shape BxCxHxW, where C is the number | |
| of masks and H=W=256. These low res logits can be passed to | |
| a subsequent iteration as mask input. | |
| """ | |
| if not self.is_image_set: | |
| raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") | |
| if point_coords is not None: | |
| points = (point_coords, point_labels) | |
| else: | |
| points = None | |
| # Embed prompts | |
| sparse_embeddings, dense_embeddings = self.model.prompt_encoder( | |
| points=points, | |
| boxes=boxes, | |
| masks=mask_input, | |
| ) | |
| # Predict masks | |
| if self.is_hq: | |
| low_res_masks, iou_predictions = self.model.mask_decoder( | |
| image_embeddings=self.features, | |
| image_pe=self.model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| hq_token_only=False, | |
| interm_embeddings=self.interm_features, | |
| ) | |
| else: | |
| low_res_masks, iou_predictions = self.model.mask_decoder( | |
| image_embeddings=self.features, | |
| image_pe=self.model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| # Upscale the masks to the original image resolution | |
| masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) | |
| if not return_logits: | |
| masks = masks > self.model.mask_threshold | |
| return masks, iou_predictions, low_res_masks | |