没啥好说的,低质量复刻SAM官网 https://segment-anything.com/
需要提一点:所有生成embedding和mask的操作都是python后端做的,计算mask不是onnxruntime-web实现的,前端只负责了把rle编码的mask解码后画到canvas上,会有几十毫秒的网络传输延迟。我不会react和typescript,官网F12里的源代码太难懂了,生成的svg总是与期望的不一样
throttle了一下,修改代码里的throttle delay,反应更快些,我觉得没必要已经够了,设置的150ms
蓝色前景,红色背景,对应clickType分别为1和0
随便做了下,实在做不出官网的效果,可能模型也有问题 ,我用的vit_b,懒得试了,这功能对我来说没卵用
前端使用了Vue3+ElementPlus(https://element-plus.org/zh-CN/#/zh-CN)+axios+lz-string,npm安装一下。
后端是fastapi(https://fastapi.tiangolo.com/),FastAPI 依赖 Python 3.8 及更高版本。
安装 FastAPI
pip install fastapi
另外我们还需要一个 ASGI 服务器,生产环境可以使用 Uvicorn 或者 Hypercorn:
pip install "uvicorn[standard]"
@/util/request.js
import axios from "axios";
import { ElMessage } from "element-plus";
axios.interceptors.request.use(
config => {
return config;
},
error => {
return Promise.reject(error);
}
);
axios.interceptors.response.use(
response => {
if (response.data.success != null && !response.data.success) {
return Promise.reject(response.data)
}
return response.data;
},
error => {
console.log('error: ', error)
ElMessage.error(' ');
return Promise.reject(error);
}
);
export default axios;
然后在main.js中绑定
import axios from './util/request.js'
axios.defaults.baseURL = 'http://localhost:9000'
axios.defaults.headers.post['Content-Type'] = 'application/x-www-form-urlencoded';
app.config.globalProperties.$http = axios
@/util/throttle.js
function throttle(func, delay) {
let timer = null; // 定时器变量
return function() {
const context = this; // 保存this指向
const args = arguments; // 保存参数列表
if (!timer) {
timer = setTimeout(() => {
func.apply(context, args); // 调用原始函数并传入上下文和参数
clearTimeout(timer); // 清除计时器
timer = null; // 重置计时器为null
}, delay);
}
};
}
export default throttle
@/util/mask_utils.js
/**
* Parses RLE from compressed string
* @param {Array} input
* @returns array of integers
*/
export const rleFrString = (input) => {
let result = [];
let charIndex = 0;
while (charIndex < input.length) {
let value = 0,
k = 0,
more = 1;
while (more) {
let c = input.charCodeAt(charIndex) - 48;
value |= (c & 0x1f) << (5 * k);
more = c & 0x20;
charIndex++;
k++;
if (!more && c & 0x10) value |= -1 << (5 * k);
}
if (result.length > 2) value += result[result.length - 2];
result.push(value);
}
return result;
};
/**
* Parse RLE to mask array
* @param rows
* @param cols
* @param counts
* @returns {Uint8Array}
*/
export const decodeRleCounts = ([rows, cols], counts) => {
let arr = new Uint8Array(rows * cols)
let i = 0
let flag = 0
for (let k of counts) {
while (k-- > 0) {
arr[i++] = flag
}
flag = (flag + 1) % 2
}
return arr
};
/**
* Parse Everything mode counts array to mask array
* @param rows
* @param cols
* @param counts
* @returns {Uint8Array}
*/
export const decodeEverythingMask = ([rows, cols], counts) => {
let arr = new Uint8Array(rows * cols)
let k = 0;
for (let i = 0; i < counts.length; i += 2) {
for (let j = 0; j < counts[i]; j++) {
arr[k++] = counts[i + 1]
}
}
return arr;
};
/**
* Get globally unique color in the mask
* @param category
* @param colorMap
* @returns {*}
*/
export const getUniqueColor = (category, colorMap) => {
// 该种类没有颜色
if (!colorMap.hasOwnProperty(category)) {
// 生成唯一的颜色
while (true) {
const color = {
r: Math.floor(Math.random() * 256),
g: Math.floor(Math.random() * 256),
b: Math.floor(Math.random() * 256)
}
// 检查颜色映射中是否已存在相同的颜色
const existingColors = Object.values(colorMap);
const isDuplicateColor = existingColors.some((existingColor) => {
return color.r === existingColor.r && color.g === existingColor.g && color.b === existingColor.b;
});
// 如果不存在相同颜色,结束循环
if (!isDuplicateColor) {
colorMap[category] = color;
break
}
}
console.log("生成唯一颜色", category, colorMap[category])
return colorMap[category]
} else {
return colorMap[category]
}
}
/**
* Cut out specific area of image uncovered by mask
* @param w image's natural width
* @param h image's natural height
* @param image source image
* @param canvas mask canvas
* @param callback function to solve the image blob
*/
export const cutOutImage = ({w, h}, image, canvas, callback) => {
const resultCanvas = document.createElement('canvas'),
resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true}),
originalCtx = canvas.getContext('2d', {willReadFrequently: true});
resultCanvas.width = w;
resultCanvas.height = h;
resultCtx.drawImage(image, 0, 0, w, h)
const maskDataArray = originalCtx.getImageData(0, 0, w, h).data;
const imageData = resultCtx.getImageData(0, 0, w, h);
const imageDataArray = imageData.data
// 将mask的部分去掉
for (let i = 0; i < maskDataArray.length; i += 4) {
const alpha = maskDataArray[i + 3];
if (alpha !== 0) { // 不等于0,是mask区域
imageDataArray[i + 3] = 0;
}
}
// 计算被分割出来的部分的矩形框
let minX = w;
let minY = h;
let maxX = 0;
let maxY = 0;
for (let y = 0; y < h; y++) {
for (let x = 0; x < w; x++) {
const alpha = imageDataArray[(y * w + x) * 4 + 3];
if (alpha !== 0) {
minX = Math.min(minX, x);
minY = Math.min(minY, y);
maxX = Math.max(maxX, x);
maxY = Math.max(maxY, y);
}
}
}
const width = maxX - minX + 1;
const height = maxY - minY + 1;
const startX = minX;
const startY = minY;
resultCtx.putImageData(imageData, 0, 0)
// 创建一个新的canvas来存储特定区域的图像
const croppedCanvas = document.createElement("canvas");
const croppedContext = croppedCanvas.getContext("2d");
croppedCanvas.width = width;
croppedCanvas.height = height;
// 将特定区域绘制到新canvas上
croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
croppedCanvas.toBlob(blob => {
if (callback) {
callback(blob)
}
}, "image/png");
}
/**
* Cut out specific area of image covered by target color mask
* PS: 我写的这代码有问题,比较color的时候tmd明明mask canvas中有这个颜色,
* 就是说不存在这颜色,所以不用这个函数,改成下面的了
* @param w image's natural width
* @param h image's natural height
* @param image source image
* @param canvas mask canvas
* @param color target color
* @param callback function to solve the image blob
*/
export const cutOutImageWithMaskColor = ({w, h}, image, canvas, color, callback) => {
const resultCanvas = document.createElement('canvas'),
resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true}),
originalCtx = canvas.getContext('2d', {willReadFrequently: true});
resultCanvas.width = w;
resultCanvas.height = h;
resultCtx.drawImage(image, 0, 0, w, h)
const maskDataArray = originalCtx.getImageData(0, 0, w, h).data;
const imageData = resultCtx.getImageData(0, 0, w, h);
const imageDataArray = imageData.data
let find = false
// 比较mask的color和目标color
for (let i = 0; i < maskDataArray.length; i += 4) {
const r = maskDataArray[i],
g = maskDataArray[i + 1],
b = maskDataArray[i + 2];
if (r != color.r || g != color.g || b != color.b) { // 颜色与目标颜色不相同,是mask区域
// 设置alpha为0
imageDataArray[i + 3] = 0;
} else {
find = true
}
}
// 计算被分割出来的部分的矩形框
let minX = w;
let minY = h;
let maxX = 0;
let maxY = 0;
for (let y = 0; y < h; y++) {
for (let x = 0; x < w; x++) {
const alpha = imageDataArray[(y * w + x) * 4 + 3];
if (alpha !== 0) {
minX = Math.min(minX, x);
minY = Math.min(minY, y);
maxX = Math.max(maxX, x);
maxY = Math.max(maxY, y);
}
}
}
const width = maxX - minX + 1;
const height = maxY - minY + 1;
const startX = minX;
const startY = minY;
// console.log(`矩形宽度:${width}`);
// console.log(`矩形高度:${height}`);
// console.log(`起点坐标:(${startX}, ${startY})`);
resultCtx.putImageData(imageData, 0, 0)
// 创建一个新的canvas来存储特定区域的图像
const croppedCanvas = document.createElement("canvas");
const croppedContext = croppedCanvas.getContext("2d");
croppedCanvas.width = width;
croppedCanvas.height = height;
// 将特定区域绘制到新canvas上
croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
croppedCanvas.toBlob(blob => {
if (callback) {
callback(blob)
}
}, "image/png");
}
/**
* Cut out specific area whose category is target category
* @param w image's natural width
* @param h image's natural height
* @param image source image
* @param arr original mask array that stores all pixel's category
* @param category target category
* @param callback function to solve the image blob
*/
export const cutOutImageWithCategory = ({w, h}, image, arr, category, callback) => {
const resultCanvas = document.createElement('canvas'),
resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true});
resultCanvas.width = w;
resultCanvas.height = h;
resultCtx.drawImage(image, 0, 0, w, h)
const imageData = resultCtx.getImageData(0, 0, w, h);
const imageDataArray = imageData.data
// 比较mask的类别和目标类别
let i = 0
for(let y = 0; y < h; y++){
for(let x = 0; x < w; x++){
if (category != arr[i++]) { // 类别不相同,是mask区域
// 设置alpha为0
imageDataArray[3 + (w * y + x) * 4] = 0;
}
}
}
// 计算被分割出来的部分的矩形框
let minX = w;
let minY = h;
let maxX = 0;
let maxY = 0;
for (let y = 0; y < h; y++) {
for (let x = 0; x < w; x++) {
const alpha = imageDataArray[(y * w + x) * 4 + 3];
if (alpha !== 0) {
minX = Math.min(minX, x);
minY = Math.min(minY, y);
maxX = Math.max(maxX, x);
maxY = Math.max(maxY, y);
}
}
}
const width = maxX - minX + 1;
const height = maxY - minY + 1;
const startX = minX;
const startY = minY;
resultCtx.putImageData(imageData, 0, 0)
// 创建一个新的canvas来存储特定区域的图像
const croppedCanvas = document.createElement("canvas");
const croppedContext = croppedCanvas.getContext("2d");
croppedCanvas.width = width;
croppedCanvas.height = height;
// 将特定区域绘制到新canvas上
croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
croppedCanvas.toBlob(blob => {
if (callback) {
callback(blob)
}
}, "image/png");
}
首先从github上下载SAM的代码https://github.com/facebookresearch/segment-anything
然后下载模型文件,保存到项目根目录/checkpoints中,
default
or vit_h
: ViT-H SAM model.vit_l
: ViT-L SAM model.vit_b
: ViT-B SAM model.在项目根目录下创建main.py
main.py
import os
import time
from PIL import Image
import numpy as np
import io
import base64
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
from pycocotools import mask as mask_utils
import lzstring
def init():
# your model path
checkpoint = "checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)
return predictor, mask_generator
predictor, mask_generator = init()
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins="*",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
last_image = ""
last_logit = None
@app.post("/segment")
def process_image(body: dict):
global last_image, last_logit
print("start processing image", time.time())
path = body["path"]
is_first_segment = False
# 看上次分割的图片是不是该图片
if path != last_image: # 不是该图片,重新生成图像embedding
pil_image = Image.open(path)
np_image = np.array(pil_image)
predictor.set_image(np_image)
last_image = path
is_first_segment = True
print("第一次识别该图片,获取embedding")
# 获取mask
clicks = body["clicks"]
input_points = []
input_labels = []
for click in clicks:
input_points.append([click["x"], click["y"]])
input_labels.append(click["clickType"])
print("input_points:{}, input_labels:{}".format(input_points, input_labels))
input_points = np.array(input_points)
input_labels = np.array(input_labels)
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
mask_input=last_logit[None, :, :] if not is_first_segment else None,
multimask_output=is_first_segment # 第一次产生3个结果,选择最优的
)
# 设置mask_input,为下一次做准备
best = np.argmax(scores)
last_logit = logits[best, :, :]
masks = masks[best, :, :]
# print(mask_utils.encode(np.asfortranarray(masks))["counts"])
# numpy_array = np.frombuffer(mask_utils.encode(np.asfortranarray(masks))["counts"], dtype=np.uint8)
# print("Uint8Array([" + ", ".join(map(str, numpy_array)) + "])")
source_mask = mask_utils.encode(np.asfortranarray(masks))["counts"].decode("utf-8")
# print(source_mask)
lzs = lzstring.LZString()
encoded = lzs.compressToEncodedURIComponent(source_mask)
print("process finished", time.time())
return {"shape": masks.shape, "mask": encoded}
@app.get("/everything")
def segment_everything(path: str):
start_time = time.time()
print("start segment_everything", start_time)
pil_image = Image.open(path)
np_image = np.array(pil_image)
masks = mask_generator.generate(np_image)
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
img = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]), dtype=np.uint8)
for idx, ann in enumerate(sorted_anns, 0):
img[ann['segmentation']] = idx
#看一下mask是什么样
#plt.figure(figsize=(10,10))
#plt.imshow(img)
#plt.show()
# 压缩数组
result = my_compress(img)
end_time = time.time()
print("finished segment_everything", end_time)
print("time cost", end_time - start_time)
return {"shape": img.shape, "mask": result}
@app.get('/automatic_masks')
def automatic_masks(path: str):
pil_image = Image.open(path)
np_image = np.array(pil_image)
mask = mask_generator.generate(np_image)
sorted_anns = sorted(mask, key=(lambda x: x['area']), reverse=True)
lzs = lzstring.LZString()
res = []
for ann in sorted_anns:
m = ann['segmentation']
source_mask = mask_utils.encode(m)['counts'].decode("utf-8")
encoded = lzs.compressToEncodedURIComponent(source_mask)
r = {
"encodedMask": encoded,
"point_coord": ann['point_coords'][0],
}
res.append(r)
return res
# 就是将连续的数字统计个数,然后把[个数,数字]放到result中,类似rle算法
# 比如[[1,1,1,2,3,2,2,4,4],[3,3,4...]]
# result是[3,1, 1,2, 1,3, 2,2, 2,4, 2,3,...]
def my_compress(img):
result = []
last_pixel = img[0][0]
count = 0
for line in img:
for pixel in line:
if pixel == last_pixel:
count += 1
else:
result.append(count)
result.append(int(last_pixel))
last_pixel = pixel
count = 1
result.append(count)
result.append(int(last_pixel))
return result
在cmd或者pycharm终端,cd到项目根目录下,输入uvicorn main:app --port 8006,启动服务器
展示图像
未进行抠图
左键设置区域为前景
右键设置区域为背景
{if (!this.clicks.length&&!this.isEverything) this.canvasVisible = false}"/>
关于上面所提rle,可以在项目根目录/notebooks/predictor_example.ipynb中产生mask的位置添加代码自行观察他编码的rle,他只支持矩阵元素为0或1,result的第一个位置是0的个数,不管矩阵是不是0开头。
[0,0,1,1,0,1,0]
,rle counts是[2(两个0), 2(两个1), 1(一个0), 1(一个1), 1(一个0)];
[1,1,1,1,1,0]
,rle counts是[0(零个0),5(五个1),1(一个0)]
def decode_rle(rle_string): # 这是将pycocotools的counts编码的字符串转成counts数组,而非转成原矩阵
result = []
char_index = 0
while char_index < len(rle_string):
value = 0
k = 0
more = 1
while more:
c = ord(rle_string[char_index]) - 48
value |= (c & 0x1f) << (5 * k)
more = c & 0x20
char_index += 1
k += 1
if not more and c & 0x10:
value |= -1 << (5 * k)
if len(result) > 2:
value += result[-2]
result.append(value)
return result
from pycocotools import mask as mask_utils
import numpy as np
mask = np.array([[1,1,0,1,1,0],[1,1,1,1,1,1],[0,1,1,1,0,0],[1,1,1,1,1,1]])
mask = np.asfortranarray(mask, dtype=np.uint8)
print("原mask:\n{}".format(mask))
res = mask_utils.encode(mask)
print("encode:{}".format(res))
print("rle counts:{}".format(decode_rle(res["counts"].decode("utf-8"))))
# 转置后好看
print("转置:{}".format(mask.transpose()))
# flatten后更好看
print("flatten:{}".format(mask.transpose().flatten()))
#numpy_array = np.frombuffer(res["counts"], dtype=np.uint8)
# 打印numpy数组作为uint8array的格式
#print("Uint8Array([" + ", ".join(map(str, numpy_array)) + "])")
输出: