Say I have an array of distances x=[1,2,1,3,3,2,1,5,1,1].
I want to get the indices from x where cumsum reaches 10, in this case, idx=[4,9].
So the cumsum restarts after the condition are met.
I can do it with a loop, but loops are slow for large arrays and I was wondering if I could do it in a vectorized way.
Answers:
Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.
Method 1
Here’s one with numba and array-initialization –
from numba import njit
@njit
def cumsum_breach_numba2(x, target, result):
total = 0
iterID = 0
for i,x_i in enumerate(x):
total += x_i
if total >= target:
result[iterID] = i
iterID += 1
total = 0
return iterID
def cumsum_breach_array_init(x, target):
x = np.asarray(x)
result = np.empty(len(x),dtype=np.uint64)
idx = cumsum_breach_numba2(x, target, result)
return result[:idx]
Timings
Including @piRSquared's solutions and using the benchmarking setup from the same post –
In [58]: np.random.seed([3, 1415])
...: x = np.random.randint(100, size=1000000).tolist()
# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop
# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop
# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop
Numba : Appending vs. array-initialization
For a closer look at how the array-initialization helps, which seems be the big difference between the two numba implementations, let’s time these on the array data, as the array data creation was in itself heavy on runtime and they both depend on it –
In [62]: x = np.array(x) In [63]: %timeit cumsum_breach_numba(x, 10)# with appending 10 loops, best of 3: 31.5 ms per loop In [64]: %timeit cumsum_breach_array_init(x, 10) 1000 loops, best of 3: 1.8 ms per loop
To force the output to have it own memory space, we can make a copy. Won’t change the things in a big way though –
In [65]: %timeit cumsum_breach_array_init(x, 10).copy() 100 loops, best of 3: 2.67 ms per loop
Method 2
A fun method
sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1) newx=sumlm.accumulate(x, dtype=np.object) newx array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object) np.nonzero(newx==10) (array([4, 9]),)
Method 3
Loops are not always bad (especially when you need one). Also, There is no tool or algorithm that will make this quicker than O(n). So let’s just make a good loop.
Generator Function
def cumsum_breach(x, target):
total = 0
for i, y in enumerate(x):
total += y
if total >= target:
yield i
total = 0
list(cumsum_breach(x, 10))
[4, 9]
Just In Time compiling with Numba
Numba is a third party library that needs to be installed.
Numba can be persnickety about what features are supported. But this works.
Also, as pointed out by Divakar, Numba performs better with arrays
from numba import njit
@njit
def cumsum_breach_numba(x, target):
total = 0
result = []
for i, y in enumerate(x):
total += y
if total >= target:
result.append(i)
total = 0
return result
cumsum_breach_numba(x, 10)
Testing the Two
Because I felt like it ¯_(ツ)_/¯
Setup
np.random.seed([3, 1415]) x0 = np.random.randint(100, size=1_000_000) x1 = x0.tolist()
Accuracy
i0 = cumsum_breach_numba(x0, 200_000) i1 = list(cumsum_breach(x1, 200_000)) assert i0 == i1
Time
%timeit cumsum_breach_numba(x0, 200_000) %timeit list(cumsum_breach(x1, 200_000)) 582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numba was on the order of 100 times faster.
For a more true apples to apples test, I convert a list to a Numpy array
%timeit cumsum_breach_numba(np.array(x1), 200_000) %timeit list(cumsum_breach(x1, 200_000)) 43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Which brings them to about even.
All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0