// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

// poly.cpp : solution of cubic and quartic equation
// (c) Khashin S.I. http://math.ivanovo.ac.ru/dalgebra/Khashin/index.html
// khash2 (at) gmail.com
//
#pragma once

#include "types.hpp"

#define TwoPi 6.28318530717958648f

INL_FN_QUARTIC_SOL(setQuarticSolution)(const int solutionCount, const solve_float_t x[4]) {
  vsQuarticSolution s;
  s.solutionCount = solutionCount;
  for (int i = 0; i < 4; i++) {
    s.x[i] = convert_float(x[i]);
  }
  return s;
}

INL_FN_QUARTIC_SOL(solveP1)(const solve_float_t a) {
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  x[0] = -a;
  return FN_NAME(setQuarticSolution)(1, x);
}

INL_FN_QUARTIC_SOL(solveP2)(const solve_float_t a, const solve_float_t b) {
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  const solve_float_t delta = a * a - 4.0f * b;
  if (delta < 0.0f) {
    return FN_NAME(setQuarticSolution)(0, x);
  }
  const solve_float_t sqrtDelta = sqrt_vs(delta);
  x[0] = 0.5f * (-a + sqrtDelta);
  if (delta <= 1e-14f) {
    return FN_NAME(setQuarticSolution)(1, x);
  }
  x[1] = 0.5f * (-a - sqrtDelta);
  return FN_NAME(setQuarticSolution)(2, x);
}

//---------------------------------------------------------------------------
// x - array of size 3
// In case 3 real roots: => x[0], x[1], x[2], return 3
//         2 real roots: x[0], x[1],          return 2
//         1 real root : x[0], x[1] ± i*x[2], return 1
INL_FN_QUARTIC_SOL(solveP3)
(const solve_float_t a, const solve_float_t b, const solve_float_t c) {  // solve cubic equation x^3 + a*x^2 + b*x + c
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  const solve_float_t a2 = a * a;
  const solve_float_t a_3 = a / 3.0f;
  solve_float_t q = (a2 - 3.0f * b) / 9.0f;
  const solve_float_t r = (a * (2.0f * a2 - 9.0f * b) + 27.0f * c) / 54.0f;
  const solve_float_t r2 = r * r;
  const solve_float_t q3 = q * q * q;
  if (r2 < q3) {
    solve_float_t t = r / sqrt_vs(q3);
    if (t < -1.0f) {
      t = -1.0f;
    }
    if (t > 1.0f) {
      t = 1.0f;
    }
    t = acos_vs(t);
    q = -2.0f * sqrt_vs(q);
    x[0] = q * cos_vs(t / 3.0f) - a_3;
    x[1] = q * cos_vs((t + TwoPi) / 3.0f) - a_3;
    x[2] = q * cos_vs((t - TwoPi) / 3.0f) - a_3;
    return FN_NAME(setQuarticSolution)(3, x);
  } else {
    solve_float_t A, B;
    A = -pow_vs(fabs_vs(r) + sqrt_vs(r2 - q3), (solve_float_t)(1.0f / 3.0f));
    if (r < 0.0f) {
      A = -A;
    }
    B = (A == 0.0f) ? 0.0f : q / A;
    x[0] = (A + B) - a_3;
    x[1] = -0.5f * (A + B) - a_3;
    x[2] = 0.5f * sqrt_vs(3.0f) * (A - B);
    if (fabs_vs(x[2]) <= 1e-14f) {
      x[2] = x[1];
      return FN_NAME(setQuarticSolution)(2, x);
    }
    return FN_NAME(setQuarticSolution)(1, x);
  }
}  // solveP3(solve_float_t a,solve_float_t b,solve_float_t c) {

//---------------------------------------------------------------------------
// a>=0!
INL_FN_FLOAT2(cSqrt)(const solve_float_t x, const solve_float_t y) {  // returns:  a+i*s = sqrt(x+i*y)
  solve_float_t a, b;
  solve_float_t r = sqrt_vs(x * x + y * y);
  if (fabs_vs(y) <= 1e-14f) {
    r = sqrt_vs(r);
    if (x >= 0.0f) {
      a = r;
      b = 0.0f;
    } else {
      a = 0.0f;
      b = r;
    }
  } else {  // y != 0
    a = sqrt_vs(0.5f * (x + r));
    b = 0.5f * y / a;
  }
  return make_solve_float_t2(a, b);
}

//---------------------------------------------------------------------------
// solve equation x^4 + b*x^2 + d = 0
INL_FN_QUARTIC_SOL(solveP4Bi)(const solve_float_t b, const solve_float_t d) {
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  const solve_float_t D = b * b - 4 * d;
  if (D >= 0.0f) {
    solve_float_t sD = sqrt_vs(D);
    solve_float_t x1 = (-b + sD) / 2;
    solve_float_t x2 = (-b - sD) / 2;  // x2 <= x1
    if (x2 >= 0.0f) {                  // 0 <= x2 <= x1, 4 real roots
      solve_float_t sx1 = sqrt_vs(x1);
      solve_float_t sx2 = sqrt_vs(x2);
      x[0] = -sx1;
      x[1] = sx1;
      x[2] = -sx2;
      x[3] = sx2;
      return FN_NAME(setQuarticSolution)(4, x);
    }
    if (x1 < 0.0f) {  // x2 <= x1 < 0, two pair of imaginary roots
      solve_float_t sx1 = sqrt_vs(-x1);
      solve_float_t sx2 = sqrt_vs(-x2);
      x[0] = 0.0f;
      x[1] = sx1;
      x[2] = 0.0f;
      x[3] = sx2;
      return FN_NAME(setQuarticSolution)(0, x);
    }
    // now x2 < 0 <= x1 , two real roots and one pair of imginary root
    solve_float_t sx1 = sqrt_vs(x1);
    solve_float_t sx2 = sqrt_vs(-x2);
    x[0] = -sx1;
    x[1] = sx1;
    x[2] = 0.0f;
    x[3] = sx2;
    return FN_NAME(setQuarticSolution)(2, x);
  } else {  // if( D < 0 ), two pair of compex roots
    solve_float_t sD2 = 0.5f * sqrt_vs(-D);
    solve_float_t2 sol0 = FN_NAME(cSqrt)(-0.5f * b, sD2);
    x[0] = sol0.x;
    x[1] = sol0.y;
    solve_float_t2 sol1 = FN_NAME(cSqrt)(-0.5f * b, -sD2);
    x[2] = sol1.x;
    x[3] = sol1.y;
    return FN_NAME(setQuarticSolution)(0, x);
  }  // if( D>=0 )
}  // solveP4Bi(solve_float_t b, solve_float_t d)	// solve equation x^4 + b*x^2 d

//---------------------------------------------------------------------------
INL_FN_FLOAT3(dblSort3)(solve_float_t a, solve_float_t b, solve_float_t c) {  // make: a <= b <= c
  solve_float_t t;
  if (a > b) {
    t = b;
    b = a;
    a = t;  // now a<=b
  }
  if (c < b) {
    t = b;
    b = c;
    c = t;  // now a<=b, b<=c
    if (a > b) {
      t = b;
      b = a;
      a = t;  // now a<=b
    }
  }
  return make_solve_float_t3(a, b, c);
}
//---------------------------------------------------------------------------
INL_FN_QUARTIC_SOL(solveP4De)
(const solve_float_t b, const solve_float_t c, const solve_float_t d) {  // solve equation x^4 + b*x^2 + c*x + d
  // if( c==0 ) return solveP4Bi(b,d); // After that, c!=0
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  if (fabs_vs(c) < 1e-14f * (fabs_vs(b) + fabs_vs(d))) {
    return FN_NAME(solveP4Bi)(b, d);  // After that, c!=0
  }
  const vsQuarticSolution s = FN_NAME(solveP3)(2 * b, b * b - 4 * d, -c * c);
  x[0] = s.x[0];
  x[1] = s.x[1];
  x[2] = s.x[2];
  x[3] = s.x[3];
  int res3 = s.solutionCount;  // solve resolvent
  // by Viet theorem:  x1*x2*x3=-c*c not equals to 0, so x1!=0, x2!=0, x3!=0
  if (res3 > 1) {                                                       // 3 real roots,
    const solve_float_t3 sorted = FN_NAME(dblSort3)(x[0], x[1], x[2]);  // sort roots to x[0] <= x[1] <= x[2]
    x[0] = sorted.x;
    x[1] = sorted.y;
    x[2] = sorted.z;
    // Note: x[0]*x[1]*x[2]= c*c > 0
    if (x[0] > 0.0f) {  // all roots are positive
      const solve_float_t sz1 = sqrt_vs(x[0]);
      const solve_float_t sz2 = sqrt_vs(x[1]);
      const solve_float_t sz3 = sqrt_vs(x[2]);
      // Note: sz1*sz2*sz3= -c (and not equal to 0)
      if (c > 0.0f) {
        x[0] = (-sz1 - sz2 - sz3) / 2;
        x[1] = (-sz1 + sz2 + sz3) / 2;
        x[2] = (+sz1 - sz2 + sz3) / 2;
        x[3] = (+sz1 + sz2 - sz3) / 2;
        return FN_NAME(setQuarticSolution)(4, x);
      }
      // now: c<0
      x[0] = (-sz1 - sz2 + sz3) / 2;
      x[1] = (-sz1 + sz2 - sz3) / 2;
      x[2] = (+sz1 - sz2 - sz3) / 2;
      x[3] = (+sz1 + sz2 + sz3) / 2;
      return FN_NAME(setQuarticSolution)(4, x);
    }  // if( x[0] > 0) // all roots are positive
    // now x[0] <= x[1] < 0, x[2] > 0
    // two pair of complex roots
    const solve_float_t sz1 = sqrt_vs(-x[0]);
    const solve_float_t sz2 = sqrt_vs(-x[1]);
    const solve_float_t sz3 = sqrt_vs(x[2]);

    if (c > 0) {  // sign = -1
      x[0] = -sz3 / 2;
      x[1] = (sz1 - sz2) / 2;  // x[0]±i*x[1]
      x[2] = sz3 / 2;
      x[3] = (-sz1 - sz2) / 2;  // x[2]±i*x[3]
      return FN_NAME(setQuarticSolution)(0, x);
    }
    // now: c<0 , sign = +1
    x[0] = sz3 / 2;
    x[1] = (-sz1 + sz2) / 2;
    x[2] = -sz3 / 2;
    x[3] = (sz1 + sz2) / 2;
    return FN_NAME(setQuarticSolution)(0, x);
  }  // if( res3>1 )	// 3 real roots,
  // now resoventa have 1 real and pair of compex roots
  // x[0] - real root, and x[0]>0,
  // x[1]±i*x[2] - complex roots,
  const solve_float_t sz1 = sqrt_vs(x[0]);
  const solve_float_t2 res = FN_NAME(cSqrt)(x[1], x[2]);
  // (szr+i*szi)^2 = x[1]+i*x[2]
  const solve_float_t szr = res.x;
  const solve_float_t szi = res.y;
  if (c > 0) {              // sign = -1
    x[0] = -sz1 / 2 - szr;  // 1st real root
    x[1] = -sz1 / 2 + szr;  // 2nd real root
    x[2] = sz1 / 2;
    x[3] = szi;
    return FN_NAME(setQuarticSolution)(2, x);
  }
  // now: c<0 , sign = +1
  x[0] = sz1 / 2 - szr;  // 1st real root
  x[1] = sz1 / 2 + szr;  // 2nd real root
  x[2] = -sz1 / 2;
  x[3] = szi;
  return FN_NAME(setQuarticSolution)(2, x);
}  // solveP4De(solve_float_t b, solve_float_t c, solve_float_t d)	// solve equation x^4 + b*x^2 + c*x + d

//-----------------------------------------------------------------------------
INL_FN_FLOAT(n4Step)
(const solve_float_t x, const solve_float_t a, const solve_float_t b, const solve_float_t c,
 const solve_float_t d) {                                           // one Newton step for x^4 + a*x^3 + b*x^2 + c*x + d
  const solve_float_t fxs = ((4 * x + 3 * a) * x + 2 * b) * x + c;  // f'(x)
  if (fabs_vs(fxs) <= 1e-14f) {
    return 1e20f;
  }
  solve_float_t fx = (((x + a) * x + b) * x + c) * x + d;  // f(x)
  return x - fx / fxs;
}

//-----------------------------------------------------------------------------
// x - array of size 4
// return 4: 4 real roots x[0], x[1], x[2], x[3], possible multiple roots
// return 2: 2 real roots x[0], x[1] and complex x[2]±i*x[3],
// return 0: two pair of complex roots: x[0]±i*x[1],  x[2]±i*x[3],
INL_FN_QUARTIC_SOL(solveP4)
(const solve_float_t a, const solve_float_t b, const solve_float_t c,
 const solve_float_t d) {  // solve equation x^4 + a*x^3 + b*x^2 + c*x + d by Dekart-Euler method
  // move to a=0:
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  const solve_float_t d1 = d + 0.25f * a * (0.25f * b * a - 3.0f / 64.0f * a * a * a - c);
  const solve_float_t c1 = c + 0.5f * a * (0.25f * a * a - b);
  const solve_float_t b1 = b - 0.375f * a * a;
  const vsQuarticSolution s = FN_NAME(solveP4De)(b1, c1, d1);
  x[0] = s.x[0];
  x[1] = s.x[1];
  x[2] = s.x[2];
  x[3] = s.x[3];
  const int res = s.solutionCount;
  if (res == 4) {
    x[0] -= a / 4;
    x[1] -= a / 4;
    x[2] -= a / 4;
    x[3] -= a / 4;
  } else if (res == 2) {
    x[0] -= a / 4;
    x[1] -= a / 4;
    x[2] -= a / 4;
  } else {
    x[0] -= a / 4;
    x[2] -= a / 4;
  }
  // one Newton step for each real root:
  if (res > 0) {
    x[0] = FN_NAME(n4Step)(x[0], a, b, c, d);
    x[1] = FN_NAME(n4Step)(x[1], a, b, c, d);
  }
  if (res > 2) {
    x[2] = FN_NAME(n4Step)(x[2], a, b, c, d);
    x[3] = FN_NAME(n4Step)(x[3], a, b, c, d);
  }
  return FN_NAME(setQuarticSolution)(res, x);
}

INL_FN_QUARTIC_SOL(solveQuartic)(const float a, const float b, const float c, const float d, const float e) {
  if (fabsf_vs(a) > 1e-14f) {
    return FN_NAME(solveP4)(b / a, c / a, d / a, e / a);
  } else if (fabsf_vs(b) > 1e-14f) {
    return FN_NAME(solveP3)(c / b, d / b, e / b);
  } else if (fabsf_vs(c) > 1e-14f) {
    return FN_NAME(solveP2)(d / c, e / c);
  } else if (fabsf_vs(d) > 1e-14f) {
    return FN_NAME(solveP1)(e / d);
  }
  solve_float_t x[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  return FN_NAME(setQuarticSolution)(0, x);
}