// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

// templated kernel:
// need to #define PhotoCorrectionT
// before including this file

// For a pixel coordinates in the panorama (output) space, compute the photo-corrected pixel value
__device__ float4 FUNCTION_NAME_2(photoCorrectionFunction4, PhotoCorrectionT)(
    const float2 uv,
    read_only image2d_t texCL,
    const int texWidth, const int texHeight,
    const float3 rgbMult,
    const float photoParam, const lut_ptr float* lutPtr,
    const float vigCenterX, const float vigCenterY,
    const float inverseDemiDiagonalSquared,
    const float vigCoeff0, const float vigCoeff1,
    const float vigCoeff2, const float vigCoeff3) {

  // Compute vignetting:
  // Done before uv is shifted for texture fetch.
  // vigMult = a0 + a1 * r + a2 * r^2 + a3 * r^3
  // = a0 + r * (a1 + r * (a2 + r * a3))
  float vigMult;
  float vigRadiusX = uv.x - ((float)texWidth / 2.0f) - vigCenterX;
  float vigRadiusY = uv.y - ((float)texHeight / 2.0f) - vigCenterY;
  float vigRadiusSquared = (vigRadiusX * vigRadiusX + vigRadiusY * vigRadiusY) * inverseDemiDiagonalSquared;
  vigMult = vigRadiusSquared * vigCoeff3;
  vigMult += vigCoeff2;
  vigMult *= vigRadiusSquared;
  vigMult += vigCoeff1;
  vigMult *= vigRadiusSquared;
  vigMult += vigCoeff0;
  vigMult = 1.0f / vigMult;

  const float4 color = read_texture_vs(texCL, uv) * 255.0f;

  float3 rgb = get_xyz(color);

  // photo correction
  rgb = FUNCTION_NAME_2(PhotoCorrectionT, corr) (rgb, photoParam, lutPtr);

  // exposure correction
  rgb.x = rgbMult.x * rgb.x * vigMult;
  rgb.y = rgbMult.y * rgb.y * vigMult;
  rgb.z = rgbMult.z * rgb.z * vigMult;

  // inverse photo correction
  rgb = FUNCTION_NAME_2(PhotoCorrectionT, invCorr) (rgb, photoParam, lutPtr);

  return make_float4(rgb.x, rgb.y, rgb.z, color.w);
}

__device__ uint32_t FUNCTION_NAME_2(photoCorrectionFunction, PhotoCorrectionT)(
    const float2 uv,
    read_only image2d_t texCL,
    const int texWidth, const int texHeight,
    const float3 rgbMult,
    const float photoParam, const lut_ptr float* lutPtr,
    const float vigCenterX, const float vigCenterY,
    const float inverseDemiDiagonalSquared,
    const float vigCoeff0, const float vigCoeff1,
    const float vigCoeff2, const float vigCoeff3) {

  const float4 rgba = FUNCTION_NAME_2(photoCorrectionFunction4, PhotoCorrectionT)(uv, texCL, texWidth, texHeight, rgbMult, photoParam, lutPtr, vigCenterX, vigCenterY, inverseDemiDiagonalSquared, vigCoeff0, vigCoeff1, vigCoeff2, vigCoeff3);

  // discretize
  const int32_t r = __float2int_rn(rgba.x);
  const int32_t g = __float2int_rn(rgba.y);
  const int32_t b = __float2int_rn(rgba.z);

  const int32_t alpha = (int32_t)rgba.w;

  return Image_RGBA_pack(Image_clamp8(r), Image_clamp8(g), Image_clamp8(b), alpha);
}