// Fast Lanczos-2 scaler for ARM in C and Neon intrinsics.
// Written by Nils Liaaen Corneliusen 2023.
// License: CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
// Read the 2023 article here: https://www.ignorantus.com/pages/neon_filters/
// Read the 2018 article here: https://www.ignorantus.com/pages/image_transformation/

// Only intrinsics listed as "v7/A32/A64" are used.
// 32 register Neon recommended.
// The Assembler version is 38% faster on 16 register Neon. Check the 2023 article.

#include <stdio.h>
#include <stdint.h>
#include <malloc.h>
#include <alloca.h>
#include <arm_neon.h>

#include "coeffs.h"

#define u8_u64   vreinterpret_u8_u64
#define s8_u64   vreinterpret_s8_u64
#define u64_u8   vreinterpret_u64_u8
#define qu8_u64  vreinterpretq_u8_u64
#define qs16_u16 vreinterpretq_s16_u16
#define qu32_u8  vreinterpretq_u32_u8

#define alloca_aligned(size, align) ((void *)((((size_t)alloca((size)+(align)))+(align)-1)&(~((align)-1))))
#define GETOFFSET(p) ( ((p)>>3)*64+((p)&0x07) )

static void h2v( uint8_t *src, int cnt )
{
    uint32_t *dst32 = (uint32_t *)src;

    do {
        uint8x16_t in0 = vld1q_u8( src      ); // 00000000.11111111 r0.r1
        uint8x16_t in1 = vld1q_u8( src + 16 ); // 22222222.33333333 r2.r3
        uint8x16_t in2 = vld1q_u8( src + 32 ); // 44444444.55555555 r4.r5
        uint8x16_t in3 = vld1q_u8( src + 48 ); // 66666666.77777777 r6.r7
        src += 64;

        uint8x16x2_t tr0 = vzipq_u8( in0, in1 );
        uint8x16x2_t tr1 = vzipq_u8( in2, in3 );

        uint8x16x2_t tr2 = vzipq_u8( tr0.val[0], tr0.val[1] );
        uint8x16x2_t tr3 = vzipq_u8( tr1.val[0], tr1.val[1] );

        uint32x4x2_t tr4 = vzipq_u32( qu32_u8(tr2.val[0]), qu32_u8(tr3.val[0]) );
        uint32x4x2_t tr5 = vzipq_u32( qu32_u8(tr2.val[1]), qu32_u8(tr3.val[1]) );

        vst1q_u32( dst32,      tr4.val[0] );  // 01234567.01234567 c0.c1
        vst1q_u32( dst32 +  4, tr4.val[1] );  // 01234567.01234567 c2.c3
        vst1q_u32( dst32 +  8, tr5.val[0] );  // 01234567.01234567 c4.c5
        vst1q_u32( dst32 + 12, tr5.val[1] );  // 01234567.01234567 c6.c7
        dst32 += 16;

        cnt -= 8;

    } while( cnt > 0 );
}

static inline int clamp( int v, int lo, int hi )
{
    return v < lo ? lo : v > hi ? hi: v;
}

static void vert_interpolate_4to1_8( uint8_t *src, int srcw, int srch, int src_stride, uint8_t *dst, uint32_t *yco, int yy )
{
    int16x8_t rv = vdupq_n_s16( COEFFS_ROUNDVAL );

    // Set up row pointers
    int ypos = yy>>16;
    uint64_t *sp0 = (uint64_t *)(src + src_stride * clamp(ypos,   0, srch));
    uint64_t *sp1 = (uint64_t *)(src + src_stride * clamp(ypos+1, 0, srch));
    uint64_t *sp2 = (uint64_t *)(src + src_stride * clamp(ypos+2, 0, srch));
    uint64_t *sp3 = (uint64_t *)(src + src_stride * clamp(ypos+3, 0, srch));

    uint8_t *dst0 = dst;

    // Fetch 8-bit coeffs, convert to 16x4
    int16x4_t vco0w = vget_low_s16( vmovl_s8( s8_u64(vcreate_u64( *(yco + ((((uint32_t)yy)>>10)&63)) )) ) );

    for( int x = 0; x < srcw; x += 32 ) {
        int16x8_t res0, res1, res2, res3;

        // Row 0-3: fetch
        uint64x1x4_t in0 = vld4_u64( sp0 ); sp0 += 4;
        uint64x1x4_t in1 = vld4_u64( sp1 ); sp1 += 4;
        uint64x1x4_t in2 = vld4_u64( sp2 ); sp2 += 4;
        uint64x1x4_t in3 = vld4_u64( sp3 ); sp3 += 4;

        // Row 0: Expand & multiply
        res0 = vmulq_lane_s16( qs16_u16( vmovl_u8( u8_u64(in0.val[0]) ) ), vco0w, 0 );
        res1 = vmulq_lane_s16( qs16_u16( vmovl_u8( u8_u64(in0.val[1]) ) ), vco0w, 0 );
        res2 = vmulq_lane_s16( qs16_u16( vmovl_u8( u8_u64(in0.val[2]) ) ), vco0w, 0 );
        res3 = vmulq_lane_s16( qs16_u16( vmovl_u8( u8_u64(in0.val[3]) ) ), vco0w, 0 );

        // Row 1: Expand & fma
        res0 = vmlaq_lane_s16( res0, qs16_u16( vmovl_u8( u8_u64(in1.val[0]) ) ), vco0w, 1 );
        res1 = vmlaq_lane_s16( res1, qs16_u16( vmovl_u8( u8_u64(in1.val[1]) ) ), vco0w, 1 );
        res2 = vmlaq_lane_s16( res2, qs16_u16( vmovl_u8( u8_u64(in1.val[2]) ) ), vco0w, 1 );
        res3 = vmlaq_lane_s16( res3, qs16_u16( vmovl_u8( u8_u64(in1.val[3]) ) ), vco0w, 1 );

        // Row 2: Etc.
        res0 = vmlaq_lane_s16( res0, qs16_u16( vmovl_u8( u8_u64(in2.val[0]) ) ), vco0w, 2 );
        res1 = vmlaq_lane_s16( res1, qs16_u16( vmovl_u8( u8_u64(in2.val[1]) ) ), vco0w, 2 );
        res2 = vmlaq_lane_s16( res2, qs16_u16( vmovl_u8( u8_u64(in2.val[2]) ) ), vco0w, 2 );
        res3 = vmlaq_lane_s16( res3, qs16_u16( vmovl_u8( u8_u64(in2.val[3]) ) ), vco0w, 2 );

        // Row 3
        res0 = vmlaq_lane_s16( res0, qs16_u16( vmovl_u8( u8_u64(in3.val[0]) ) ), vco0w, 3 );
        res1 = vmlaq_lane_s16( res1, qs16_u16( vmovl_u8( u8_u64(in3.val[1]) ) ), vco0w, 3 );
        res2 = vmlaq_lane_s16( res2, qs16_u16( vmovl_u8( u8_u64(in3.val[2]) ) ), vco0w, 3 );
        res3 = vmlaq_lane_s16( res3, qs16_u16( vmovl_u8( u8_u64(in3.val[3]) ) ), vco0w, 3 );

        // Round and pack
        uint8x8_t r0 = vqshrun_n_s16( vaddq_s16( res0, rv ), 6 );
        uint8x8_t r1 = vqshrun_n_s16( vaddq_s16( res1, rv ), 6 );
        uint8x8_t r2 = vqshrun_n_s16( vaddq_s16( res2, rv ), 6 );
        uint8x8_t r3 = vqshrun_n_s16( vaddq_s16( res3, rv ), 6 );

        // Store result
        vst1_u8( dst,       r0 );
        vst1_u8( dst +  64, r1 );
        vst1_u8( dst + 128, r2 );
        vst1_u8( dst + 192, r3 );
        dst += 256;
    }

    // pad left
    vst1_u64( (uint64_t *)(dst0  -64), u64_u8(vdup_n_u8( dst0[0] )) );
    vst1_u64( (uint64_t *)(dst0-8-64), u64_u8(vdup_n_u8( dst0[8] )) );

    // pad right
    uint8_t pix;
    pix = *(dst0+  GETOFFSET(srcw-1)); *(dst0+  GETOFFSET(srcw+0)) = pix; *(dst0+  GETOFFSET(srcw+1)) = pix;
    pix = *(dst0+8+GETOFFSET(srcw-1)); *(dst0+8+GETOFFSET(srcw+0)) = pix; *(dst0+8+GETOFFSET(srcw+1)) = pix;
}

// 8 output dst cols at a time
static void horiz_interpolate_8_8( uint8_t *src, int xadd, int xoffset, uint32_t *xcoeffs, int width, uint8_t *dst, int dst_stride )
{
    int xx = xoffset + (-1<<16);

    int16x8_t rv = vdupq_n_s16( COEFFS_ROUNDVAL );

    for( int x = 0; x < width; x += 8 ) {
        int16x8_t res[8];
        int16x8_t in0w, in1w, in2w, in3w;
        int16x4_t hco0w;

        for( int i = 0; i < 8; i += 4 ) {

            // Fetch data and raw coeffs
            uint64x1x4_t in0 = vld4_u64( (uint64_t *)(src + ((xx>>16)<<3)) ); uint32_t hco0 = *(xcoeffs + ((((uint32_t)xx)>>10)&63)); xx += xadd;
            uint64x1x4_t in1 = vld4_u64( (uint64_t *)(src + ((xx>>16)<<3)) ); uint32_t hco1 = *(xcoeffs + ((((uint32_t)xx)>>10)&63)); xx += xadd;
            uint64x1x4_t in2 = vld4_u64( (uint64_t *)(src + ((xx>>16)<<3)) ); uint32_t hco2 = *(xcoeffs + ((((uint32_t)xx)>>10)&63)); xx += xadd;
            uint64x1x4_t in3 = vld4_u64( (uint64_t *)(src + ((xx>>16)<<3)) ); uint32_t hco3 = *(xcoeffs + ((((uint32_t)xx)>>10)&63)); xx += xadd;

            // Col 0/4: Convert coeffs, expand & mul/fma
            hco0w = vget_low_s16( vmovl_s8( s8_u64(vcreate_u64( hco0 )) ) );
            in0w = qs16_u16( vmovl_u8( u8_u64(in0.val[0]) ) );
            in1w = qs16_u16( vmovl_u8( u8_u64(in0.val[1]) ) );
            in2w = qs16_u16( vmovl_u8( u8_u64(in0.val[2]) ) );
            in3w = qs16_u16( vmovl_u8( u8_u64(in0.val[3]) ) );
            res[i+0] = vmulq_lane_s16(           in0w, hco0w, 0 );
            res[i+0] = vmlaq_lane_s16( res[i+0], in1w, hco0w, 1 );
            res[i+0] = vmlaq_lane_s16( res[i+0], in2w, hco0w, 2 );
            res[i+0] = vmlaq_lane_s16( res[i+0], in3w, hco0w, 3 );

            // Col 1/5: Convert coeffs, expand & mul/fma
            hco0w = vget_low_s16( vmovl_s8( s8_u64(vcreate_u64( hco1 )) ) );
            in0w = qs16_u16( vmovl_u8( u8_u64(in1.val[0]) ) );
            in1w = qs16_u16( vmovl_u8( u8_u64(in1.val[1]) ) );
            in2w = qs16_u16( vmovl_u8( u8_u64(in1.val[2]) ) );
            in3w = qs16_u16( vmovl_u8( u8_u64(in1.val[3]) ) );
            res[i+1] = vmulq_lane_s16(           in0w, hco0w, 0 );
            res[i+1] = vmlaq_lane_s16( res[i+1], in1w, hco0w, 1 );
            res[i+1] = vmlaq_lane_s16( res[i+1], in2w, hco0w, 2 );
            res[i+1] = vmlaq_lane_s16( res[i+1], in3w, hco0w, 3 );

            // Col 2/6: Etc.
            hco0w = vget_low_s16( vmovl_s8( s8_u64(vcreate_u64( hco2 )) ) );
            in0w = qs16_u16( vmovl_u8( u8_u64(in2.val[0]) ) );
            in1w = qs16_u16( vmovl_u8( u8_u64(in2.val[1]) ) );
            in2w = qs16_u16( vmovl_u8( u8_u64(in2.val[2]) ) );
            in3w = qs16_u16( vmovl_u8( u8_u64(in2.val[3]) ) );
            res[i+2] = vmulq_lane_s16(           in0w, hco0w, 0 );
            res[i+2] = vmlaq_lane_s16( res[i+2], in1w, hco0w, 1 );
            res[i+2] = vmlaq_lane_s16( res[i+2], in2w, hco0w, 2 );
            res[i+2] = vmlaq_lane_s16( res[i+2], in3w, hco0w, 3 );

            // Col 3/7
            hco0w = vget_low_s16( vmovl_s8( s8_u64(vcreate_u64( hco3 )) ) );
            in0w = qs16_u16( vmovl_u8( u8_u64(in3.val[0]) ) );
            in1w = qs16_u16( vmovl_u8( u8_u64(in3.val[1]) ) );
            in2w = qs16_u16( vmovl_u8( u8_u64(in3.val[2]) ) );
            in3w = qs16_u16( vmovl_u8( u8_u64(in3.val[3]) ) );
            res[i+3] = vmulq_lane_s16(           in0w, hco0w, 0 );
            res[i+3] = vmlaq_lane_s16( res[i+3], in1w, hco0w, 1 );
            res[i+3] = vmlaq_lane_s16( res[i+3], in2w, hco0w, 2 );
            res[i+3] = vmlaq_lane_s16( res[i+3], in3w, hco0w, 3 );
        }

        // Round and pack
        uint8x16_t r01 = vcombine_u8( vqshrun_n_s16( vaddq_s16( res[0], rv ), 6 ), vqshrun_n_s16( vaddq_s16( res[1], rv ), 6 ) );
        uint8x16_t r23 = vcombine_u8( vqshrun_n_s16( vaddq_s16( res[2], rv ), 6 ), vqshrun_n_s16( vaddq_s16( res[3], rv ), 6 ) );
        uint8x16_t r45 = vcombine_u8( vqshrun_n_s16( vaddq_s16( res[4], rv ), 6 ), vqshrun_n_s16( vaddq_s16( res[5], rv ), 6 ) );
        uint8x16_t r67 = vcombine_u8( vqshrun_n_s16( vaddq_s16( res[6], rv ), 6 ), vqshrun_n_s16( vaddq_s16( res[7], rv ), 6 ) );

        // Undo transpose
        uint8x16x2_t tr0 = vzipq_u8( r01, r23 );
        uint8x16x2_t tr1 = vzipq_u8( r45, r67 );
        uint8x16x2_t tr2 = vzipq_u8( tr0.val[0], tr0.val[1] );
        uint8x16x2_t tr3 = vzipq_u8( tr1.val[0], tr1.val[1] );
        uint32x4x2_t tr4 = vzipq_u32( qu32_u8(tr2.val[0]), qu32_u8(tr3.val[0]) );
        uint32x4x2_t tr5 = vzipq_u32( qu32_u8(tr2.val[1]), qu32_u8(tr3.val[1]) );

        // Store final result
        uint8_t *tmp = dst;
        vst1_u32( (uint32_t *)tmp, vget_low_u32(  tr4.val[0] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_high_u32( tr4.val[0] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_low_u32(  tr4.val[1] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_high_u32( tr4.val[1] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_low_u32(  tr5.val[0] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_high_u32( tr5.val[0] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_low_u32(  tr5.val[1] ) ); tmp += dst_stride;
        vst1_u32( (uint32_t *)tmp, vget_high_u32( tr5.val[1] ) ); tmp += dst_stride;

        dst += 8;
    }

}

void scale_plane( uint8_t *src, int srcw, int srch, int srcstr,
                  uint8_t *dst, int dstw, int dsth, int dststr,
                  uint32_t *xco, uint32_t *yco )
{
    int xadd = (srcw<<16)/dstw;
    int yadd = (srch<<16)/dsth;

    int xoffset = (xadd>>1)-(1<<15);
    int yoffset = (yadd>>1)-(1<<15);

    int yy = yoffset + (-1<<16);

    uint8_t *buf = (uint8_t *)alloca_aligned( (srcstr+8+8)*8, 16 );
    buf += 64;

    for( int y = 0; y < dsth; y += 8 ) {

        for( int i = 0; i < 8; i++ ) {

            vert_interpolate_4to1_8( src, srcw, srch-1, srcstr, buf + i*8, yco, yy );

            yy += yadd;

        }

        h2v( buf-64, srcw+16 );

        horiz_interpolate_8_8( buf, xadd, xoffset, xco, dstw, dst + y * dststr, dststr );

    }

}
