
#include <opencv2/imgproc/imgproc.hpp>
#include <statistical_edges.h>

#include "pdfs.h"

using namespace statistical_edges;

cv::Mat1b StaticticalDepthEdgeDetector::detect1(cv::Mat1f &depth, cv::Mat1f &noise_std) {
    if (!initialized())
        initialize(size, intrinsic_matrix);

    double edge_detection_start = (double)cv::getTickCount();

    cv::Mat1b edges = cv::Mat1b::zeros(depth.size());

    // We create a border of the size of the spatial support "distance"
    // with values set at the maximum possible float. This will make surface characteristics
    // estimated on the border more unlikely than characteristics estimated on the actual
    // image depth values.
    cv::Mat1f tmp(depth.rows + 2 * distance, depth.cols + 2 * distance);
    cv::copyMakeBorder(depth, tmp,
                       distance, distance, distance, distance,
                       cv::BORDER_CONSTANT, cv::Scalar(FLT_MAX));
    cv::Rect roi(distance, distance, depth.cols, depth.rows);
    cv::Mat1f depth_padded = tmp(roi);

    // Standard deviation of the noise padded with 0.
    cv::Mat1f tmp2(depth.rows + 2 * distance, depth.cols + 2 * distance);
    cv::copyMakeBorder(noise_std, tmp2,
                       distance, distance, distance, distance,
                       cv::BORDER_CONSTANT, cv::Scalar(0));
    cv::Mat1f noise_padded = tmp2(roi);

    // Compute min/max of depth map to determine the distribution of the depth of a single pixel
    float min = FLT_MAX;
    float max = -FLT_MAX;
    float *z_q_ptr = NULL;
    for (int row = depth.rows - 1; row >= 0; row--) {
        z_q_ptr = (float*)depth_padded.ptr(row);
        float *u_q_ptr = (float*)noise_padded.ptr(row);
        
        for (int col = depth.cols - 1; col >= 0; col--) {
            if (*z_q_ptr != 0.0f) {
                if (*z_q_ptr < min) min = *z_q_ptr;
                if (*z_q_ptr > max) max = *z_q_ptr;

                *u_q_ptr = *u_q_ptr / powf(*z_q_ptr, 2);
                *z_q_ptr = 1.0f / *z_q_ptr;
            }
            z_q_ptr++; u_q_ptr++;
        }
    }

#ifdef __SSE__
    __m128 one_const = _mm_set1_ps(1.0f);
    __m128 zero_const = _mm_set1_ps(0.0f);
    __m128 max_const = _mm_set1_ps(FLT_MAX);
    __m128 p_single_z = _mm_set1_ps(1.0f / (fasterlog(max) - fasterlog(min))); // Reciprocal distribution between min and max
    __m128 edge_prior_v = _mm_set1_ps(edge_prior);
    __m128 Spq_prior_v = _mm_set1_ps(Spq_prior);
    __m128 Jqr_prior_v = _mm_set1_ps(Jqr_prior);
    __m128 Sqr_prior_v = _mm_set1_ps(Sqr_prior);
    __m128 Sqr_Sop_prior_v = _mm_set1_ps(Sqr_Sop_prior);
    __m128 Sqr_Jop_prior_v = _mm_set1_ps(Sqr_Jop_prior);
    __m128 Jop_Jqr_prior_v = _mm_set1_ps(Jop_Jqr_prior);
    __m128 distance_v = _mm_set1_ps(distance);
    __m128 detection_threshold_v = _mm_set1_ps(detection_threshold);
#else
    // Reciprocal distribution between min and max.
    float p_single_z = 1.0f / (fasterlog(max) - fasterlog(min));
#endif


    /**
     * For each pixel, we consider two directions: WEST and NORTH.
     */
    static cv::Point2i dir[2] = {
        cv::Point2i(-1, 0),
        cv::Point2i(0, -1)
    };

    /**
     * Flags
     */
    static uint8_t direction_flags[2][2] = {
        { WEST, EAST },
        { NORTH, SOUTH }
    };

    // Detection loop

    // Pointers to data
    uint8_t *edge_ptr = edges.data;
    uint8_t *edge2_ptr[2] = {edge_ptr - 1, edge_ptr - edges.step1()};

    setup_locations_scales_pointers();

    for (int row = 0; row < depth.rows; row++) {

        // Pointers to row data
        z_q_ptr = (float *)depth_padded.ptr(row);
        float *u_q_ptr = (float *)noise_padded.ptr(row);

        float *z_p_ptr[2] = {z_q_ptr - 1, z_q_ptr - depth_padded.step1()}; // WEST and NORTH depth values
        float *u_p_ptr[2] = {u_q_ptr - 1, u_q_ptr - noise_padded.step1()};

        float *z_r_ptr[2] = {z_q_ptr + distance, z_q_ptr + distance * depth_padded.step1()};
        float *z_o_ptr[2] = {z_q_ptr - 1 - distance, z_q_ptr - (1 + distance) * depth_padded.step1()};
        float *u_r_ptr[2] = {u_q_ptr + distance, u_q_ptr + distance * noise_padded.step1()};
        float *u_o_ptr[2] = {u_q_ptr - 1 - distance, u_q_ptr - (1 + distance) * noise_padded.step1()};

#ifndef __SSE__
        // Unoptimized version
        for (int col = 0; col< depth.cols; col++) {
            float zq = *(z_q_ptr++);
            float uq = *(u_q_ptr++);

            float p_zq = zq == 0.0f ? 1 : p_single_z / zq;

            // For each directions, we compute the surface and edge probability and apply the decision rule
            for (int i = 0; i < 2; i++) {
                float zp = *(z_p_ptr[i]++);
                float up = *(u_p_ptr[i]++);

                float p_SS = 1.0f;
                float p_SJ = p_single_z;
                float p_JS = p_single_z;
                float p_JJ = p_single_z * p_single_z;

               if (zq == 0.0f || zp == 0.0f) {
                    // Special case if we do not have any depth information! There is an edge only if
                    // d2 has a depth!
                    if (zp == zq) { p_SS = FLT_MAX; } // We are on a surface
                    else { p_SS = p_SJ = 0.0f; } // We are not on the same surface!
                }
                else {
                    // We use the range zo or zr that is the closest to zp or zq resp.
                    if (fabs(zq - *z_r_ptr[i]) < fabs(zp - *z_o_ptr[i])) {
                        // We chose (zp, zq, zr)
                        float zr = *(z_r_ptr[i]);
                        float ur = *(u_r_ptr[i]);

                        // p_SS = p(zp,zq|zr) p(zr)
                        p_SS *= pdf_3_z(zp, zq, zr, up, uq, ur, distance, *locations[i][p_r], *scales[i][p_r]);

                        // p_JS = p(zp) p(zq|zr) p(zr)
                        p_JS *= (1 / zp) * pdf_2_z(zq, zr, uq, ur, *locations[i][q_r], *scales[i][q_r]);

                        // p_JJ = p(zp) p(zq) p(zr)
                        p_JJ /= (zp * zq);
                    }
                    else {
                        // We chose (zo, zp, zq)
                        float zo = *(z_o_ptr[i]);
                        float uo = *(u_o_ptr[i]);

                        // p_SS = p(zp, zq | zo) p(zo)
                        p_SS *= pdf_3_z(zq, zp, zo, uq, up, uo, distance, *locations[i][q_o], *scales[i][q_o]);

                        // p_JS = p(zq) p(zp|zo) p(zo)
                        p_JS *= pdf_2_z(zp, zo, up, uo, *locations[i][p_o], *scales[i][p_o]) / zq;

                        // p_JJ = p(zp) p(zq) p(zo)
                        p_JJ /= (zp * zq);
                    }

                    // p_SJ = [ p(zp|zq) p(zq) ] p(z0 or zr)
                    p_SJ *= pdf_2_z(zp, zq, up, uq, *locations[i][p_q], *scales[i][p_q]) / zq;
                }

                float f_Spq = (p_SS * Sqr_prior + p_SJ * Jqr_prior) * Spq_prior;
                float f_Jpq = (p_JS * Sqr_prior + p_JJ * Jqr_prior) * edge_prior;

                float p_Spq = f_Spq / (f_Spq + f_Jpq);

                // Bayes decision rule
                if (p_Spq < detection_threshold) {
                    // Edge detected!
                    *edge_ptr |= direction_flags[i][0];

                    // If the other pixel is within image boundaries, we set its corresponding
                    // flags into the edge map.
                    if (!((i == 0 && col == 0) || (i == 1 && row == 0))) {
                        *edge2_ptr[i] |= direction_flags[i][1];

                        if (zq > zp) {
                            // d2 is closer than d1 and thus, d2 occludes d1.
                            *edge2_ptr[i] |= direction_flags[i][1] << 4;
                        }
                        else {
                            // d1 is closer than d2 and thus, d1 occludes d2.
                            *edge_ptr |= direction_flags[i][0] << 4;
                        }
                    }
                }
                edge2_ptr[i]++;

                for (int j = 0; j < 6; j++) {
                    locations[i][j]++;
                    scales[i][j]++;
                }

                z_o_ptr[i]++;
                z_r_ptr[i]++;
                u_o_ptr[i]++;
                u_r_ptr[i]++;
            }

            edge_ptr++;
        }
#else
        // Optimized SSE version
        for (int col = 0; col< depth.cols; col += 4) {
            __m128 zq = _mm_load_ps(z_q_ptr);
            __m128 uq = _mm_load_ps(u_q_ptr);
            __m128 zq_valid = _mm_cmpneq_ps(zq, zero_const);
            z_q_ptr += 4;
            u_q_ptr += 4;

            // For each direction
            for (int i = 0; i < 2; i++) {
                __m128 zp = _mm_loadu_ps(z_p_ptr[i]); z_p_ptr[i] += 4;
                __m128 up = _mm_loadu_ps(u_p_ptr[i]); u_p_ptr[i] += 4;
                __m128 zp_valid = _mm_cmpneq_ps(zp, zero_const);
                __m128 zpzq_valid = _mm_and_ps(zp_valid, zq_valid);
                __m128 zpzq_nonvalid = _mm_cmpeq_ps(_mm_or_ps(zp_valid, zq_valid), zero_const);

                __m128 zo = _mm_loadu_ps(z_o_ptr[i]); z_o_ptr[i] += 4;
                __m128 uo = _mm_loadu_ps(u_o_ptr[i]); u_o_ptr[i] += 4;
                __m128 zr = _mm_loadu_ps(z_r_ptr[i]); z_r_ptr[i] += 4;
                __m128 ur = _mm_loadu_ps(u_r_ptr[i]); u_r_ptr[i] += 4;

                __m128 zo_valid = _mm_cmpneq_ps(zo, zero_const);
                __m128 zr_valid = _mm_cmpneq_ps(zr, zero_const);
                __m128 zozr_valid = _mm_and_ps(zo_valid, zr_valid);

                __m128 zo_closest = _mm_cmplt_ps(abs_ps(_mm_sub_ps(zp, zo)), abs_ps(_mm_sub_ps(zr, zq)));
                __m128 zr_closest = _mm_cmpge_ps(abs_ps(_mm_sub_ps(zp, zo)), abs_ps(_mm_sub_ps(zr, zq)));

                // Compute some needed pdfs (the terms f(zr) or f(zo) are not computed
                // because we cancel it in the decision rule)
                // f(zi) i = p, q
                __m128 f_p = p_single_z * _mm_rcp_ps(zp);
                __m128 f_q = p_single_z * _mm_rcp_ps(zq);

                __m128 f_i = _mm_or_ps(_mm_and_ps(zr_closest, f_p), _mm_and_ps(zo_closest, f_q));

                // f(zp, zq | S_{pq}) = f(zp | zq, S_{pq}) f(zq)
                __m128 f_pq = pdf_2_z(zp, zq, up, uq, _mm_load_ps(locations[i][p_q]), _mm_load_ps(scales[i][p_q])) * f_q;

                // f(zp, zo | S_{po}) = f(zp | zo, S_{po}) f(zo)
                __m128 f_xi = _mm_and_ps(zo_closest, pdf_2_z(zp, zo, up, uo, _mm_load_ps(locations[i][p_o]), _mm_load_ps(scales[i][p_o])));
                // f(zq, zr | S_{qr}) = f(zq | zr, S_{qr}) f(zr)
                f_xi = _mm_or_ps(f_xi, _mm_and_ps(zr_closest, pdf_2_z(zq, zr, uq, ur, _mm_load_ps(locations[i][q_r]), _mm_load_ps(scales[i][q_r]))));

                // f(zo, zp, zq | S_{po}, S_{pq}) = f(zp, zq | zo, S_{pq}) f(zo)
                __m128 f_pqi = _mm_and_ps(zo_closest, pdf_3_z(zq, zp, zo, uq, up, uo, distance_v, _mm_load_ps(locations[i][q_o]), _mm_load_ps(scales[i][q_o])));

                // f(zp, zq, zr | S_{pq}, S_{qr}) = f(zo) f(zp, zq | zr, S_{pr}) f(zr)
                f_pqi = _mm_or_ps(f_pqi, _mm_and_ps(zr_closest, pdf_3_z(zp, zq, zr, up, uq, ur, distance_v, _mm_load_ps(locations[i][p_r]), _mm_load_ps(scales[i][p_r]))));

                // Compute f(zp,zq,zr|S_{pq}) P(S_{pq}) (or for zo, zp, zq)
                //  = P(S_{pq}) *
                //    [ f(zp,zq,zr|S_{pq}, S_{qr}) P(S_{qr}) +
                //      f(zp,zq,zr|S_{pq}, J_{qr}) P(J_{qr}) ]
                __m128 f_Spq = Spq_prior_v *
                        (Sqr_prior_v * f_pqi +
                         Jqr_prior_v * f_pq);

                // Compute f(zp,zq,zr|J_{pq}) P(J_{pq}) (or for zo, zp, zq)
                //  = P(J_{pq}) *
                //    [ f(zp,zq,zr|J_{pq}, S_{qr}) P(S_{qr}) +
                //      f(zp,zq,zr|J_{pq}, J_{qr}) P(J_{qr}) ]
                __m128 f_Jpq = edge_prior_v *
                         (Sqr_prior_v * f_xi * f_i +
                          Jqr_prior_v * f_p * f_q);

                // P(S_{pq} | zo, zp, zq, zr) = f_Spq / (f_Spq + f_Jpq)
                __m128 p_Spq = f_Spq / (f_Spq + f_Jpq);

                //if (row > 200 && row < 220 && col > 200  && col < 220)
                //printf ("(%d, %d) p_Spq = %f %f %f %f (f_Spq = %f %f %f %f)\n", row, col, p_Spq[0], p_Spq[1], p_Spq[2], p_Spq[3], f_Spq[0], f_Spq[1], f_Spq[2], f_Spq[3]);

                p_Spq = _mm_and_ps(p_Spq, zpzq_valid);
                p_Spq = _mm_or_ps(p_Spq, _mm_and_ps(max_const, zpzq_nonvalid));

                //if (row > 200 && row < 220 && col > 200  && col < 220)
                //printf ("p_Spq2 = %f %f %f %f\n", p_Spq[0], p_Spq[1], p_Spq[2], p_Spq[3]);

                // Bayes decision rule
                __m128 comp = _mm_cmplt_ps(_mm_sub_ps(p_Spq, detection_threshold_v), zero_const);
                __m128 depth_comp = _mm_cmpgt_ps(zq, zp);

                for (int j = 0; j < 4; j++) {
                    if (comp[j] != 0xFFFFFFFF)
                        continue;

                    // Edge detected!
                    edge_ptr[j] |= direction_flags[i][0];

                    if (!((i == 0 && col == 0) || (i == 1 && row == 0))) {
                        edge2_ptr[i][j] |= direction_flags[i][1];
                        if (depth_comp[j] == 0xFFFFFFFF) {
                            edge2_ptr[i][j] |= direction_flags[i][1] << 4;
                        }
                        else {
                            edge_ptr[j] |= direction_flags[i][0] << 4;
                        }
                    }
                }

                for (int j = 0; j < 6; j++) {
                    locations[i][j] += 4;
                    scales[i][j] += 4;
                }

                edge2_ptr[i] += 4;
            }

            edge_ptr += 4;
        }
#endif
    }

    double edge_detection_time = ((double)cv::getTickCount() - edge_detection_start ) / cv::getTickFrequency();
    printf ("Edge detection took %fms\n", edge_detection_time * 1000);

    return edges;
}