+5 votes
in Programming Languages by (73.8k points)
I want to find the index of the largest element in each row and column of a CSR matrix. What function should I use?

1 Answer

+1 vote
by (348k points)
selected by
 
Best answer

The argmax() function of csr_matrix can be used to find the index of the largest element in each row and column of a CSR matrix. You need to use the argument "axis" to specify row and column.

axis=0 means "for each column which row has the max value"

axis=1 means "for each row which column has the max value"

Here is an example:

>>> import numpy as np
>>> from scipy.sparse import csr_matrix
>>> row = np.array([0, 0, 1, 2, 2, 2])
>>> col = np.array([0, 2, 2, 0, 1, 2])
>>> data = np.array([1, 2, 5, 3, 1, 4])
>>> X=csr_matrix((data, (row, col)), shape=(3, 3))
>>> X.toarray()
array([[1, 0, 2],
       [0, 0, 5],
       [3, 1, 4]])
>>> X.argmax(axis=0)
matrix([[2, 2, 1]])
>>> X.argmax(axis=1)
matrix([[2],
        [2],
        [2]])
>>>
 


...