I have trouble properly understanding numpy.where() despite reading the doc, this post and this other post.
Can someone provide step-by-step commented examples with 1D and 2D arrays?
Answers:
Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.
Method 1
After fiddling around for a while, I figured things out, and am posting them here hoping it will help others.
Intuitively, np.where is like asking “tell me where in this array, entries satisfy a given condition“.
>>> a = np.arange(5,10) >>> np.where(a < 8) # tell me where in a, entries are < 8 (array([0, 1, 2]),) # answer: entries indexed by 0, 1, 2
It can also be used to get entries in array that satisfy the condition:
>>> a[np.where(a < 8)] array([5, 6, 7]) # selects from a entries 0, 1, 2
When a is a 2d array, np.where() returns an array of row idx’s, and an array of col idx’s:
>>> a = np.arange(4,10).reshape(2,3)
array([[4, 5, 6],
[7, 8, 9]])
>>> np.where(a > 8)
(array(1), array(2))
As in the 1d case, we can use np.where() to get entries in the 2d array that satisfy the condition:
>>> a[np.where(a > 8)] # selects from a entries 0, 1, 2
array([9])
Note, when a is 1d, np.where() still returns an array of row idx’s and an array of col idx’s, but columns are of length 1, so latter is empty array.
Method 2
Here is a little more fun. I’ve found that very often NumPy does exactly what I wish it would do – sometimes it’s faster for me to just try things than it is to read the docs. Actually a mixture of both is best.
I think your answer is fine (and it’s OK to accept it if you like). This is just “extra”.
import numpy as np a = np.arange(4,10).reshape(2,3) wh = np.where(a>7) gt = a>7 x = np.where(gt) print "wh: ", wh print "gt: ", gt print "x: ", x
gives:
wh: (array([1, 1]), array([1, 2]))
gt: [[False False False]
[False True True]]
x: (array([1, 1]), array([1, 2]))
… but:
print "a[wh]: ", a[wh] print "a[gt] ", a[gt] print "a[x]: ", a[x]
gives:
a[wh]: [8 9] a[gt] [8 9] a[x]: [8 9]
All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0