+2 votes
in Programming Languages by (73.8k points)

I want to remove all columns from a CSR matrix which have value 0 for all rows. In the following example. CSR matrix X has last 4 columns with value 0 for all rows. How can I delete them.

>>> X
<10x9 sparse matrix of type '<type 'numpy.int32'>'
        with 18 stored elements in Compressed Sparse Row format>
>>> X.toarray()
array([[1, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 1, 1, 1, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 1, 1, 0, 0, 0, 0, 0],
       [0, 1, 1, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0]])
 

1 Answer

+1 vote
by (349k points)
selected by
 
Best answer

There may be several ways to do this. One way is using numpy.sum() and numpy.where(). Here is an example.

>>> X

<10x9 sparse matrix of type '<type 'numpy.int32'>'

        with 18 stored elements in Compressed Sparse Row format>

>>> X.toarray()

array([[1, 0, 0, 0, 1, 0, 0, 0, 0],

       [0, 0, 0, 1, 0, 0, 0, 0, 0],

       [0, 1, 1, 1, 0, 0, 0, 0, 0],

       [0, 1, 0, 0, 1, 0, 0, 0, 0],

       [0, 0, 1, 0, 0, 0, 0, 0, 0],

       [0, 0, 0, 1, 0, 0, 0, 0, 0],

       [0, 1, 0, 0, 1, 0, 0, 0, 0],

       [0, 0, 1, 1, 0, 0, 0, 0, 0],

       [0, 1, 1, 1, 0, 0, 0, 0, 0],

       [0, 0, 0, 0, 1, 0, 0, 0, 0]])

>>> i=np.where(X.sum(axis=0)!=0)[1]

>>> i

array([0, 1, 2, 3, 4], dtype=int64)

>>> X[:,i]

<10x5 sparse matrix of type '<type 'numpy.int32'>'

        with 18 stored elements in Compressed Sparse Row format>

>>> X[:,i].toarray()

array([[1, 0, 0, 0, 1],

       [0, 0, 0, 1, 0],

       [0, 1, 1, 1, 0],

       [0, 1, 0, 0, 1],

       [0, 0, 1, 0, 0],

       [0, 0, 0, 1, 0],

       [0, 1, 0, 0, 1],

       [0, 0, 1, 1, 0],

       [0, 1, 1, 1, 0],

       [0, 0, 0, 0, 1]])

>>>

Sum() function using axis=0 will generate the total for all columns using all rows. Then where() function checks which columns have total != 0.

>>> X.sum(axis=0)

matrix([[1, 4, 4, 5, 4, 0, 0, 0, 0]])


...