GPU version:
step 1, model = model.to('cuda:0') or model = model.to('cuda')
step2: audio_inputs = processor(audios=audio, return_tensors="pt").to('cuda:0') or audio_inputs = processor(audios=audio, return_tensors="pt").to('cuda')
CPU version:
step 1 not valid for cpu , so marked this line, #model = model.to('cuda:0') or model = model.to('cuda')
step2: audio_inputs = processor(audios=audio, return_tensors="pt")
records:
推理3s的音频转音频性能对比, rtx4090 vs cpu 11870
rtx4090: 0.5s
cpu 11870: 3.25s
conclusion: rtx4090 is 650% better
如果要部署到gunicorn+fastapi 那么就需要
step1, main.py修改:
app = FastAPI()
model = None
processor= None
torch_device=None
preload_app = True #Reducing Memory Usage and Boosting Performance among Gunicorn Workers
sample_rate = 16_000
#This piece of code will be executed only when the server starts
@app.on_event("startup")
def on_startup():
global model
global processor
global torch_device
torch.backends.cudnn.enabled
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print ("Device ", torch_device)
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large")
print(torch.cuda.get_device_name(0))
#print(torch.cuda.get_device_name(1))
model = model.to(torch_device)
step2, added myconfig.py
import os
import time
try:
import pynvml
pynvml.nvmlInit()
gpuDeviceCount = pynvml.nvmlDeviceGetCount()
except:
gpuDeviceCount = 1
gpuDevicePool = []
def pre_fork(server, worker):
try:
gid = gpuDevicePool.pop(0)
except:
gid = (worker.age - 1) % gpuDeviceCount
worker.gid = gid
def post_fork(server,worker):
time.sleep(worker.age % server.cfg.workers)
os.environ['CUDA_VISIBLE_DEVICES'] = str(worker.gid)
server.log.info(f'worker(age:{worker.age}, pid:{worker.pid}, cuda:{worker.gid})')
def child_exit(server, worker):
gpuDevicePool.append(worker.gid)
step3:
gunicorn -c myconfig.py --workers 1 --preload --worker-class=uvicorn.workers.UvicornWorker main:app
issues:
model.to('cuda:1') 出现下面问题,可能是因为gpu显存不足问题
t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking#ip=1