// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

/**
* Inverse mapping function that hold all the transform stack logic.
* For a pixel coordinates in the given input space, compute the pixel coordinates in output space
*/

__device__ float3 FUNCTION_NAME_4(mapInputToCameraSphere,
                                  fromInputToSphere,
                                  inverseDistortionMetersTransform,
                                  inverseRadialPixelsTransform)(float2 uv,
                                                                const float2 inputScale,
                                                                const vsDistortion distortion,
                                                                const float2 centerShift) const_member {

  /** Shift of principal point (Wrt image center)*/
  uv -= centerShift;

  /** Undistort point in pixel space*/
  uv = inverseRadialPixelsTransform(uv, distortion);

  RETURN_3D_IF_INVALID_INVERSE_2D(uv)

  /* To meter space */
  uv /= inputScale;

  /* Undistort point in meter space*/
  uv = inverseDistortionMetersTransform(uv, distortion);

  RETURN_3D_IF_INVALID_INVERSE_2D(uv)

  /* From input space to spherical space */
  float3 pt = fromInputToSphere(uv);
  return pt;
}

__device__ float3 FUNCTION_NAME_4(tracePointToRigSphere,
                                  fromInputToSphere,
                                  inverseDistortionMetersTransform,
                                  inverseRadialPixelsTransform)(float3 pt,
                                                                const vsfloat3x4 poseInverse,
                                                                const float rigSphereRadius) const_member {

  // Have translation?
#ifndef GPU_CL_ARGS_WORKAROUND
  if (poseInverse.values[0][3] != 0.0f ||
      poseInverse.values[1][3] != 0.0f ||
      poseInverse.values[2][3] != 0.0f) {
#else
  if (poseInverse.s3 != 0.0f ||
      poseInverse.s7 != 0.0f ||
      poseInverse.sb != 0.0f) {
#endif
    /* Camera center in spherical (rig) space */
    const float3 camcenterpt = transformSphere(make_float3(0.f, 0.f, 0.f), poseInverse);

    /* Set refpt to reference sphere scale by ray-tracing it, finding lambda > 0 such as norm(camcenterpt + lambda * (pt
     * - camcenterpt)) == rigSphereRadius */
    const float3 ray = pt - camcenterpt;

    /* lambda is root of lambda^2.ray^T.ray + 2 * lambda * center^T.ray  + center^T.center - rigSphereRadius^2 */
    const float a = dot(ray, ray);
    const float b = 2 * dot(camcenterpt, ray);
    const float c = dot(camcenterpt, camcenterpt) - rigSphereRadius * rigSphereRadius;
    const float d = b * b - 4 * a * c;

    // TODODEPTH investigate why this triggers in sphere sweep, re-enable
    // assert(d >= 0.f /* if failing, it is likely that radius is smaller than getInputMinimumRigSphereRadius() */);

    const float lambda = (d >= 0.f) ? (-b + sqrtf_vs(d)) / (2 * a) : 1.f;
    pt = camcenterpt + lambda * ray;
  }

  // TODODEPTH investigate why this triggers in sphere sweep, re-enable
  // assert(
  //     std::abs(length(pt) - rigSphereRadius) < 1e-3f ||
  //     (pt.x == INVALID_INVERSE_DISTORTION && pt.y == INVALID_INVERSE_DISTORTION && pt.z == INVALID_INVERSE_DISTORTION));

  return pt;
}

__device__ float3 FUNCTION_NAME_4(mapInputToRigSpherical,
                                  fromInputToSphere,
                                  inverseDistortionMetersTransform,
                                  inverseRadialPixelsTransform)(float2 uv,
                                                                const vsfloat3x4 poseInverse,
                                                                const float rigSphereRadius,
                                                                const float2 inputScale,
                                                                const vsDistortion distortion,
                                                                const float2 centerShift) const_member {

  float3 pt = FUNCTION_NAME_4(mapInputToCameraSphere, fromInputToSphere, inverseDistortionMetersTransform,
                              inverseRadialPixelsTransform)(uv, inputScale, distortion, centerShift);
  RETURN_3D_IF_INVALID_INVERSE_3D(pt)

  /** Transform camera to rig */
  pt = transformSphere(pt, poseInverse);

  pt = FUNCTION_NAME_4(tracePointToRigSphere, fromInputToSphere, inverseDistortionMetersTransform,
                                    inverseRadialPixelsTransform)(pt, poseInverse, rigSphereRadius);
  return pt;
}

__device__ float2 FUNCTION_NAME_4(mapInputToPanorama,
                                  fromInputToSphere,
                                  inverseDistortionMetersTransform,
                                  inverseRadialPixelsTransform)(float2 uv,
                                                                const float2 panoScale,
                                                                const vsfloat3x4 poseInverse,
                                                                const float rigSphereRadius,
                                                                const float2 inputScale,
                                                                const vsDistortion distortion,
                                                                const float2 centerShift) const_member {

  const float3 pt = FUNCTION_NAME_4(mapInputToRigSpherical, fromInputToSphere, inverseDistortionMetersTransform,
                                    inverseRadialPixelsTransform)(uv, poseInverse, rigSphereRadius, inputScale,
                                                                  distortion, centerShift);

  RETURN_2D_IF_INVALID_INVERSE_3D(pt)

  /** From spherical to output space */
  uv = fromSphereToOutput(pt);

  /** From output space to normalized space */
  uv *= panoScale;
  return uv;
}