// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

/**
 * CUDA implementation of 2-site generalized voronoi diagram computation using Jump Flooding.
 * See e.g. RONG GUODONG's PhD thesis:
 *   JUMP FLOODING ALGORITHM ON GRAPHICS HARDWARE AND ITS APPLICATIONS
 * for a mathematical description of the algorithm.
 */

#define MAX_FLOAT_DISTANCE 1.e18f

/**
 * Take a buffer and initialize it for use by voronoiKernel.
 * each pixel contains the coordinates of the nearest site pixel as
 * well as its color (for fast categorization after).
 * The two most sigificant bits are used for color.
 *  00 = unclassified
 *  10 = in black image
 *  11 = in white image
 *  01 = [unused]
 * The remaining bits are used for coordinates (row-major).
 * A pixel is unconquered when either in both in zero images.
 * Else it has the color of the image in which is falls.
 *
 * On entry, @a buf should contain the relevant portion of the setup image (bit i set if pixel in image i)
 * On return, @a buf contains a buffer to be used with voronoiComputeKernel.
 */
__global__ void voronoiInitKernel(global_mem uint32_t* buf, uint32_t width, uint32_t height, uint32_t blackMask,
                                  uint32_t whiteMask) {
  uint32_t coords = get_global_id_y() * width + get_global_id_x();
  uint32_t v = buf[coords];
  uint32_t inBlackImage = !!(v & blackMask);  // 0 or 1
  uint32_t inWhiteImage = !!(v & whiteMask);
  uint32_t inOnlyOneImage = (inBlackImage + inWhiteImage) & 1;

  // set to all 0s if unclassified
  buf[coords] = inOnlyOneImage * ((uint32_t)0x80000000 | (inWhiteImage << 30) | coords);
}

__device__ float3 toSphere(int2 coords, PanoRegion region) {
  // updateIfBetter may search outside pano width,
  // wrap back to valid x coord
  const int panoX = coords.x % region.panoDim.width;
  const int panoY = coords.y + region.viewTop;

  float2 uv = {(float)panoX, (float)panoY};
  uv.x -= region.panoDim.width / 2;
  uv.y -= region.panoDim.height / 2;
  uv.x /= region.panoDim.scaleX;
  uv.y /= region.panoDim.scaleY;
  return ErectToSphere(uv);
}

__device__ float compute_distSphere(int x1, int y1, int x2, int y2, PanoRegion region) {
  float3 sphere1 = toSphere(make_int2(x1, y1), region);
  float3 sphere2 = toSphere(make_int2(x2, y2), region);
  return length_vs(sphere2 - sphere1);
}

#define DistFn distSqr
#include "voronoiExtract.gpu.incl"
#undef DistFn

#define DistFn distSphere
#include "voronoiExtract.gpu.incl"
#undef DistFn

__device__ float extractDist_Wrap_distSqr(int32_t x, int32_t y, uint32_t v, PanoRegion region) {
  if (!(v & 0x80000000)) {
    return MAX_FLOAT_DISTANCE;  // max dist
  } else {
    int32_t sy = (v & 0x3fffffff) / region.viewWidth;
    int32_t sx = (v & 0x3fffffff) - region.viewWidth * sy;
    return min(min(compute_distSqr(x, y, sx, sy, region), compute_distSqr(x + region.viewWidth, y, sx, sy, region)),
               compute_distSqr(x - region.viewWidth, y, sx, sy, region));
  }
}

__device__ float extractDist_Wrap_distSphere(int32_t x, int32_t y, uint32_t v, PanoRegion region) {
  return extractDist_NoWrap_distSphere(x, y, v, region);
}

__device__ int32_t stepDown_Wrap(const uint32_t x, const uint32_t step, const uint32_t width) {
  int32_t wrappedStepDown = 0;
  if (x >= step) {
    wrappedStepDown = -step;
  } else if (x + width >= step) {
    wrappedStepDown = width - step;
  }
  return wrappedStepDown;
}

__device__ int32_t stepUp_Wrap(const uint32_t x, const uint32_t step, const uint32_t width) {
  int32_t wrappedStepUp = 0;
  if (x + step < width) {
    wrappedStepUp = step;
  } else if (x + step < 2 * width) {
    wrappedStepUp = step - width;
  }
  return wrappedStepUp;
}

__device__ int32_t stepDown_NoWrap(const uint32_t x, const uint32_t step, const uint32_t width) {
  int32_t stepDown = 0;
  if (x >= step) {
    stepDown = -step;
  }
  return stepDown;
}

__device__ int32_t stepUp_NoWrap(const uint32_t x, const uint32_t step, const uint32_t width) {
  int32_t stepUp = 0;
  if (x + step < width) {
    stepUp = step;
  }
  return stepUp;
}

#define Wraps Wrap
#define DistFn distSphere
#include "voronoi.gpu.incl"
#undef Wraps
#undef DistFn

#define Wraps NoWrap
#define DistFn distSphere
#include "voronoi.gpu.incl"
#undef Wraps
#undef DistFn

#define Wraps Wrap
#define DistFn distSqr
#include "voronoi.gpu.incl"
#undef Wraps
#undef DistFn

#define Wraps NoWrap
#define DistFn distSqr
#include "voronoi.gpu.incl"
#undef Wraps
#undef DistFn

/**
 * Build a base mask (finest scale) from a voronoi diagram.
 * On entry, @a src should contain a buffer generated by voronoiComputeKernel.
 * On return, @a dst contains the mask.
 */
__global__ void voronoiMakeMaskKernel(global_mem uint8_t* __restrict__ dst, global_mem uint32_t* __restrict__ src,
                                      uint32_t width, uint32_t height) {
  const uint32_t coords = get_global_id_y() * width + get_global_id_x();
  dst[coords] = 255 * ((src[coords] >> 30) & 1);
}

__global__ void edtReflexiveKernel(global_mem uint8_t* __restrict__ dst, const global_mem uint32_t* __restrict__ buf,
                                   uint32_t width, uint32_t height, uint32_t mask) {
  uint32_t coords = get_global_id_y() * width + get_global_id_x();
  uint32_t v = buf[coords];
  bool inImage = !!(v & mask);

  dst[coords] = (inImage) ? 255 : 0;
}

/**
 * Take a buffer and initialize it for use by voronoiComputeKernel in EDT mode.
 * each pixel is composed of an uint32_t, which contain the coordinates (row-major)
 * of the nearest pixel for the masked site if the gith order bit is set, and is undefined else.
 *
 * On entry, @a buf should contain the relevant portion of the setup image (bit i set if pixel in image i)
 * On return, @a dst contains a buffer to be used with edtComputeKernel.
 * @a dst should have the same size as @a buf.
 */

__global__ void edtInit(global_mem uint32_t* dst, const global_mem uint32_t* buf, uint32_t width, uint32_t height,
                        uint32_t mask, uint32_t otherMask) {
  const uint32_t x = get_global_id_x();
  const uint32_t y = get_global_id_y();
  if (x < width && y < height) {
    uint32_t coords = y * width + x;
    uint32_t v = buf[coords];
    bool inImage = !!(v & mask);
    bool inOtherImage = !!(v & otherMask);

    dst[coords] = (inImage && !inOtherImage) ? ((uint32_t)0x80000000 | coords) : 0x00000000;
  }
}