最近看到了Segment Anything Model,发现不需要配置太多的东西就能跑起来
介绍说明的网址Segment Anything Model
我已经下载了源代码和最小的模型,有需要的话:百度云链接
接下来是加载最小的那个模型,检测自己的图像
需要运行的是amg.py这个文件。我按照官方给的提示对这个文件进行了一点修改,没有修改之前不会在图像上显示检测的结果,这个修改后的文件也放在百度云链接中了
--input 检测图片的路径 --output 检测结果的输出路径 --model-type vit_b --checkpoint 权重文件的存放路径
举个例子
--input /home/image/test.png --output /home/image --model-type vit_b --checkpoint home/image/sam_vit_b_01ec64.pth
# for mask
# base = os.path.basename(t)
# base = os.path.splitext(base)[0]
# save_base = os.path.join(args.output, base)
# if output_mode == "binary_mask":
# os.makedirs(save_base, exist_ok=False)
# write_masks_to_folder(masks, save_base)
# else:
# save_file = save_base + ".json"
# with open(save_file, "w") as f:
# json.dump(masks, f)
不注释掉的话检测得到的掩码二值图像都会保存下来
下面的代码是用于处理一批图像,并保存下来
def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
# output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
output_mode = 'binary_mask'
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
floader_path = "/home/jin/图片/dete" # 这里为一批图像所在的文件夹
file_path = os.listdir(floader_path)
for im in file_path:
args.input = os.path.join(floader_path, im)
if not os.path.isdir(args.input):
targets = [args.input]
else:
targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
]
targets = [os.path.join(args.input, f) for f in targets]
os.makedirs(args.output, exist_ok=True)
for t in targets:
print(f"Processing '{t}'...")
image = cv2.imread(t)
if image is None:
print(f"Could not load '{t}' as an image, skipping...")
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = generator.generate(image)
plt.imshow(image) # 这里自行选择加不加
for mask in masks:
show_mask(mask['segmentation'], plt.gca())
# show_box(mask['bbox'],plt.gca())
# plt.axis('off') 保存轴线或不保存
plt.savefig('/home/jin/图片/segment-anything-main/scripts/'+im) # 这里要替换为自己的路径
plt.close()
print("Done!")