import * as tf from '@tensorflow/tfjs-core';

import {
  cropCanvas,
  getPixelSourceDimensions,
} from 'gelato/frontend/src/lib/canvas';
import ImageFrame from 'gelato/frontend/src/lib/ImageFrame';
import releaseCanvas from 'gelato/frontend/src/lib/releaseCanvas';
import IDProbabilityDetector from 'gelato/frontend/src/ML/detectors/IDProbabilityDetector';
import {prepareImageForModel} from 'gelato/frontend/src/ML/utils';

import type {
  ModelProbabilities,
  Rectangle,
} from 'gelato/frontend/src/ML/IDDetectorAPI';

export function resizeCanvasForProcessing(
  pixelSource: HTMLCanvasElement | ImageFrame,
  imageSize: number,
): HTMLCanvasElement {
  let outputWidth = imageSize;
  let outputHeight = imageSize;
  const {sourceWidth, sourceHeight} = getPixelSourceDimensions(pixelSource);

  if (sourceWidth > sourceHeight) {
    outputHeight = (imageSize * sourceHeight) / sourceWidth;
  } else {
    outputWidth = (imageSize * sourceWidth) / sourceHeight;
  }

  return cropCanvas({
    backgroundColor: 'black',
    pixelSource,
    dHeight: outputHeight,
    dWidth: outputWidth,
    height: imageSize,
    width: imageSize,
    center: true,
  });
}

export async function findIdInImage(
  pixelSource: HTMLCanvasElement | ImageFrame,
  idProbabilityDetector: IDProbabilityDetector,
  imageSize: number,
): Promise<{
  probability: ModelProbabilities;
  location: Rectangle;
}> {
  const resizedPixelSource = resizeCanvasForProcessing(pixelSource, imageSize);
  const image = tf.browser.fromPixels(resizedPixelSource);
  const preparedImage = prepareImageForModel(image, 'mobilenet');
  image.dispose();
  releaseCanvas(resizedPixelSource);

  const {probabilities, topLeft, dimensions} =
    await idProbabilityDetector.detect(preparedImage);

  return {
    probability: {
      noDocument: probabilities[0],
      frontPassport: probabilities[1],
      frontCard: probabilities[2],
      back: probabilities[3],
      invalid: probabilities[4],
    },
    location: {topLeft, dimensions},
  };
}

export default findIdInImage;
