
#include "fast_math_funcs.h"

#ifdef __SSE__
#include <xmmintrin.h>
#endif

#include <Eigen/Dense>

#define SQRT2PI_INV 0.3989422804014327f
#define absmin(a, b) (fabs(a) < fabs(b) ? (a) : (b))

/**
 * @brief Pseudo voigt implementation
 */
static inline float pseudo_voigt(float x, float x0, float sigma, float gamma) {
    float scale = 1.0f / (0.5346f * gamma + fasterpow(0.2166f * gamma * gamma + 1.3863f * sigma * sigma, 0.5));
    float eta = std::min(1.0f, gamma * scale);

    float d2 = (x - x0) * scale;
    d2 *= d2;
#define COEF_GAUSSIAN_1 0.33883037580155256f
#define COEF_GAUSSIAN_2 0.36067376022224085f
    return eta * scale / (M_PI * (1.0f + d2)) + (1.0f - eta) * COEF_GAUSSIAN_1 * scale * /*expf*/ fasterexp (- d2 * COEF_GAUSSIAN_2);
}

/**
 * @brief Cauchy probability density function.
 */
static inline float cauchy(float x, float location, float scale) {
    float z = (x - location) / scale;
    float zz = 1 + z * z;
    return 1.0f / (M_PI * scale * zz);
}

/**
 * @brief Gaussian probability density function.
 */
static inline float gaussian(float x, float mean, float inv_std) {
    //return (SQRT2PI_INV * inv_std) * fasterexp(- 0.5f * inv_std * inv_std * powf(x - mean, 2));
    return (SQRT2PI_INV * inv_std) * fasterexp(- 0.5f * inv_std * inv_std * (x - mean) * (x - mean));
}


static inline float pdf_2_z(float zp, float zq, float up, float uq, float loc, float scale) {
    return pseudo_voigt(zp, zq * loc, fasterpow(up * up + uq * uq, 0.5f), zq * scale);
}

static inline float pdf_3_z(float zp, float zq, float zr, float up, float uq, float ur, int distance, float loc, float scale) {
    // Expansion and simplication of the matrix form for 3 ranges

    // 1. Compute zr from zp and zq and the noise on this estimation
    float zr_pq = zq + distance * (zq - zp);
    float ur_pq = powf((1 + distance) * uq, 2)+ powf(distance * up, 2);

    // 2. Compute zp from zq and the previous estimation of zr and the noise on
    //    this estimation
    float zp_qr = zq + (zq - zr_pq) / distance;
    float up_qr = powf ((1 + 1.0f/distance) * up,2);

    // 3. Weighted average of zp and its previous estimation
    float zp_pqr = (up_qr * zp + powf(up,2) * zp_qr) / (up_qr + powf(up,2));

    if (zp_pqr <= 0 || zr_pq <= 0 || zr == 0 || !std::isfinite(zp_pqr))
        return 0;

    // 4. Compute the sigma that will be used in the Voigt
    float u_voigt = (up_qr * (up * up) / (up_qr + (up * up))) + (((ur * ur) * ur_pq) / (ur * ur + ur_pq));

    // 5. The pdf is a Gaussian multiplied by a Voigt
    float val = distance * (1 + 1.0f / distance) * gaussian(zr, zr_pq, fasterpow(ur * ur + ur_pq, 0.5)) *
        pseudo_voigt(zp_pqr, zr * loc, fasterpow(u_voigt, 0.5), zr * scale);

    //printf ("[Classic] zr_pq = %f, ur_pq = %f, zp_qr = %f, zp_pqr = %f, u_voigt = %f, gaussian = %f, voigt = %f\n", zr_pq, ur_pq, zp_qr, zp_pqr, u_voigt, gaussian(zr, zr_pq, fasterpow(ur * ur + ur_pq, 0.5)), pseudo_voigt(zp_pqr, zr * loc, fasterpow(u_voigt, 0.5), zr * scale));

    return val;
}

template <int N>
static inline float pdf_N_z(Eigen::Ref<Eigen::Matrix<float, N, 1> > ranges, Eigen::Ref<Eigen::Matrix<float, N, 1> > noises, float loc, float scale, Eigen::Ref<Eigen::Matrix<float, N, 2> > A) {
    Eigen::Matrix<float, N, 2> An = noises.asDiagonal() * A;
    Eigen::Matrix<float, 2, 2> Lambda = An.transpose() * An;
    float det_Lambda = Lambda(0, 0) * Lambda(1, 1) - Lambda(1, 0) * Lambda(0, 1);
    // We can perform matrix inversion directly
    Eigen::Matrix<float, 2, 2> Lambda_inv;
    Lambda_inv(0, 0) = Lambda(1, 1);
    Lambda_inv(0, 1) = - Lambda(0, 1);
    Lambda_inv(1, 0) = - Lambda(1, 0);
    Lambda_inv(1, 1) = Lambda(0, 0);

    Eigen::Vector2f z0N_ols = (Lambda_inv * An.transpose() * noises.asDiagonal() * ranges) / det_Lambda;

    float gaussian_argument = ((noises.asDiagonal() * ranges - An * z0N_ols)).squaredNorm();

    if (z0N_ols(1) < 0.0f)
        return 0.0f;

#define SQRT2PI_INV2 0.15915494309189535f
    float val = SQRT2PI_INV2 * fasterpow(det_Lambda, -0.5) * noises.prod()  * fasterexp(-0.5 * gaussian_argument) *
        pseudo_voigt(z0N_ols(0), z0N_ols(1) * loc, fasterpow((1.0f / Lambda(0,0)) + (loc * Lambda(0, 0) + Lambda(0, 1)) / det_Lambda, 0.5), z0N_ols(1) * scale);

    return val;
}

static inline float pdf_4_z(float zo, float zp, float zq, float zr, float uo, float up, float uq, float ur, float loc, float scale, Eigen::Ref<Eigen::Matrix<float, 4, 2> > A) {
    Eigen::Vector4f ranges(zo, zp, zq, zr);
    Eigen::Vector4f noises(1.0f/uo, 1.0f/up, 1.0f/uq, 1.0f/ur);

    return pdf_N_z<4>(ranges, noises, loc, scale, A);
}

#ifdef __SSE__
/**
 * @brief Pseudo voigt implementation (SSE).
 */
static inline __m128 pseudo_voigt_sse(__m128 x, __m128 x0, __m128 sigma, __m128 gamma) {
    static __m128 one_const = _mm_set1_ps(1.0f);
    static __m128 half_const = _mm_set1_ps(0.5f);
    static __m128 minus_half_const = _mm_set1_ps(-0.5f);
    static __m128 pi_const = _mm_set1_ps(M_PI);
    static __m128 gaussian_coef_1_const = _mm_set1_ps(COEF_GAUSSIAN_1);
    static __m128 gaussian_coef_2_const = _mm_set1_ps(-COEF_GAUSSIAN_2);
    static __m128 scale_coef_1_const = _mm_set1_ps(0.5346f);
    static __m128 scale_coef_2_const = _mm_set1_ps(0.2166f);
    static __m128 scale_coef_3_const = _mm_set1_ps(1.3863f);

    __m128 scale = _mm_rcp_ps (_mm_add_ps (
                _mm_mul_ps(scale_coef_1_const, gamma),
                vfasterpow ( _mm_add_ps(
                        _mm_mul_ps(scale_coef_2_const, _mm_mul_ps(gamma, gamma)),
                        _mm_mul_ps(scale_coef_3_const, _mm_mul_ps(sigma, sigma))),
                    half_const)
            ));
    __m128 eta = _mm_min_ps(one_const, _mm_mul_ps(gamma, scale));

    __m128 d2 = _mm_mul_ps(_mm_sub_ps(x, x0), scale);
    d2 = _mm_mul_ps(d2, d2);

    return _mm_add_ps(_mm_div_ps(_mm_mul_ps(eta, scale), _mm_mul_ps(pi_const, _mm_add_ps(one_const, d2))),
                      _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(_mm_sub_ps(one_const, eta), gaussian_coef_1_const), scale),
                      vfasterexp(_mm_mul_ps(gaussian_coef_2_const, d2))));
}

/**
 * @brief Cauchy probability density function (SSE).
 */
static inline __m128 cauchy_sse(__m128 x, __m128 location, __m128 scale) {
    __m128 z = _mm_div_ps(_mm_sub_ps(x, location), scale);
    __m128 one_const = _mm_set1_ps(1.0f);
    __m128 pi_const = _mm_set1_ps(M_PI);
    __m128 zz = _mm_add_ps(one_const, _mm_mul_ps(z, z));
    return _mm_rcp_ps(_mm_mul_ps(pi_const, _mm_mul_ps(scale, zz)));
}

/**
 * @brief Gaussian probability density function (SSE).
 */
static inline __m128 gaussian_sse(__m128 x, __m128 mean, __m128 inv_std) {
    //return (SQRT2PI_INV * inv_std) * fasterexp(- 0.5f * inv_std * inv_std * powf(x - mean, 2));
    __m128 half_const = _mm_set1_ps(-0.5f);
    __m128 sqrt2piinv_const = _mm_set1_ps(SQRT2PI_INV);
    __m128 diff = _mm_sub_ps(x, mean);
    __m128 operand = _mm_mul_ps(half_const, _mm_mul_ps(inv_std, _mm_mul_ps(inv_std, _mm_mul_ps(diff, diff))));
    return _mm_mul_ps(_mm_mul_ps(sqrt2piinv_const, inv_std), vfasterexp(operand));
}

/**
 * @brief Absolute value in SSE
 */
inline __m128 abs_ps(__m128 x) {
    static const __m128 sign_mask = _mm_set1_ps(-0.f); // -0.f = 1 << 31
    return _mm_andnot_ps(sign_mask, x);
}

static inline v4sf pdf_2_z(v4sf zp, v4sf zq, v4sf up, v4sf uq, v4sf loc, v4sf scale) {
    return pseudo_voigt_sse(zp, zq * loc, vfasterpow(up * up + uq * uq, v4sfl(0.5)), zq * scale);
}

static inline v4sf pdf_3_z(v4sf zp, v4sf zq, v4sf zr, v4sf up, v4sf uq, v4sf ur, v4sf distance, v4sf loc, v4sf scale) {
    // Expansion and simplication of the matrix form for 3 ranges

    v4sf dist_plus_1 = _mm_set1_ps(1.0f) + distance;
    v4sf rcp_dist_plus_1 = _mm_set1_ps(1.0f) + _mm_rcp_ps(distance);

    // 1. Compute zr from zp and zq and the noise on this estimation
    v4sf zr_pq = zq + distance * (zq - zp);

    v4sf ur_pq = (dist_plus_1 * uq) * (dist_plus_1 * uq) + (distance * up) * (distance * up);

    // 2. Compute zp from zq and the previous estimation of zr and the noise on
    //    this estimation
    v4sf zp_qr = zq + (zq - zr_pq) / distance;
    v4sf up_qr = (rcp_dist_plus_1 * up) * (rcp_dist_plus_1 * up);

    // 3. Weighted average of zp and its previous estimation
    v4sf zp_pqr = (up_qr * zp + up * up * zp_qr) / (up_qr + up * up);

    //if (zp_pqr < 0)
    //    return 0;

    // 4. Compute the sigma that will be used in the Voigt
    v4sf u_voigt = (up_qr * (up * up) / (up_qr + (up * up))) + (((ur * ur) * ur_pq) / (ur * ur + ur_pq));

    // 5. The pdf is a Gaussian multiplied by a Voigt
    v4sf val = distance * rcp_dist_plus_1 * gaussian_sse(zr, zr_pq, vfasterpow(ur * ur + ur_pq, _mm_set1_ps(0.5))) *
        pseudo_voigt_sse(zp_pqr, zr * loc, vfasterpow(u_voigt, _mm_set1_ps(0.5)), zr * scale);

    /*for (int j = 0; j < 4; j++) {
        float val2 = pdf_3_z(zq[j], zp[j], zr[j], uq[j], up[j], ur[j], distance[j], loc[j], scale[j]);
        if (fabs(val2 - val[j]) > 1e-6) {
            printf ("(%d) Different results! %f >< %f\n", j, val2, val[j]);
            printf ("     zr_pq = %f, ur_pq = %f, zp_qr = %f, zp_pqr = %f, u_voigt = %f, gaussian = %f, voigt = %f\n", zr_pq[j], ur_pq[j], zp_qr[j], zp_pqr[j], u_voigt[j], gaussian_sse(zr, zr_pq, vfasterpow(ur * ur + ur_pq, _mm_set1_ps(0.5)))[j], pseudo_voigt_sse(zp_pqr, zr * loc, vfasterpow(u_voigt, _mm_set1_ps(0.5)), zr * scale)[j]);
        }

    }*/

    return val;
}

static inline v4sf pdf_4_z(v4sf zo, v4sf zp, v4sf zq, v4sf zr, v4sf uo, v4sf up, v4sf uq, v4sf ur, v4sf loc, v4sf scale, Eigen::Ref<Eigen::Matrix<float, 4, 2> > A) {
    v4sf gaussian_arguments;
    v4sf noise_determinants;
    v4sf zo_ols;
    v4sf zr_ols;
    v4sf lambdas_00;
    v4sf lambdas_01;
    v4sf lambda_determinants;

    v4sf zo_ratios = zo / uo;
    v4sf zp_ratios = zp / up;
    v4sf zq_ratios = zq / uq;
    v4sf zr_ratios = zr / ur;

    v4sf uo_inv = _mm_rcp_ps(uo);
    v4sf up_inv = _mm_rcp_ps(up);
    v4sf uq_inv = _mm_rcp_ps(uq);
    v4sf ur_inv = _mm_rcp_ps(ur);

    Eigen::Matrix<float, 4, 2> An;
    Eigen::Vector4f ranges;
    for (int i = 0; i < 4; i++) {
        noise_determinants[i] = uo_inv[i] * up_inv[i] * uq_inv[i] * ur_inv[i];
        ranges << zo_ratios[i], zp_ratios[i], zq_ratios[i], zr_ratios[i];
        An << A(0, 0) / uo[i], A(0, 1) / uo[i],
              A(1, 0) / up[i], A(1, 1) / up[i],
              A(2, 0) / uq[i], A(2, 1) / uq[i],
              A(3, 0) / ur[i], A(3, 1) / ur[i];

        Eigen::Matrix<float, 2, 2> Lambda = An.transpose() * An;
        lambda_determinants[i] = Lambda(0, 0) * Lambda(1, 1) - Lambda(1, 0) * Lambda(0, 1);

        Eigen::Matrix<float, 2, 2> Lambda_inv;
        Lambda_inv(0, 0) = Lambda(1, 1);
        Lambda_inv(0, 1) = - Lambda(0, 1);
        Lambda_inv(1, 0) = - Lambda(1, 0);
        Lambda_inv(1, 1) = Lambda(0, 0);

        lambdas_00[i] = Lambda(0, 0);
        lambdas_01[i] = Lambda(0, 1);

        Eigen::Vector2f zor_ols = (Lambda_inv * An.transpose() * ranges) / lambda_determinants[i];
        zo_ols[i] = zor_ols(0);
        zr_ols[i] = zor_ols(1);

        gaussian_arguments[i] = -0.5 * (ranges - An * zor_ols).squaredNorm();
    }


#define SQRT2PI_INV2 0.15915494309189535f
    v4sf val = v4sfl(SQRT2PI_INV2) * vfasterpow(lambda_determinants, v4sfl(-0.5)) * noise_determinants  * vfasterexp(gaussian_arguments) *
        pseudo_voigt_sse(zo_ols, zr_ols * loc, vfasterpow((_mm_rcp_ps(lambdas_00)) + (loc * lambdas_00 + lambdas_01) / lambda_determinants, v4sfl(0.5)), zr_ols * scale);

    return val;
}

#endif

enum CauchyParameters {
    p_o = 0,
    p_q = 1,
    q_r = 2,
    q_o = 3,
    p_r = 4,
    o_r = 5,
    r_r = 6,
};

#define setup_locations_scales_pointers() float *locations[2][7] = { \
        { \
            (float *)loc_po[0].data, \
            (float *)loc_pq[0].data, \
            (float *)loc_qr[0].data, \
            (float *)loc_qo[0].data, \
            (float *)loc_pr[0].data, \
            (float *)loc_or[0].data, \
            (float *)loc_rr[0].data \
        }, \
        { \
            (float *)loc_po[1].data, \
            (float *)loc_pq[1].data, \
            (float *)loc_qr[1].data, \
            (float *)loc_qo[1].data, \
            (float *)loc_pr[1].data, \
            (float *)loc_or[1].data, \
            (float *)loc_rr[1].data \
        } \
    }; \
    float *scales[2][7] = { \
        { \
            (float *)scale_po[0].data, \
            (float *)scale_pq[0].data, \
            (float *)scale_qr[0].data, \
            (float *)scale_qo[0].data, \
            (float *)scale_pr[0].data, \
            (float *)scale_or[0].data, \
            (float *)scale_rr[0].data \
        }, \
        { \
            (float *)scale_po[1].data, \
            (float *)scale_pq[1].data, \
            (float *)scale_qr[1].data, \
            (float *)scale_qo[1].data, \
            (float *)scale_pr[1].data, \
            (float *)scale_or[1].data, \
            (float *)scale_rr[1].data \
        } \
    };
