SRTP AES Optimization Revisited

Nils L. Corneliusen
27 July 2016

Introduction

In a 2010 article titled SRTP AES Optimization I presented a method to make SRTP AES run significantly quicker. Unfortunately, there were some caveats: Packet length had to be 4096 bytes or less and a multiple of 16, and the target CPU was expected to be big endian. Let's try to address these issues in a new and improved version that will run on any 32-bit CPU.

SRTP AES Recap

From RFC3711:

4.1.1 AES in Counter Mode (...) Note that the initial value, IV, is fixed for each packet and is formed by "reserving" 16 zeros in the least significant bits for the purpose of the counter. (...)

So there's a 16 bit counter and each block is 16 bytes, which gives a maximum packet size of 1048576 bytes. The IV for each 16 byte block inside 256*16=4k differs in only the last byte. Conversely, the IV for each 4k block differs in only the second last byte.

New and Improved srtp_crypt()

Endianness is handled by the ubiquitous htonl() call. The multiple of 16 issue is addressed by a simple rollup of the remaining data outside the main loop. No reads or writes will access memory outside the given length. Both input and output are still expected to be 32-bit aligned.

Like in the original code, the constant parts are recalculated fully every 4k block. As mentioned above, parts of round 0 and round 1 can be pulled outside the 4k loop, but the impact is minimal. At max packet size there's only 256 iterations. All the action is inside the inner loop, so let's keep it simple.

Let's start off by defining a function, srtp_crypt(), that tries to look vaguely similar to AES_ctr128_encrypt(). It'll also easily replace srtp_encrypt_decrypt() from the original article. The number of 4k blocks and rest is determined, and the start of the iv is reversed if needed:

void srtp_crypt( const uint32_t *src, uint32_t *dst, int srclen, const AES_KEY *aeskey, const uint32_t *iv )
{
    if( srclen == 0 ) return;

    uint32_t len4096 = (srclen+4095)>>12;
    uint32_t rest = srclen&0x0f;

    uint32_t r1t0;
    uint32_t r2s0, r2s1, r2s2, r2s3;
    uint32_t s0, s1, s2, s3;
    uint32_t t0, t1, t2, t3;

    const uint32_t *key = aeskey->rd_key;

    uint8_t k3ctr = 0;
    uint8_t k3 = (uint8_t)key[3];

    uint32_t ivn[4];
    ivn[0] = htonl( iv[0] );
    ivn[1] = htonl( iv[1] );
    ivn[2] = htonl( iv[2] );
    ivn[3] =        iv[3]  ;

The main 4k loop precalculates the constant parts and determines how many 16 byte blocks are needed:

    do {
        // Everything except last byte of iv is the same for each 4k block.
        // Precalculate constant parts for this 4k block:

        // round 0, s0-s2 constant, last byte in s3 0..
          s0 =        ivn[0]   ^ key[0];
          s1 =        ivn[1]   ^ key[1];
          s2 =        ivn[2]   ^ key[2];
          s3 = htonl( ivn[3] ) ^ key[3];

        // round 1, t0-t3, t0 except last byte
        r1t0 = Te0[s0 >> 24] ^ Te1[(s1 >> 16) & 0xff] ^ Te2[(s2 >> 8) & 0xff]                  ^ key[ 4];
          t1 = Te0[s1 >> 24] ^ Te1[(s2 >> 16) & 0xff] ^ Te2[(s3 >> 8) & 0xff] ^ Te3[s0 & 0xff] ^ key[ 5];
          t2 = Te0[s2 >> 24] ^ Te1[(s3 >> 16) & 0xff] ^ Te2[(s0 >> 8) & 0xff] ^ Te3[s1 & 0xff] ^ key[ 6];
          t3 = Te0[s3 >> 24] ^ Te1[(s0 >> 16) & 0xff] ^ Te2[(s1 >> 8) & 0xff] ^ Te3[s2 & 0xff] ^ key[ 7];

        // round 2, s0-s3 without t0
        r2s0 =                 Te1[(t1 >> 16) & 0xff] ^ Te2[(t2 >> 8) & 0xff] ^ Te3[t3 & 0xff] ^ key[ 8];
        r2s1 = Te0[t1 >> 24] ^ Te1[(t2 >> 16) & 0xff] ^ Te2[(t3 >> 8) & 0xff]                  ^ key[ 9];
        r2s2 = Te0[t2 >> 24] ^ Te1[(t3 >> 16) & 0xff]                         ^ Te3[t1 & 0xff] ^ key[10];
        r2s3 = Te0[t3 >> 24]                          ^ Te2[(t1 >> 8) & 0xff] ^ Te3[t2 & 0xff] ^ key[11];

        uint32_t len16 = srclen > 4096 ? 4096/16 : (srclen+0x0f)>>4;

The inner loop processes each 16 byte block:

        do {
            // Round 0,1,2 reduced from 48 xors to 5:

            // round 0/1
            //    s3 = iv[3] ^ key[3]; t0 = Te3[s3&0xff] ^ r1t0;
            // => t0 = Te3[ iv[3] ^ key[3]) & 0xff] ^ r1t0;
            t0 = Te3[k3ctr^k3] ^ r1t0;

            // round 2
            s0 = Te0[ t0   >> 24        ] ^ r2s0;
            s1 = Te3[ t0          & 0xff] ^ r2s1;
            s2 = Te2[(t0   >>  8) & 0xff] ^ r2s2;
            s3 = Te1[(t0   >> 16) & 0xff] ^ r2s3;

            /* round 3: */
            t0 = Te0[s0 >> 24] ^ Te1[(s1 >> 16) & 0xff] ^ Te2[(s2 >>  8) & 0xff] ^ Te3[s3 & 0xff] ^ key[12];
            t1 = Te0[s1 >> 24] ^ Te1[(s2 >> 16) & 0xff] ^ Te2[(s3 >>  8) & 0xff] ^ Te3[s0 & 0xff] ^ key[13];
            t2 = Te0[s2 >> 24] ^ Te1[(s3 >> 16) & 0xff] ^ Te2[(s0 >>  8) & 0xff] ^ Te3[s1 & 0xff] ^ key[14];
            t3 = Te0[s3 >> 24] ^ Te1[(s0 >> 16) & 0xff] ^ Te2[(s1 >>  8) & 0xff] ^ Te3[s2 & 0xff] ^ key[15];
            /* round 4: */
            s0 = Te0[t0 >> 24] ^ Te1[(t1 >> 16) & 0xff] ^ Te2[(t2 >>  8) & 0xff] ^ Te3[t3 & 0xff] ^ key[16];
            s1 = Te0[t1 >> 24] ^ Te1[(t2 >> 16) & 0xff] ^ Te2[(t3 >>  8) & 0xff] ^ Te3[t0 & 0xff] ^ key[17];
            s2 = Te0[t2 >> 24] ^ Te1[(t3 >> 16) & 0xff] ^ Te2[(t0 >>  8) & 0xff] ^ Te3[t1 & 0xff] ^ key[18];
            s3 = Te0[t3 >> 24] ^ Te1[(t0 >> 16) & 0xff] ^ Te2[(t1 >>  8) & 0xff] ^ Te3[t2 & 0xff] ^ key[19];
            /* round 5: */
            t0 = Te0[s0 >> 24] ^ Te1[(s1 >> 16) & 0xff] ^ Te2[(s2 >>  8) & 0xff] ^ Te3[s3 & 0xff] ^ key[20];
            t1 = Te0[s1 >> 24] ^ Te1[(s2 >> 16) & 0xff] ^ Te2[(s3 >>  8) & 0xff] ^ Te3[s0 & 0xff] ^ key[21];
            t2 = Te0[s2 >> 24] ^ Te1[(s3 >> 16) & 0xff] ^ Te2[(s0 >>  8) & 0xff] ^ Te3[s1 & 0xff] ^ key[22];
            t3 = Te0[s3 >> 24] ^ Te1[(s0 >> 16) & 0xff] ^ Te2[(s1 >>  8) & 0xff] ^ Te3[s2 & 0xff] ^ key[23];
            /* round 6: */
            s0 = Te0[t0 >> 24] ^ Te1[(t1 >> 16) & 0xff] ^ Te2[(t2 >>  8) & 0xff] ^ Te3[t3 & 0xff] ^ key[24];
            s1 = Te0[t1 >> 24] ^ Te1[(t2 >> 16) & 0xff] ^ Te2[(t3 >>  8) & 0xff] ^ Te3[t0 & 0xff] ^ key[25];
            s2 = Te0[t2 >> 24] ^ Te1[(t3 >> 16) & 0xff] ^ Te2[(t0 >>  8) & 0xff] ^ Te3[t1 & 0xff] ^ key[26];
            s3 = Te0[t3 >> 24] ^ Te1[(t0 >> 16) & 0xff] ^ Te2[(t1 >>  8) & 0xff] ^ Te3[t2 & 0xff] ^ key[27];
            /* round 7: */
            t0 = Te0[s0 >> 24] ^ Te1[(s1 >> 16) & 0xff] ^ Te2[(s2 >>  8) & 0xff] ^ Te3[s3 & 0xff] ^ key[28];
            t1 = Te0[s1 >> 24] ^ Te1[(s2 >> 16) & 0xff] ^ Te2[(s3 >>  8) & 0xff] ^ Te3[s0 & 0xff] ^ key[29];
            t2 = Te0[s2 >> 24] ^ Te1[(s3 >> 16) & 0xff] ^ Te2[(s0 >>  8) & 0xff] ^ Te3[s1 & 0xff] ^ key[30];
            t3 = Te0[s3 >> 24] ^ Te1[(s0 >> 16) & 0xff] ^ Te2[(s1 >>  8) & 0xff] ^ Te3[s2 & 0xff] ^ key[31];
            /* round 8: */
            s0 = Te0[t0 >> 24] ^ Te1[(t1 >> 16) & 0xff] ^ Te2[(t2 >>  8) & 0xff] ^ Te3[t3 & 0xff] ^ key[32];
            s1 = Te0[t1 >> 24] ^ Te1[(t2 >> 16) & 0xff] ^ Te2[(t3 >>  8) & 0xff] ^ Te3[t0 & 0xff] ^ key[33];
            s2 = Te0[t2 >> 24] ^ Te1[(t3 >> 16) & 0xff] ^ Te2[(t0 >>  8) & 0xff] ^ Te3[t1 & 0xff] ^ key[34];
            s3 = Te0[t3 >> 24] ^ Te1[(t0 >> 16) & 0xff] ^ Te2[(t1 >>  8) & 0xff] ^ Te3[t2 & 0xff] ^ key[35];
            /* round 9: */
            t0 = Te0[s0 >> 24] ^ Te1[(s1 >> 16) & 0xff] ^ Te2[(s2 >>  8) & 0xff] ^ Te3[s3 & 0xff] ^ key[36];
            t1 = Te0[s1 >> 24] ^ Te1[(s2 >> 16) & 0xff] ^ Te2[(s3 >>  8) & 0xff] ^ Te3[s0 & 0xff] ^ key[37];
            t2 = Te0[s2 >> 24] ^ Te1[(s3 >> 16) & 0xff] ^ Te2[(s0 >>  8) & 0xff] ^ Te3[s1 & 0xff] ^ key[38];
            t3 = Te0[s3 >> 24] ^ Te1[(s0 >> 16) & 0xff] ^ Te2[(s1 >>  8) & 0xff] ^ Te3[s2 & 0xff] ^ key[39];

This will bail us out for rest handling if needed:

            if( __builtin_expect( rest && len16 == 1 && len4096 == 1, 0 ) ) break;

Write encrypted data and continue:

            dst[0] = src[0] ^ htonl( ((Te2[(t0 >> 24)       ] & 0xff000000) ^
                                      (Te3[(t1 >> 16) & 0xff] & 0x00ff0000) ^
                                      (Te0[(t2 >>  8) & 0xff] & 0x0000ff00) ^
                                      (Te1[(t3      ) & 0xff] & 0x000000ff) ^ key[40]) );

            dst[1] = src[1] ^ htonl( ((Te2[(t1 >> 24)       ] & 0xff000000) ^
                                      (Te3[(t2 >> 16) & 0xff] & 0x00ff0000) ^
                                      (Te0[(t3 >>  8) & 0xff] & 0x0000ff00) ^
                                      (Te1[(t0      ) & 0xff] & 0x000000ff) ^ key[41]) );
            
            dst[2] = src[2] ^ htonl( ((Te2[(t2 >> 24)       ] & 0xff000000) ^
                                      (Te3[(t3 >> 16) & 0xff] & 0x00ff0000) ^
                                      (Te0[(t0 >>  8) & 0xff] & 0x0000ff00) ^
                                      (Te1[(t1      ) & 0xff] & 0x000000ff) ^ key[42]) );

            dst[3] = src[3] ^ htonl( ((Te2[(t3 >> 24)       ] & 0xff000000) ^
                                      (Te3[(t0 >> 16) & 0xff] & 0x00ff0000) ^
                                      (Te0[(t1 >>  8) & 0xff] & 0x0000ff00) ^
                                      (Te1[(t2      ) & 0xff] & 0x000000ff) ^ key[43]) );


            src += 4; dst += 4;
            k3ctr++;

        } while( --len16 );

When 4k is done, increment the 4k block counter:

        // next 4k block
        ((uint8_t *)ivn)[14]++;
        srclen -= 4096;

    } while( --len4096 );

If there was any rest, wrap it up in a simple loop:

    if( rest ) {

        uint32_t tmp[4];

        tmp[0] = htonl( ((Te2[(t0 >> 24)       ] & 0xff000000) ^
                         (Te3[(t1 >> 16) & 0xff] & 0x00ff0000) ^
                         (Te0[(t2 >>  8) & 0xff] & 0x0000ff00) ^
                         (Te1[(t3      ) & 0xff] & 0x000000ff) ^ key[40]) );

        tmp[1] = htonl( ((Te2[(t1 >> 24)       ] & 0xff000000) ^
                         (Te3[(t2 >> 16) & 0xff] & 0x00ff0000) ^
                         (Te0[(t3 >>  8) & 0xff] & 0x0000ff00) ^
                         (Te1[(t0      ) & 0xff] & 0x000000ff) ^ key[41]) );

        tmp[2] = htonl( ((Te2[(t2 >> 24)       ] & 0xff000000) ^
                         (Te3[(t3 >> 16) & 0xff] & 0x00ff0000) ^
                         (Te0[(t0 >>  8) & 0xff] & 0x0000ff00) ^
                         (Te1[(t1      ) & 0xff] & 0x000000ff) ^ key[42]) );

        tmp[3] = htonl( ((Te2[(t3 >> 24)       ] & 0xff000000) ^
                         (Te3[(t0 >> 16) & 0xff] & 0x00ff0000) ^
                         (Te0[(t1 >>  8) & 0xff] & 0x0000ff00) ^
                         (Te1[(t2      ) & 0xff] & 0x000000ff) ^ key[43]) );

        uint8_t *src8 = (uint8_t *)src;
        uint8_t *dst8 = (uint8_t *)dst;
        uint8_t *tmp8 = (uint8_t *)tmp;

        do {
            *dst8++ = *src8++ ^ *tmp8++;
        } while( --rest );

    }

}

Verification

Since we're not banging rocks together here, it's probably wise to make sure srtp_crypt() works properly and that there's no overrun. A function for that is provided in main.c:

static bool validate_aes( int len )

It will use both the new srtp_crypt() and AES_ctr128_encrypt() for encryption and decryption, compare the outputs and check for overrun for a given length. All inputs are random, but deterministic via fixed srand() so it's easier to debug. Running for a set of different lengths should always have the same iv, key and source data. The start of the source data is filled with a famous movie quote. Looking for ascii while debugging is more fun.

Interesting lengths to test would be all the edge values:

    int testlens[] = { 0, 1, 15, 16, 17, 4095, 4096, 4097, MAX_LEN-1, MAX_LEN };
    int testcnt = sizeof(testlens)/sizeof(testlens[0]);

    for( int i = 0; i < testcnt; i++ ) {
        bool ok = validate_aes( testlens[i] );
        if( !ok ) return 1;
    }

Measurements

I extracted AES_ctr128_encrypt() from OpenSSL and made sure it was compiled with the same options. Linking with whatever random version was installed on this system gave disastrous results for OpenSSL.

A measurement function is provided:

static uint64_t measure_aes( int len, int cnt, bool use_ssl )

Let's try it for a range of values, graph shows speed difference in percent between the old and new version:

Length 1..64 step 1:

SRTP AES Measurement, length 1-64 step 1

Length 96..4096 step 32:

SRTP AES Measurement, length 96-4096 step 32

Length 5120..65536 step 1024:

SRTP AES Measurement, length 5120-65536 step 1024

Source Code

Source files:

Compile and link:

[@tyrell srtp2016] gcc -o srtp -O3 -std=gnu99 -Wall srtp_crypt.c main.c -lm -lcrypto

For measurements, I suggest finding the needed OpenSSL files somewhere and compiling/linking them yourself:

[@tyrell srtp2016] gcc -o srtp -O3 -std=gnu99 -Wall srtp_crypt.c main.c aes_core.c aes_ctr.c -lm

A test run:

[@tyrell srtp2016] ./srtp
Test len       0: Ok
Test len       1: Ok
Test len      15: Ok
Test len      16: Ok
Test len      17: Ok
Test len    4095: Ok
Test len    4096: Ok
Test len    4097: Ok
Test len 1048575: Ok
Test len 1048576: Ok
      1 1000000 69607 73340 -5.089992
      2 1000000 70164 73686 -4.779741
      3 1000000 70781 74276 -4.705423
      4 1000000 72281 74993 -3.616338
      5 1000000 72164 75449 -4.353934
(...)

Comments are always appreciated. I prefer being contacted on LinkedIn. Email is also available, you can figure it out from the front page. I switched to a disposable email address system, so it will change regularly.

Remember to appreciate this classic XKCD strip.


www.ignorantus.com