// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

#pragma once

/* Procedure used to generate the inverse warping function
 * Terminologies:
 * - Input : image captured by the camera
 * - Intermediate : projection of "Input" to an intermediate sphere of radius = 1
 * - Output: an equi-rectangular projection of the Intermediate sphere onto a cylinder
 */

#define WARPCOORD_INPUT_TO_OUTPUT_KERNEL(fromInputToSphere, isWithin, inverseDistortionMetersTransform,                                    \
                                         inverseRadialPixelsTransform)                                                                     \
  __global__ void                                                                                                                          \
      warpCoordInputToOutputKernel_##fromInputToSphere##_##isWithin##_##inverseDistortionMetersTransform##_##inverseRadialPixelsTransform( \
          float2* g_odata, int outputWidth, int outputHeight, int inputImgWidth, int inputImgHeight, int cropImgLeft,                      \
          int cropImgRight, int cropImgTop, int cropImgBottom, int inputWidth, int inputHeight, const int id,                              \
          const float2* g_idata, const uint32_t* g_iMask, const float2 panoScale, const vsfloat3x4 poseInverse,                            \
          const float rigSphereRadius, const float2 inputScale, const vsDistortion distortion,                                             \
          const float2 centerShift) {                                                                                                      \
    /* calculate normalized texture coordinates */                                                                                         \
    const int x = blockIdx.x * blockDim.x + threadIdx.x;                                                                                   \
    const int y = blockIdx.y * blockDim.y + threadIdx.y;                                                                                   \
    if (x >= inputWidth || y >= inputHeight) return;                                                                                       \
                                                                                                                                           \
    float2 uv = g_idata[y * inputWidth + x];                                                                                               \
    int mask = g_iMask[y * inputWidth + x];                                                                                                \
    if (mask != (1 << id)) {                                                                                                               \
      return;                                                                                                                              \
    }                                                                                                                                      \
    if (isWithin(uv, (float)inputImgWidth, (float)inputImgHeight, (float)cropImgLeft, (float)cropImgRight,                                 \
                 (float)cropImgTop, (float)cropImgBottom)) {                                                                               \
      uv += make_float2(0.5, 0.5);                                                                                                         \
      uv.x -= inputImgWidth / 2.0f;                                                                                                        \
      uv.y -= inputImgHeight / 2.0f;                                                                                                       \
      uv =                                                                                                                                 \
          mapInputToPanorama_##fromInputToSphere##_##inverseDistortionMetersTransform##_##inverseRadialPixelsTransform(                    \
              uv, panoScale, poseInverse, rigSphereRadius, inputScale, distortion, centerShift);                                           \
                                                                                                                                           \
      /**                                                                                                                                  \
       * Move to center-based coordinate                                                                                                   \
       */                                                                                                                                  \
      uv.x += outputWidth / 2.0f;                                                                                                          \
      uv.y += outputHeight / 2.0f;                                                                                                         \
                                                                                                                                           \
      if (isWithinRect(uv, (float)outputWidth, (float)outputHeight)) {                                                                     \
        g_odata[(y * inputWidth + x)] = uv;                                                                                                \
      }                                                                                                                                    \
    }                                                                                                                                      \
  }

#define WARPCOORD_INPUT_TO_OUTPUT_KERNEL2(INV_RADIAL1, INV_RADIAL2)                                \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(RectToSphere, isWithinCropRect, INV_RADIAL1, INV_RADIAL2)       \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(ErectToSphere, isWithinCropRect, INV_RADIAL1, INV_RADIAL2)      \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(FisheyeToSphere, isWithinCropRect, INV_RADIAL1, INV_RADIAL2)    \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(ExternalToSphere, isWithinCropRect, INV_RADIAL1, INV_RADIAL2)   \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(StereoToSphere, isWithinCropRect, INV_RADIAL1, INV_RADIAL2)     \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(RectToSphere, isWithinCropCircle, INV_RADIAL1, INV_RADIAL2)     \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(ErectToSphere, isWithinCropCircle, INV_RADIAL1, INV_RADIAL2)    \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(FisheyeToSphere, isWithinCropCircle, INV_RADIAL1, INV_RADIAL2)  \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(ExternalToSphere, isWithinCropCircle, INV_RADIAL1, INV_RADIAL2) \
  WARPCOORD_INPUT_TO_OUTPUT_KERNEL(StereoToSphere, isWithinCropCircle, INV_RADIAL1, INV_RADIAL2)