## Most efficient way to sort an array into bins specified by an index array?

``````data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
idx  = np.array([2, 0, 1, 1, 2, 0, 1, 1, 2])

``````

Expected result:

``````binned = np.array([2, 6, 3, 4, 7, 8, 1, 5, 9])

``````

Constraints:

• Should be fast.

• Should be `O(n+k)` where n is the length of data and k is the number of bins.

• Should be stable, i.e. order within bins is preserved.

Obvious solution

``````data[np.argsort(idx, kind='stable')]

``````

is `O(n log n)`.

`O(n+k)` solution

``````def sort_to_bins(idx, data, mx=-1):
if mx==-1:
mx = idx.max() + 1
cnts = np.zeros(mx + 1, int)
for i in range(idx.size):
cnts[idx[i] + 1] += 1
for i in range(1, cnts.size):
cnts[i] += cnts[i-1]
res = np.empty_like(data)
for i in range(data.size):
res[cnts[idx[i]]] = data[i]
cnts[idx[i]] += 1
return res

``````

is loopy and slow.

Is there a better method in pure `numpy` < `scipy` < `pandas` < `numba`/`pythran`?

Here are a few solutions:

1. Use `np.argsort` anyway, after all it is fast compiled code.

2. Use `np.bincount` to get the bin sizes and `np.argpartition` which is `O(n)` for fixed number of bins. Downside: currently, no stable algorithm is available, thus we have to sort each bin.

3. Use `scipy.ndimage.measurements.labeled_comprehension`. This does roughly what is required, but no idea how it is implemented.

4. Use `pandas`. I'm a complete `pandas` noob, so what I cobbled together here using `groupby` may be suboptimal.

5. Use `scipy.sparse` switching between compressed sparse row and compressed sparse column formats happens to implement the exact operation we are looking for.

6. Use `pythran` (I'm sure `numba` works as well) on the loopy code in the question. All that is required is to insert at the top after numpy import

7. .

#pythran export sort_to_bins(int[:], float[:], int)

and then compile

``````# pythran stb_pthr.py

``````

Benchmarks 100 bins, variable number of items:

Take home:

If you are ok with `numba`/`pythran` that is the way to go, if not `scipy.sparse` scales rather well.

Code:

``````import numpy as np
from scipy import sparse
from scipy.ndimage.measurements import labeled_comprehension
from stb_pthr import sort_to_bins as sort_to_bins_pythran
import pandas as pd

def sort_to_bins_pandas(idx, data, mx=-1):
df = pd.DataFrame.from_dict(data=data)
out = np.empty_like(data)
j = 0
for grp in df.groupby(idx).groups.values():
out[j:j+len(grp)] = data[np.sort(grp)]
j += len(grp)
return out

def sort_to_bins_ndimage(idx, data, mx=-1):
if mx==-1:
mx = idx.max() + 1
out = np.empty_like(data)
j = 0
def collect(bin):
nonlocal j
out[j:j+len(bin)] = np.sort(bin)
j += len(bin)
return 0
labeled_comprehension(data, idx, np.arange(mx), collect, data.dtype, None)
return out

def sort_to_bins_partition(idx, data, mx=-1):
if mx==-1:
mx = idx.max() + 1
return data[np.argpartition(idx, np.bincount(idx, None, mx)[:-1].cumsum())]

def sort_to_bins_partition_stable(idx, data, mx=-1):
if mx==-1:
mx = idx.max() + 1
split = np.bincount(idx, None, mx)[:-1].cumsum()
srt = np.argpartition(idx, split)
for bin in np.split(srt, split):
bin.sort()
return data[srt]

def sort_to_bins_sparse(idx, data, mx=-1):
if mx==-1:
mx = idx.max() + 1
return sparse.csr_matrix((data, idx, np.arange(len(idx)+1)), (len(idx), mx)).tocsc().data

def sort_to_bins_argsort(idx, data, mx=-1):
return data[idx.argsort(kind='stable')]

from timeit import timeit
exmpls = [np.random.randint(0, K, (N,)) for K, N in np.c_[np.full(16, 100), 1<<np.arange(5, 21)]]

timings = {}
for idx in exmpls:
data = np.arange(len(idx), dtype=float)
ref = None
for x, f in (*globals().items(),):
if x.startswith('sort_to_bins_'):
timings.setdefault(x.replace('sort_to_bins_', '').replace('_', ' '), []).append(timeit('f(idx, data, -1)', globals={'f':f, 'idx':idx, 'data':data}, number=10)*100)
if x=='sort_to_bins_partition':
continue
if ref is None:
ref = f(idx, data, -1)
else:
assert np.all(f(idx, data, -1)==ref)

import pylab
for k, v in timings.items():
pylab.loglog(1<<np.arange(5, 21), v, label=k)
pylab.xlabel('#items')
pylab.ylabel('time [ms]')
pylab.legend()
pylab.show()

``````
Posted on by Paul Panzer