Segment Anything Model(SAM)是 Facebook 的一个 AI 模型,旨在推广分割技术。在我们之前的文章中,我们讨论了 SAM 的一般信息,现在让我们深入了解其技术细节。SAM 模型的结构如下图所示,图像经过编码器得到其嵌入,并且任何掩码都可以被实现。提示可以以文本、边界框或自由形式的点的形式出现。我们对提示进行编码,并将其与图像嵌入一起传递给解码器,该解码器生成我们的掩码。
SAM 最有趣的特性之一是其轻量级的编码器和解码器,可以实现实时性能。您可以使用在 GitHub 上可用的打包版本在 Python 中使用 SAM:https://github.com/kadirnar/segment-anything-video
然而,如果您在使用过程中遇到问题,您可以使用原始 GitHub 页面上提供的 Colab 文件。以下是您可以开始使用的步骤:
using_colab = True if using_colab: import torch import torchvision print("PyTorch version:", torch.__version__) print("Torchvision version:", torchvision.__version__) print("CUDA is available:", torch.cuda.is_available()) import sys !{sys.executable} -m pip install opencv-python matplotlib !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git' !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
首先,导入 Torch 和 Torchvision,这是项目所必需的,然后使用 pip 安装 Segment Anything。
从https://github.com/facebookresearch/segment-anything#model-checkpoints 下载模型检查点,稍后我们将使用它。
接下来,创建一个图像目录,您可以将测试图像放在其中。您还可以通过替换以下命令中的 URL 来使用自己的图像:
!mkdir images !wget -O images/image.jpg https://live.staticflickr.com/65535/49894878561_14a39c6c35_b.jpg
一旦您拥有了图像,您可以导入必要的包,包括 numpy、Torch、Matplotlib 和 OpenCV。
import numpy as np import torch import matplotlib.pyplot as plt import cv2
您可以使用以下函数绘制注释:
def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35)))
现在,使用 OpenCV 读取图像并将通道从 BGR 更改为 RGB。然后,使用 Matplotlib 显示图像。
image = cv2.imread('images/image.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) plt.figure(figsize=(20,20)) plt.imshow(image) plt.axis('off') plt.show()
要创建您的掩码生成器,您需要定义您的 SAM 模型并使用 SamAutomaticMaskGenerator:
import sys sys.path.append("..") from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam)
我们的 SAM 模型注册表需要检查点和模型类型以提供模型,记得将您的运行时设置为 GPU。SamAutomaticMaskGenerator 使用您的模型来制作您的掩码生成器。您所需要做的就是将您的输入传递给这个函数以获取您的掩码。
masks = mask_generator.generate(image)
掩码对象包含关于区域和稳定性分数的多个信息,稍后会将标签添加到此掩码对象中。让我们来看一下输出:
plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') plt.show()
您还可以通过更改以下变量来调整遮罩生成器的参数:
mask_generator_2 = SamAutomaticMaskGenerator( model=sam, points_per_side=32, pred_iou_thresh=0.86, stability_score_thresh=0.92, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, # Requires open-cv to run post-processing ) masks2 = mask_generator_2.generate(image) plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks2) plt.axis('off') plt.show()
不同参数的另一个输出
您可以在这里找到源代码:
https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb