git clone https://github.com/IDEA-Research/Grounded-Segment-Anything.git
conda create -n env_grounded_segment_anything python==3.8.10
conda activate env_grounded_segment_anything
pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
requirements.txt
库pip install -r requirements.txt
pycharm
打开项目grounding_dino_demo.py
文件grounded_sam_demo.py
文件gpu
,device
参数使用默认的cpu
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
--grounded_checkpoint groundingdino_swint_ogc.pth
--sam_checkpoint sam_vit_h_4b8939.pth
--input_image assets/demo1.jpg
--output_dir "outputs"
--box_threshold 0.3
--text_threshold 0.25
--text_prompt "bear"
grounded_sam_simple_demo.py
文件grounded_sam_inpainting_demo.py
文件修复图片文件
添加参数
--config
GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
--grounded_checkpoint
groundingdino_swint_ogc.pth
--sam_checkpoint
sam_vit_h_4b8939.pth
--input_image
assets/inpaint_demo.jpg
--output_dir
"outputs"
--box_threshold
0.3
--text_threshold
0.25
--det_prompt
"bench"
--inpaint_prompt
"A sofa, high quality, detailed"
解决:手动下载
重新运行:报错
结果
运行
automatic_label_ram_demo.py`文件添加参数
--ram_checkpoint
ram_swin_large_14m.pth
--grounded_checkpoint
groundingdino_swint_ogc.pth
--sam_checkpoint
sam_vit_h_4b8939.pth
--input_image
assets/demo9.jpg
--output_dir
"outputs"
--box_threshold
0.25
--text_threshold
0.2
--iou_threshold
0.5
automatic_label_demo.py
文件自动标注文件
添加参数
--config
GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
--grounded_checkpoint
groundingdino_swint_ogc.pth
--sam_checkpoint
sam_vit_h_4b8939.pth
--input_image
assets/demo9.jpg
--output_dir
"outputs"
--box_threshold
0.25
--text_threshold
0.2
--iou_threshold
0.5
原因:下载文件失败,手动下载
报错Resource punkt not found. Please use the NLTK Downloader to obtain the resources
结果
automatic_label_demo.py
文件if __name__ == "__main__":
root_path='' # 根目录
images_name='images' # 图片文件夹名
images_path=os.path.join(root_path,images_name)
images_outputs_path=os.path.join(root_path,'grounded_segment_anything_images')
output_json = os.path.join(images_outputs_path,'json')
output_orig = os.path.join(images_outputs_path,'orig')
output_mask = os.path.join(images_outputs_path,'mask')
output_automatic_label = os.path.join(images_outputs_path,'automatic_label')
for i in [output_json,output_mask,output_orig,output_automatic_label]:
os.makedirs(i, exist_ok=True)
images_list=os.listdir(images_path)
parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
parser.add_argument("--config", type=str, default='GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py',
help="path to config file")
parser.add_argument("--grounded_checkpoint", type=str, default='groundingdino_swint_ogc.pth', help="path to checkpoint file")
parser.add_argument("--sam_checkpoint", type=str, default='sam_vit_h_4b8939.pth', help="path to checkpoint file")
parser.add_argument("--split", default=",", type=str, help="split for text prompt")
parser.add_argument("--openai_key", type=str, help="key for chatgpt")
parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
parser.add_argument("--box_threshold", type=float, default=0.25, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.2, help="text threshold")
parser.add_argument("--iou_threshold", type=float, default=0.5, help="iou threshold")
parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
args = parser.parse_args()
# cfg
config_file = args.config # change the path of the model config file
grounded_checkpoint = args.grounded_checkpoint # change the path of the model
sam_checkpoint = args.sam_checkpoint
# image_path = args.input_image
split = args.split
openai_key = args.openai_key
openai_proxy = args.openai_proxy
box_threshold = args.box_threshold
text_threshold = args.text_threshold
iou_threshold = args.iou_threshold
device = args.device
openai.api_key = openai_key
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy}
# load model
model = load_model(config_file, grounded_checkpoint, device=device)
processor = BlipProcessor.from_pretrained("config_data/blip-image-captioning-large")
if device == "cuda":
blip_model = BlipForConditionalGeneration.from_pretrained("config_data/blip-image-captioning-large",
torch_dtype=torch.float16).to("cuda")
else:
blip_model = BlipForConditionalGeneration.from_pretrained("config_data/blip-image-captioning-large")
for img_name in images_list:
image_path=os.path.join(images_path,img_name)
image_pil, image = load_image(image_path)
image_pil.save(os.path.join(output_orig, img_name))
args = parser.parse_args()
caption = generate_caption(image_pil, device=device)
text_prompt = generate_tags(caption, split=split)
print(f"Caption: {caption}")
print(f"Tags: {text_prompt}")
# visualize raw image
image_pil.save(os.path.join(output_orig,img_name ))
# run grounding dino model
boxes_filt, scores, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device
)
# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
caption = check_caption(caption, pred_phrases)
print(f"Revise caption with number: {caption}")
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
masks, _, _ = predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(device),
multimask_output = False,
)
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)
plt.title(caption)
plt.axis('off')
plt.savefig(
os.path.join(output_automatic_label,img_name),
bbox_inches="tight", dpi=300, pad_inches=0.0
)
save_mask_data(output_mask,output_json,img_name, caption, masks, boxes_filt, pred_phrases)