// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

#pragma once

#include "types.hpp"

#define MAT_TRANS_MULT_MAT(AB, A, B, DIM)                 \
  {                                                       \
    for (int i = 0; i < DIM; ++i) {                       \
      for (int j = 0; j < DIM; ++j) {                     \
        float accumulator = 0;                            \
        for (int k = 0; k < DIM; ++k) {                   \
          accumulator += A.values[k][i] * B.values[k][j]; \
        }                                                 \
        AB.values[i][j] = accumulator;                    \
      }                                                   \
    }                                                     \
  }

#ifndef GPU_CL_ARGS_WORKAROUND
#define radial_values(p, a) (p.values[a])
#else
#define radial_values(p, a) (p.s##a)
#endif

INL_ROTATION_TRANSFORM_FN(matMultPoint)(float3 pt, const vsfloat3x3 R) {
  const float3 oldpt = pt;
#ifndef GPU_CL_ARGS_WORKAROUND
  pt.x = R.values[0][0] * oldpt.x + R.values[0][1] * oldpt.y + R.values[0][2] * oldpt.z;
  pt.y = R.values[1][0] * oldpt.x + R.values[1][1] * oldpt.y + R.values[1][2] * oldpt.z;
  pt.z = R.values[2][0] * oldpt.x + R.values[2][1] * oldpt.y + R.values[2][2] * oldpt.z;
#else
  pt.x = R.s0 * oldpt.x + R.s1 * oldpt.y + R.s2 * oldpt.z;
  pt.y = R.s3 * oldpt.x + R.s4 * oldpt.y + R.s5 * oldpt.z;
  pt.z = R.s6 * oldpt.x + R.s7 * oldpt.y + R.s8 * oldpt.z;
#endif
  return pt;
}

INL_ROTATION_TRANSFORM_FN(rotateSphere)(float3 pt, const vsfloat3x3 R) {
  return INL_FUNCTION_NAME(matMultPoint)(pt, R);
}

INL_PERSPECTIVE_TRANSFORM_FN(transformSphere)(float3 pt, const vsfloat3x4 R) {
  float3 scaledPt = pt;
#ifndef GPU_CL_ARGS_WORKAROUND
  pt.x = (R.values[0][0] * scaledPt.x + R.values[0][1] * scaledPt.y + R.values[0][2] * scaledPt.z + R.values[0][3]);
  pt.y = (R.values[1][0] * scaledPt.x + R.values[1][1] * scaledPt.y + R.values[1][2] * scaledPt.z + R.values[1][3]);
  pt.z = (R.values[2][0] * scaledPt.x + R.values[2][1] * scaledPt.y + R.values[2][2] * scaledPt.z + R.values[2][3]);
#else
  pt.x = (R.s0 * scaledPt.x + R.s1 * scaledPt.y + R.s2 * scaledPt.z + R.s3);
  pt.y = (R.s4 * scaledPt.x + R.s5 * scaledPt.y + R.s6 * scaledPt.z + R.s7);
  pt.z = (R.s8 * scaledPt.x + R.s9 * scaledPt.y + R.sa * scaledPt.z + R.sb);
#endif
  return pt;
}

INL_DISTORTION_TRANSFORM_FN(noopDistortionTransform)(float2 val, const vsDistortion dummyParamName) { return val; }

INL_DISTORTION_TRANSFORM_FN(radialScaled)(float2 uv, const vsDistortion p) {
  const float r = length_vs(uv);
  float scale = LARGE_ROOT;
  if (r < radial_values(p, 4)) {
    scale = ((radial_values(p, 3) * r + radial_values(p, 2)) * r + radial_values(p, 1)) * r + radial_values(p, 0);
  }
  uv *= scale;
  return uv;
}

/**
 * Forward distortion transform
 *
 * Keeping PTGui's radial distortion terms, and adopting OpenCV's tangential, thin-prism and Scheimpflug distortions
 */
INL_DISTORTION_TRANSFORM_FN(distortionScaled)(float2 uv, const vsDistortion p) {
  const float r = length_vs(uv);
  float scale = LARGE_ROOT;
  if (r < radial_values(p, 4)) {
    scale = ((radial_values(p, 3) * r + radial_values(p, 2)) * r + radial_values(p, 1)) * r + radial_values(p, 0);
  }

#ifndef GPU_CL_ARGS_WORKAROUND
  // additional non-radial terms
  float2 additional = make_float2(0.f, 0.f);

  // tangential distortion
  const float radius2 = r * r;
  if (p.distortionBitFlag & TANGENTIAL_DISTORTION_BIT) {
    const float P1 = p.values[5];
    const float P2 = p.values[6];
    additional.x = 2 * P1 * uv.x * uv.y + P2 * (radius2 + 2 * uv.x * uv.x);
    additional.y = P1 * (radius2 + 2 * uv.y * uv.y) + 2 * P2 * uv.x * uv.y;
  }

  // thin-prism distortion
  if (p.distortionBitFlag & THIN_PRISM_DISTORTION_BIT) {
    const float S1 = p.values[7];
    const float S2 = p.values[8];
    const float S3 = p.values[9];
    const float S4 = p.values[10];
    additional.x += (S2 * radius2 + S1) * radius2;
    additional.y += (S4 * radius2 + S3) * radius2;
  }

  // put everything together
  uv = uv * scale + additional;

  // Scheimpflug distortion
  if (p.distortionBitFlag & SCHEIMPFLUG_DISTORTION_BIT) {
    float3 point = make_float3(uv.x, uv.y, 1.0f);
    point = INL_FUNCTION_NAME(matMultPoint)(point, p.scheimpflugForward);
    if (point.z != 0.f) {
      uv.x = point.x / point.z;
      uv.y = point.y / point.z;
    } else {
      // do not know what to do but should not happen
    }
  }
#endif
  return uv;
}

/**
 * Angular distance (angle) between two points on a sphere
 * @params p1,p2 points on a sphere
 * @param angle angularDistance between p1 and p2
 */
INL_TRANSFORM_FN_1D(angularDistance)(const float3 p1, const float3 p2) {
  float3 cp = cross(p1, p2);
  float angle = asinf_vs(length_vs(cp));
  if (dot(p1, p2) < 0.0f) {
    angle = (float)PI_F_VS - angle;
  }
  return angle;
}

/**
 * Inverse distortion transform.
 *
 * Keeping PTGui's radial distortion terms, and adopting OpenCV's tangential, thin-prism and Scheimpflug distortions
 */
INL_DISTORTION_TRANSFORM_FN(inverseDistortionScaled)(float2 uv, const vsDistortion p) {
  /**
   * Undistort:
   *   undistort Scheimpflug first,
   *   then radial distortion to get an approximate value,
   *   then gradient descent with all the other radial and non-radial distortion parameters
   */

#ifndef GPU_CL_ARGS_WORKAROUND
  /**
   * Scheimpflug distortion
   */
  if (p.distortionBitFlag & SCHEIMPFLUG_DISTORTION_BIT) {
    float3 point = make_float3(uv.x, uv.y, 1.0f);
    point = INL_FUNCTION_NAME(matMultPoint)(point, p.scheimpflugInverse);
    if (point.z != 0.0f) {
      uv.x = point.x / point.z;
      uv.y = point.y / point.z;
    } else {
      // do not know what to do but should not happen
      return make_float2(INVALID_INVERSE_DISTORTION, INVALID_INVERSE_DISTORTION);
    }
  }
#endif

  /**
   * Inverse radial distortion
   */

  /**
   * We are looking for xy such that:
   *   xy * (p0 + p1 * length(xy) + p2 * length(xy)^2 + p3 * length(xy)^3) == uv (1)
   *
   * For simplicity: denote x = xy.x, y = xy.y, u = uv.x and v = uv.v
   * Since xy = scale * uv.
   * Consider case u != 0 (without loss of generality u = 0 && v != 0 is treated similarly):
   * Let s = y/x = v/u we have y = x*s (2) and length(xy) = |x|(1 + s*s)^(1/2) (3)
   * Plug (2),(3) into (1), we can solve the following equation for x
   *  a*x*|x|^3 + b*x^3 + c*|x|*x + d*x + e = 0
   * where a = p3*(1 + s*s)^(3/2)
   *       b = p2*(1 + s*s)^(2/2)
   *       c = p1*(1 + s*s)^(1/2)
   *       d = p0
   *       e = -u
   *       s = v/u
   * This quartic function might contain upto 4 simple solutions.
   * Need to pick the one that satisfies the condition length(xy) < p4
   * For now, the solution with the minimum length satisfying length(xy) < p4 is picked,
   *          otherwise, the input coordinate is non-invertable
   */
  float2 undistorted = make_float2(INVALID_INVERSE_DISTORTION, INVALID_INVERSE_DISTORTION);
  const float r_large = length_vs(make_float2(uv.x / LARGE_ROOT, uv.y / LARGE_ROOT));
  if (r_large >= radial_values(p, 4)) {
    return make_float2(uv.x / LARGE_ROOT, uv.y / LARGE_ROOT);
  }
  if (fabsf_vs(uv.x) < 1e-10f && fabsf_vs(uv.y) < 1e-10f) {
    return make_float2(0.0f, 0.0f);
  }
  const bool solveX =
      fabsf_vs(uv.x) > 1e-4f;  // This threshold is used to decide whether to solve for x or y. For numerical stability
                               // of the quartic solver, it should stay in the range [1e-4f..1e-6f]
  const float s = (solveX) ? uv.y / uv.x : uv.x / uv.y;
  const float a = radial_values(p, 3) * powf_vs((1.0f + s * s), 3.0f / 2.0f);
  const float b = radial_values(p, 2) * (1.0f + s * s);
  const float c = radial_values(p, 1) * sqrtf_vs(1.0f + s * s);
  const float d = radial_values(p, 0);
  const float e = (solveX) ? -uv.x : -uv.y;

  float rMin = -1;
  // Next, collect all the positive solutions
  const vsQuarticSolution rtsPositive = solveQuartic_vs(a, b, c, d, e);
  for (int i = 0; i < rtsPositive.solutionCount; i++) {
    if (rtsPositive.x[i] >= 0) {
      const float r = sqrtf_vs(rtsPositive.x[i] * rtsPositive.x[i] * (1 + s * s));
      if (r < rMin || rMin < 0) {
        rMin = r;
        undistorted = (solveX) ? make_float2(rtsPositive.x[i], rtsPositive.x[i] * s)
                               : make_float2(rtsPositive.x[i] * s, rtsPositive.x[i]);
      }
    }
  }

  // Collect all negative solutions
  const vsQuarticSolution rtsNegative = solveQuartic_vs(-a, b, -c, d, e);
  for (int i = 0; i < rtsNegative.solutionCount; i++) {
    if (rtsNegative.x[i] < 0) {
      const float r = sqrtf_vs(rtsNegative.x[i] * rtsNegative.x[i] * (1 + s * s));
      if (r < rMin || rMin < 0) {
        rMin = r;
        undistorted = (solveX) ? make_float2(rtsNegative.x[i], rtsNegative.x[i] * s)
                               : make_float2(rtsNegative.x[i] * s, rtsNegative.x[i]);
      }
    }
  }
  if (rMin >= radial_values(p, 4)) {
    return make_float2(INVALID_INVERSE_DISTORTION, INVALID_INVERSE_DISTORTION);
  }

#ifndef GPU_CL_ARGS_WORKAROUND
  /**
   * Undistort the other parameters, if necessary, by gradient descent
   * distorted (x',y') from undistorted (x,y) is given by
   * r = sqrt(x^2 + y^2)
   * x' = x*(((A * radius + B) * radius + C) * radius + D) + 2*P1*x*y + P2*(r^2 + 2*x^2) + S1*r^2 + S2*r^4
   * y' = y*(((A * radius + B) * radius + C) * radius + D) + P1*(r^2 + 2*y^2) + 2*P2*x*y + S3*r^2 + S4*r^4
   */
  if ((p.distortionBitFlag & (TANGENTIAL_DISTORTION_BIT | THIN_PRISM_DISTORTION_BIT)) &&
      length_vs(undistorted) >= 1.e-6f) {
    const float P1 = p.values[5], P2 = p.values[6];
    const float S1 = p.values[7], S2 = p.values[8], S3 = p.values[9], S4 = p.values[10];
    const float A = p.values[3], B = p.values[2], C = p.values[1], D = p.values[0];
    int iters = 100;

    while (--iters) {
      const float r = length_vs(undistorted);
      const float r2 = r * r;
      const float r4 = r2 * r2;
      const float scale = (((A * r + B) * r + C) * r + D);

      float2 error;
      error.x = uv.x - (undistorted.x * scale + 2 * P1 * undistorted.x * undistorted.y +
                        P2 * (r2 + 2 * undistorted.x * undistorted.x) + S1 * r2 + S2 * r4);
      error.y = uv.y - (undistorted.y * scale + P1 * (r2 + 2 * undistorted.y * undistorted.y) +
                        2 * P2 * undistorted.x * undistorted.y + S3 * r2 + S4 * r4);

      // Break the loop when error is small
      if (fabsf_vs(error.x) < 1.e-12f && fabsf_vs(error.y) < 1e-12f) {
        break;
      }

      // Compute Jacobian J
      const float dscale_dx = 3 * A * undistorted.x * r + 2 * B * undistorted.x + C * undistorted.x / r;
      const float dscale_dy = 3 * A * undistorted.y * r + 2 * B * undistorted.y + C * undistorted.y / r;
      vsfloat2x2 J;

      // dx'_dx
      J.values[0][0] =
          undistorted.x * dscale_dx + scale +
          2 * (P1 * undistorted.y + 3 * P2 * undistorted.x + S1 * undistorted.x + 2 * S2 * undistorted.x * r2);
      // dx'_dy
      J.values[0][1] = undistorted.x * dscale_dy +
                       2 * (P1 * undistorted.x + P2 * undistorted.y + S1 * undistorted.y + 2 * S2 * undistorted.y * r2);
      // dy'_dx
      J.values[1][0] = undistorted.y * dscale_dx +
                       2 * (P1 * undistorted.x + P2 * undistorted.y + S3 * undistorted.x + 2 * S4 * undistorted.x * r2);
      // dy'_dy
      J.values[1][1] =
          undistorted.y * dscale_dy + scale +
          2 * (3 * P1 * undistorted.y + P2 * undistorted.x + S3 * undistorted.y + 2 * S4 * undistorted.y * r2);

      // Update step p += (JtJ)^-1.Jt.error
      vsfloat2x2 JtJ;
      MAT_TRANS_MULT_MAT(JtJ, J, J, 2);
      const float JtJdeterminant = JtJ.values[0][0] * JtJ.values[1][1] - JtJ.values[0][1] * JtJ.values[1][0];
      if (JtJdeterminant > 1.e-6f) {
        const float invdet = 1.0f / JtJdeterminant;
        vsfloat2x2 JtJinverse;

        JtJinverse.values[0][0] = invdet * JtJ.values[1][1];
        JtJinverse.values[0][1] = -invdet * JtJ.values[0][1];
        JtJinverse.values[1][0] = -invdet * JtJ.values[1][0];
        JtJinverse.values[1][1] = invdet * JtJ.values[0][0];

        // J.transpose() * error;
        const float2 JtError = make_float2(J.values[0][0] * error.x + J.values[1][0] * error.y,
                                           J.values[0][1] * error.x + J.values[1][1] * error.y);
        // update = JtJinverse * JtError
        const float2 update = make_float2(JtJinverse.values[0][0] * JtError.x + JtJinverse.values[0][1] * JtError.y,
                                          JtJinverse.values[1][0] * JtError.x + JtJinverse.values[1][1] * JtError.y);

        undistorted += update;
      } else {
        // should not get here
        return make_float2(INVALID_INVERSE_DISTORTION, INVALID_INVERSE_DISTORTION);
      }
    }
  }
#endif

  return undistorted;
}

INL_CONVERT_3D_2D_FN(SphereToRect)(const float3 pt) {
  float2 uv;

  if (pt.z < 0.0001f) {
    uv.x = -100.0f;
    uv.y = -100.0f;
    return uv;
  }

  uv.x = pt.x / pt.z;
  uv.y = pt.y / pt.z;
  return uv;
}

INL_CONVERT_3D_2D_FN(SphereToErect)(const float3 pt) {
  float2 uv;
  uv.x = atan2f_vs(pt.x, pt.z);
  uv.y = asinf_vs(pt.y * invLength3(pt));
  uv = clampf_vs(uv, -PI_F_VS + FLT_EPSILON, PI_F_VS - FLT_EPSILON);
  return uv;
}

INL_CONVERT_3D_2D_FN(SphereToFisheye)(const float3 pt) {
  const float theta = acosf_vs(pt.z * invLength3(pt));
  const float phi = atan2f_vs(pt.y, pt.x);
  float2 uv;
  uv.x = cosf_vs(phi) * theta;
  uv.y = sinf_vs(phi) * theta;
  return uv;
}

INL_CONVERT_3D_2D_FN(SphereToExternal)(const float3 pt) {
  const float xi = length_vs(pt);
  const float norm = pt.z + xi;
  float2 uv;
  uv.x = pt.x / norm;
  uv.y = pt.y / norm;
  return uv;
}

INL_CONVERT_2D_3D_FN(FisheyeToSphere)(const float2 uv) {
  const float phi = atan2f_vs(uv.y, uv.x);
  const float theta = length_vs(uv);
  float3 pt;
  pt.x = sinf_vs(theta) * cosf_vs(phi);
  pt.y = sinf_vs(theta) * sinf_vs(phi);
  pt.z = cosf_vs(theta);
  return pt;
}

INL_CONVERT_2D_3D_FN(ErectToSphere)(const float2 uv) {
  const float2 uv_fix = clampf_vs(uv, -PI_F_VS + FLT_EPSILON, PI_F_VS - FLT_EPSILON);
  float3 pt;
  pt.x = cosf_vs(uv_fix.y) * sinf_vs(uv_fix.x);
  pt.y = sinf_vs(uv_fix.y);
  pt.z = cosf_vs(uv_fix.y) * cosf_vs(uv_fix.x);
  return pt;
}

INL_CONVERT_2D_3D_FN(RectToSphere)(const float2 uv) {
  const float3 vec = make_float3(uv.x, uv.y, 1.0f);
  const float norm = invLength3(vec);
  float3 pt;
  pt.x = uv.x * norm;
  pt.y = uv.y * norm;
  pt.z = norm;
  return pt;
}

INL_CONVERT_2D_3D_FN(StereoToSphere)(const float2 uv) {
  const float sq1 = uv.x * uv.x;
  const float sq2 = uv.y * uv.y;
  const float sq = sq1 + sq2;
  const float norm = 1.0f / (sq + 4.0f);
  float3 pt;
  pt.x = 4.0f * uv.x * norm;
  pt.y = 4.0f * uv.y * norm;
  pt.z = -(sq - 4.0f) * norm;
  return pt;
}

INL_CONVERT_3D_2D_FN(SphereToStereo)(const float3 pt) {
  const float norm = 2.0f / (length_vs(pt) + pt.z);
  float2 uv;
  uv.x = norm * pt.x;
  uv.y = norm * pt.y;
  return uv;
}

INL_CONVERT_2D_3D_FN(ExternalToSphere)(const float2 uv) {
  const float x = uv.x;
  const float y = uv.y;
  const float x2 = x * x;
  const float y2 = y * y;
  const float norm = x2 + y2 + 1.0f;
  float3 pt;
  pt.x = 2.0f * x / norm;
  pt.y = 2.0f * y / norm;
  pt.z = -(x2 + y2 - 1.0f) / norm;
  return pt;
}

#undef MAT_TRANS_MULT_MAT