// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

// templated kernels:
// need to `#define Wraps` to Wrap or NoWrap
// need to provide extractDist, stepUp, stepDown functions
// in _Wrap and _NoWrap variants

#define PASTER3(x,y,z) x ## _ ## y ## _ ## z
#define EVALUATOR3(x,y,z)  PASTER3(x,y,z)

#define WRAPS_FN(fn) FUNCTION_NAME_2(fn, Wraps)
#define WRAPS_DIST_FN(fn) EVALUATOR3(fn, Wraps, DistFn)

__device__ void WRAPS_DIST_FN(updateIfBetter)(int32_t x, int32_t y,
                                              uint32_t v,
                                              float* minDist, uint32_t* minV,
                                              PanoRegion region) {
  float dist = WRAPS_DIST_FN(extractDist)(x, y, v, region);
  if (dist < *minDist) {
    *minDist = dist;
    *minV = v;
  }
}

#ifndef GPU_CL_ARGS_WORKAROUND
#define _PanoRegion PanoRegion
#else
#define _PanoRegion float8
#endif
/**
 * In-place voronoi diagram. @a buf must have been prepared with voronoiInitKernel
 */
__global__ void WRAPS_DIST_FN(voronoiCompute)(global_mem uint32_t* dst,
                                              const global_mem uint32_t* src,
                                              _PanoRegion _region,
                                              uint32_t step) {
  const uint32_t x = get_global_id_x();
  const uint32_t y = get_global_id_y();

  //TODO: (perf improvements; minor since only called once)
  // * use shared memory
  // * divergence
  // * syncthreads(between reads ?)
  PanoRegion region = *(PanoRegion*)(&_region);
  uint32_t width = region.viewWidth;
  uint32_t height = region.viewHeight;
  const uint32_t coords = y * width + x;
  uint32_t minV = 0;

  if (x < width && y < height) {
    float minDist = MAX_FLOAT_DISTANCE;
    int32_t wrappedStepDown = WRAPS_FN(stepDown)(x, step, width);
    int32_t wrappedStepUp = WRAPS_FN(stepUp)(x, step, width);
    if (y >= step) {
      // Note: if wrappedStepDown == 0, we'll do several times the job, but it avoids thread divergence.
      WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords - step * width + wrappedStepDown], &minDist, &minV, region);
      WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords - step * width], &minDist, &minV, region);
      WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords - step * width + wrappedStepUp], &minDist, &minV, region);
    }
    WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords + wrappedStepDown], &minDist, &minV, region);
    WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords], &minDist, &minV, region);
    WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords + wrappedStepUp], &minDist, &minV, region);
    if (y + step < height) {
      WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords + step * width + wrappedStepDown], &minDist, &minV, region);
      WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords + step * width], &minDist, &minV, region);
      WRAPS_DIST_FN(updateIfBetter)(x, y, src[coords + step * width + wrappedStepUp], &minDist, &minV, region);
    }
  }
  sync_threads(CLK_GLOBAL_MEM_FENCE);
  if (x < width && y < height) {
    dst[coords] = minV;
  }
}

/**
 * Build a transition mask from to edt maps.
 * @param srcBlack distance map to black image, generated by edtComputeKernel.
 * @param srcWhite distance map to white image, generated by edtComputeKernel.
 * @param maxTransitionDistance maximum width of the transition / overlay.
 * @param power parameter of the p-norm that's used to calculate the transition. Should be >= 2.0 to use at least L2. Steeper transition with larger power.
 * On return, @a dst contains the mask.
 */
__global__ void WRAPS_DIST_FN(buildTransitionMask)(global_mem uint8_t * __restrict__ dst,
                                                   const global_mem uint32_t * __restrict__ srcBlack,
                                                   const global_mem uint32_t * __restrict__ srcWhite,
                                                   _PanoRegion _region,
                                                   float maxTransitionDistance,
                                                   float power) {
  const uint32_t x = get_global_id_x();
  const uint32_t y = get_global_id_y();
  PanoRegion region = *(PanoRegion*)(&_region);
  const uint32_t width = (uint32_t)region.viewWidth;
  const uint32_t height = (uint32_t)region.viewHeight;

  if (x < width && y < height) {
    const uint32_t coords = y * width + x;

    float blackDistance = WRAPS_DIST_FN(extractDist)(x, y, srcBlack[coords], region);
    float whiteDistance = WRAPS_DIST_FN(extractDist)(x, y, srcWhite[coords], region);

    const float totalTransitionDistance = (blackDistance + whiteDistance);

    // Unable to compute a proper transition and mask when black and white image
    // don't have any pixels set (thus still at max distance) or voronoi failed
    if (whiteDistance >= MAX_FLOAT_DISTANCE && blackDistance >= MAX_FLOAT_DISTANCE) {
      dst[coords] = 0;
      return;
    }

    if (maxTransitionDistance >= 0 && totalTransitionDistance > maxTransitionDistance) {
      float emptySpace = totalTransitionDistance - maxTransitionDistance;
      emptySpace /= 2.0f; /* both sides, have transition in middle */

      if (blackDistance < emptySpace) {
        dst[coords] = 255;
        return;
      } else if (whiteDistance < emptySpace) {
        dst[coords] = 0;
        return;
      }
      blackDistance -= emptySpace;
      whiteDistance -= emptySpace;
    }

    const float whiteDistPow = pow(whiteDistance, power);
    const float blackDistPow = pow(blackDistance, power);
    /**
     * A candidate function f(b,w) here should have:
     *  f(0,w) = -1
     *  f(b, 0) = 1
     *  f(inf, w) = 1
     *  f(b, inf) = -1
     */

    const float fBW = (whiteDistPow - blackDistPow) / (blackDistPow + whiteDistPow); /* [-1, 1] */
    dst[coords] = (uint8_t) (127.5f * (fBW + 1.0f));
  }
}

__global__ void WRAPS_DIST_FN(extractDistKernel)(global_mem uint8_t * __restrict__ dst,
                                                 global_mem const uint32_t * __restrict__ srcWhite,
                                                 uint32_t width, uint32_t height,
                                                 float maxTransitionDistance,
                                                 float power) {

  uint32_t x = get_global_id_x();
  uint32_t y = get_global_id_y();

  if (x < width && y < height) {
    uint32_t coords = y * width + x;
    PanoDimensions panoDim = {0, 0, 0, 0};
    PanoRegion region;
    region.panoDim = panoDim;
    region.viewLeft = 0;
    region.viewTop = 0;
    region.viewWidth = width;
    region.viewHeight = height;
    float whiteDistance = WRAPS_DIST_FN(extractDist)(x, y, srcWhite[coords], region);

    if (whiteDistance > maxTransitionDistance) {
      whiteDistance = 255;
    }
    else {
      whiteDistance = (whiteDistance / maxTransitionDistance) * 255;
    }

    float whiteDistPow = pow(whiteDistance, power);
    dst[coords] = (uint8_t) (min(whiteDistPow, 255.0f));
  }
}

#undef PASTER2
#undef EVALUATOR2
#undef PASTER3
#undef EVALUATOR3
#undef WRAPS_FN
#undef WRAPS_DIST_FN