循环不一定是不合需要的.如果性能是一个问题,请考虑numba.在没有实质性改变逻辑的情况下,改进了大约330倍:
from numba import njit
np.random.seed(0)
a = np.random.randint(-100, 100, 10000)
b = a/100
@njit
def cumsum_with_limits_nb(values):
n = len(values)
res = np.empty(n)
sum_val = 0
for i in range(n):
x = values[i]
if (sum_val+x <= 1) and (sum_val+x >= -1):
res[i] = x
sum_val += x
elif sum_val+x >= 1:
d = 1-sum_val # Remainder to 1
res[i] = d
sum_val += d
elif sum_val+x <= -1:
d = -1-sum_val # Remainder to -1
res[i] = d
sum_val += d
return res
assert np.isclose(cumsum_with_limits(b), cumsum_with_limits_nb(b)).all()
如果你不介意牺牲一些性能,你可以更简洁地重写这个循环:
@njit
def cumsum_with_limits_nb2(values):
n = len(values)
res = np.empty(n)
sum_val = 0
for i in range(n):
x = values[i]
next_sum = sum_val + x
if np.abs(next_sum) >= 1:
x = np.sign(next_sum) - sum_val
res[i] = x
sum_val += x
return res
与nb2具有相似的性能,这是另一种选择(感谢@jdehesa):
@njit
def cumsum_with_limits_nb3(values):
n = len(values)
res = np.empty(n)
sum_val = 0
for i in range(n):
x = min(max(sum_val + values[i], -1) , 1) - sum_val
res[i] = x
sum_val += x
return res
绩效比较:
assert np.isclose(cumsum_with_limits(b), cumsum_with_limits_nb(b)).all()
assert np.isclose(cumsum_with_limits(b), cumsum_with_limits_nb2(b)).all()
assert np.isclose(cumsum_with_limits(b), cumsum_with_limits_nb3(b)).all()
%timeit cumsum_with_limits(b) # 12.5 ms per loop
%timeit cumsum_with_limits_nb(b) # 40.9 µs per loop
%timeit cumsum_with_limits_nb2(b) # 54.7 µs per loop
%timeit cumsum_with_limits_nb3(b) # 54 µs per loop