import { Curve, CurveChannels, IFiltersValues, Point } from './interfaces';
import { clamp } from './mathUtils';
import { times } from './utils';

function secondDerivative(P: Point[]) {
  const n = P.length;

  // build the tridiagonal system (assume 0 boundary conditions: y2[0]=y2[-1]=0) 
  const matrix = times(n, () => new Float32Array(3));
  const result = new Float32Array(n);
  matrix[0][1] = 1;

  for (let i = 1; i < n - 1; i++) {
    matrix[i][0] = (P[i].x - P[i - 1].x) / 6;
    matrix[i][1] = (P[i + 1].x - P[i - 1].x) / 3;
    matrix[i][2] = (P[i + 1].x - P[i].x) / 6;
    result[i] = (P[i + 1].y - P[i].y) / (P[i + 1].x - P[i].x) - (P[i].y - P[i - 1].y) / (P[i].x - P[i - 1].x);
  }

  matrix[n - 1][1] = 1;

  // solving pass1 (up->down)
  for (let i = 1; i < n; i++) {
    const k = matrix[i][0] / matrix[i - 1][1];
    matrix[i][1] -= k * matrix[i - 1][2];
    matrix[i][0] = 0;
    result[i] -= k * result[i - 1];
  }

  // solving pass2 (down->up)
  for (let i = n - 2; i >= 0; i--) {
    const k = matrix[i][2] / matrix[i + 1][1];
    matrix[i][1] -= k * matrix[i + 1][0];
    matrix[i][2] = 0;
    result[i] -= k * result[i + 1];
  }

  // return second derivative value for each point P
  const y2 = new Float32Array(n);
  for (let i = 0; i < n; i++) y2[i] = result[i] / matrix[i][1];
  return y2;
}

function calcY(points: Point[], i: number, x: number, sd: Float32Array) {
  const cur = points[i];
  const next = points[i + 1];
  const t = (x - cur.x) / (next.x - cur.x);
  const a = 1 - t;
  const b = t;
  const h = next.x - cur.x;
  const yy = a * cur.y + b * next.y + (h * h / 6) * ((a * a * a - a) * sd[i] + (b * b * b - b) * sd[i + 1]);
  const y = clamp(yy, 0, 1);
  return y;
}

export function calculateCurvesWeights(points: Point[]) {
  const result = new Uint8Array(256);
  const sd = secondDerivative(points);
  const step = 1 / 256;

  for (let i = 0, x = 0; i < 256; i++, x += step) {
    let y = 0;

    if (x <= points[0].x) {
      y = points[0].y;
    } else if (x >= points[points.length - 1].x) {
      y = points[points.length - 1].y;
    } else {
      let j = 0;
      while (x >= points[j + 1].x) j++;
      y = calcY(points, j, x, sd);
    }

    result[i] = clamp(Math.round((1 - y) * 256), 0, 255);
  }

  return result;
}


function getOutputColor(colorValues: Uint8Array, inputColor: number) {
  return colorValues[inputColor] || inputColor;
}

export function getCurveValues(curvePoints: Curve[]) {
  const curveValues = [
    calculateCurvesWeights(curvePoints[0].points),
    calculateCurvesWeights(curvePoints[1].points),
    calculateCurvesWeights(curvePoints[2].points),
    calculateCurvesWeights(curvePoints[3].points),
  ];
  return curveValues;
}


export function applyCurves(source: Uint8ClampedArray, destination: Uint8ClampedArray, values: IFiltersValues) {
  let { curvePoints = [] } = values;
  const curveValues = getCurveValues(curvePoints);

  for (let index = 0; index < source.length; index += 4) {
    // applying RGB Curve
    destination[index] = getOutputColor(curveValues[CurveChannels.RGB], source[index]);
    destination[index + 1] = getOutputColor(curveValues[CurveChannels.RGB], source[index + 1]);
    destination[index + 2] = getOutputColor(curveValues[CurveChannels.RGB], source[index + 2]);
    // applying RED, GREEN, BLUE Curves
    destination[index] = getOutputColor(curveValues[CurveChannels.RED], destination[index]);
    destination[index + 1] = getOutputColor(curveValues[CurveChannels.GREEN], destination[index + 1]);
    destination[index + 2] = getOutputColor(curveValues[CurveChannels.BLUE], destination[index + 2]);
    destination[index + 3] = source[index + 3];
  }
}