// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

#define RGBDiff int3

/**
 * Converts a YUV444 sample to RGB diff components.
 */
__device__ RGBDiff yuv444ToRGBDiff(unsigned char u, unsigned char v) {
  /* see http://www.fourcc.org/fccyvrgb.php :
          R = 1.164(Y - 16) + 1.596(V - 128)
          G = 1.164(Y - 16) - 0.813(V - 128) - 0.391(U - 128)
          B = 1.164(Y - 16)                   + 2.018(U - 128)
  */
  int cr = v - (int)128;
  int cb = u - (int)128;
  RGBDiff retval = {(1596 * cr) / 1000,  // FIXME: 1024 ?
                    (-813 * cr - 391 * cb) / 1000, (2018 * cb) / 1000};
  return retval;
}

__global__ void convertNV12ToRGBAKernel(surface_t dst, global_mem const unsigned char* src, unsigned width,
                                        unsigned height) {
  // each thread is responsible for a 2x2 pixel group
  unsigned sx = 2 * get_global_id_x();
  unsigned sy = get_global_id_y();

  global_mem const unsigned char* uvSrc = src + width * height;

  if (sx < width && sy < height / 2) {
    const int3 rgbDiff = yuv444ToRGBDiff(uvSrc[sy * width + sx], uvSrc[sy * width + sx + 1]);
    nv12_surface_write(YRGBDiffToRGBA(src[(2 * sy) * width + sx], rgbDiff), dst, sx, 2 * sy);
    nv12_surface_write(YRGBDiffToRGBA(src[(2 * sy) * width + sx + 1], rgbDiff), dst, (sx + 1), 2 * sy);
    nv12_surface_write(YRGBDiffToRGBA(src[(2 * sy + 1) * width + sx], rgbDiff), dst, sx, 2 * sy + 1);
    nv12_surface_write(YRGBDiffToRGBA(src[(2 * sy + 1) * width + sx + 1], rgbDiff), dst, (sx + 1), 2 * sy + 1);
  }
}