Recently I found a delightfully simple way to efficiently count each 1
- the “active bits”- in a binary representation of an integer.
The method works by processing the bitwise AND of a number n
and n-1
until n==0
. The resulting iterations are the number of active bits. That’s sort of hard to understand in English. It’s much easier to understand in Python:
num_bits = 0
while n != 0:
n &= n - 1
num_bits += 1
Let’s take 1000
, or 1111101000
in binary. Subtracting 1
will give us 999
, 1111100111
in binary. The bitwise AND of these two numbers yields 1111100000
, or 992
in decimal. We continue doing this until n
is 0
. In our case this is six iterations; the number of active bits in 1000
.
We could add some print statements to the code (see appendix) to let you follow along at every step.
1000 0b1111101000
&
999 0b1111100111
=
992 0b1111100000
iterations: 1
992 0b1111100000
&
991 0b1111011111
=
960 0b1111000000
iterations: 2
960 0b1111000000
&
959 0b1110111111
=
896 0b1110000000
iterations: 3
896 0b1110000000
&
895 0b1101111111
=
768 0b1100000000
iterations: 4
768 0b1100000000
&
767 0b1011111111
=
512 0b1000000000
iterations: 5
512 0b1000000000
&
511 0b111111111
=
0 0b0
iterations: 6
number of bits: 6
As you can see, each bitwise AND of n
and n-1
always and only clears the least significant bit of n
.
I love how the time complexity scales directly with the number of active bits; the thing we are measuring. A ridiculously large number like 33554432
, would only have one iteration because the binary representation is 10000000000000000000000000
. Whereas 31
would take five times as long because it’s binary representation is 11111
. There is something really elegant about it.
This algorithm is frequently attributed to Brian Kernighan. I came across this while working on a Zig implementation of the VDB data structure, but that’s a story for another blog post.
Print statements added to the Python code
n = 1000
num_bits = 0
while n != 0:
m = n-1
print(n, bin(n))
print(" &")
print(m, bin(m))
print(" =")
n &= m
result = n
num_bits += 1
print(result, bin(result))
print("iterations: ", num_bits)
print("\n")
print("number of bits: ", num_bits)