On C++ performance: The Evil Mr. Branch

| No Comments
Here's a simple problem . For some reason, you've got to select between two types or values: smaller (or equal) than 50, and bigger than 50. And you need to do that fast using one core (no multithreading). ¿How do you do it?

For our example, we'll have an array called data where all values are stored, and an array called results, where an integer value is stored, indicating if the input was bigger or smaller than 50.

So, this is the straighfoward way to do it.
for(int i=0;i<NUM_DATA;++i)
{
    if(data[i] <= 50)
        results[i] = SMALLERTHAN_50;
    else
        results[i] = BIGGERTHAN_50;
}
And most of us (and myself 2 days ago) would say: "you won't get much faster than that!". Well, in fact you can. How?

By not branching!

In school they show you that branching can make the processor flush the data it has in the pipeline and start over. They also tell you that today's general-purpose processors do not generally take a big hit when branching, but that hit still exists (I'm talking x86 PCs, here. Consoles DO have serious problems with branching).

So, how do you do this without branching? With a small trick ;-)

result = x + (y-x) & condition

If condition has all bits to 0, the second half of the computation will be 0, so result ends up being x. On the other side, if condition has all bits to 1, the x cancels out, and we get result=y.

So we get a function like this (code shamelessly taken from here):
int isel( int a, int x, int y )
{
    int mask = a >> 31; // arithmetic shift right, splat out the sign bit
    return x + ((y - x) & mask); // mask is 0xFFFFFFFF if (a < 0) and 0x00 otherwise.
}
If a is positive, it returns x, if negative, returns y.

So using this function, we can get our array sorted out just like this:
for(int i=0;i<NUM_DATA;++i)
    results[i] = isel(50 - data[i], SMALLERTHAN_50, BIGGERTHAN_50); 
It works just as well... the only question is performance. What will run faster?

Well, in my machine (a Core 2 Duo at 2.4GHz), for NUM_DATA=10,000, and 200,000 iterations of those loops, I get:

- Using GCC with no optimizations:
Elapsed time BRANCHING:       19.156000000  sec
Elapsed time NOT BRANCHING:   17.566000000  sec
- Using GCC with -O3:
Elapsed time BRANCHING:       2.605000000  sec
Elapsed time NOT BRANCHING:   1.981000000  sec
The difference is not that big without optimizations, but with them, it's clear as day: not branching gets you further. More than a 20%, in this case!!

And this goes to show you what some people might already have told you: compilers inline the hell out of your code if set to high optimization values, even if you don't explicitly ask them to. This code is faster using a function call inside the main loop than without it!

By the way, GCC is NOT generating MMX, SSE, or any instructions like that. This is the disassembly for the main loop in the branchless version (the text on the right is my interpretation for the asm):

.text:00401170 loc_401170:
.text:00401170             mov     eax, ecx                  ## eax=50
.text:00401172             sub     eax, [ebx+edx*4]          ## eax-=data[i]
.text:00401175             shr     eax, 1Fh                  ## eax>>=31
.text:00401178             mov     [edi+edx*4], eax          ## result[i] = eax
.text:0040117B             inc     edx                       ## ++i
.text:0040117C             cmp     edx, 270Fh                ## if(i<NUM_DATA)
.text:00401182             jle     short loc_401170          ##    repeat loop

The process is heavily optimized by GCC, but it's all there in normal, honest-to-fsm, everyday-life instructions.

I do still have some doubts about this, however, as the GCC optimization for the branching loop doesn't use a j(n)le or jz instruction, but setnle. And I can't find out if this instruction does in fact flush the pipeline or not. Anyway, it still results in a faster processing when using the isel version.

For comparison, here's the optimized branching loop:
.text:004010F0 loc_4010F0:
.text:004010F0             xor     eax, eax
.text:004010F2             cmp     dword ptr [ebx+edx*4], 32h
.text:004010F6             setnle  al
.text:004010F9             mov     [esi+edx*4], eax
.text:004010FC             inc     edx
.text:004010FD             cmp     edx, 270Fh
.text:00401103             jle     short loc_4010F0
And, well, here's the whole testing code if anyone is curious about it:
#include <cstdlib>
#include <cstdio>
#include <ctime>

enum{ SMALLERTHAN_50, BIGGERTHAN_50};

int isel( int a, int x, int y )
{
    int mask = a >> 31; // arithmetic shift right, splat out the sign bit
    return x + ((y - x) & mask); // mask is 0xFFFFFFFF if (a < 0) and 0x00 otherwise.
}

int main ()
{
    const int NUM_DATA = 10000;
    const size_t NUM_ITERS = 200000;

    clock_t startTime, endTime;

    int * data     = new int[NUM_DATA];
    int * results  = new int[NUM_DATA];
    int * results2 = new int[NUM_DATA];
    
    // Initialize test data
    for(int i=0;i<NUM_DATA;++i)
        data[i]= rand()%100; 

    startTime = clock();
    // ------------------------------ 
    // Branch test
    for (int j=0;j<NUM_ITERS;++j)
        for(int i=0;i<NUM_DATA;++i)
        {
            if(data[i] <= 50)
                results[i] = SMALLERTHAN_50;
            else
                results[i] = BIGGERTHAN_50;
        }
    // ------------------------------ 
    endTime = clock();

    printf ("\tElapsed time BRANCHING:       %.9f  sec\n", 
        (float)(endTime-startTime) / CLOCKS_PER_SEC);
    
    startTime = clock();
    // ------------------------------ 
    // Branchless test
    for (int j=0;j<NUM_ITERS;++j)
        for(int i=0;i<NUM_DATA;++i)
            results2[i] = isel(
                50 - data[i], 
                SMALLERTHAN_50, 
                BIGGERTHAN_50);                                                                                   
    // ------------------------------ 
    endTime = clock();
 
    printf ("\tElapsed time NOT BRANCHING:   %.9f  sec\n",
        (float)(endTime-startTime) / CLOCKS_PER_SEC);

    // Check we didn't mess things up 
    for(int i=0;i<NUM_DATA;++i)
        if( results[i] != results2[i])
            printf ("ERROR in elem %i=%i, first value: %i, second value: %i\n",
                i, data[i], results[i], results2[i]);
}                                          

Leave a comment