import {GraphModel, loadGraphModel} from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';

import {cropCanvas} from 'gelato/frontend/src/lib/canvas';
import ImageFrame from 'gelato/frontend/src/lib/ImageFrame';
import releaseCanvas from 'gelato/frontend/src/lib/releaseCanvas';
import Detector from 'gelato/frontend/src/ML/detectors/Detector';
import {getModelPath, prepareImageForModel} from 'gelato/frontend/src/ML/utils';

import type {DocumentSide} from '@stripe-internal/data-gelato/schema/types';
import type {Point} from 'gelato/frontend/src/ML/IDDetectorAPI';

const MB_PLAUSIBILITY_BACK_FOLDER =
  'microblink_plausibility_back/1608157922_9361_mb_back_barcode_with_onboarding_bigger';
const MB_PLAUSIBILITY_FRONT_FOLDER =
  'microblink_plausibility_front/1608825585_2931_200_conv_small_front_patch_mid_grayscale_20201221b';

export default class MicroBlinkPlausibilityDetector extends Detector<
  [HTMLCanvasElement | ImageFrame, Point, Point, boolean],
  number
> {
  public readonly detectorName: string;

  private mbPlausibilityModel: GraphModel | undefined;

  private readonly mbPlausibilityModelPath: string;

  private readonly imageSize: number;

  constructor(side: DocumentSide, imageSize: number) {
    super();
    this.mbPlausibilityModelPath =
      side === 'front'
        ? MB_PLAUSIBILITY_FRONT_FOLDER
        : MB_PLAUSIBILITY_BACK_FOLDER;
    this.imageSize = imageSize;
    this.detectorName = `MicroBlinkPlausibilityDetector_${side}`;
  }

  protected async _warmup() {
    if (!this.mbPlausibilityModel) {
      throw new Error(`Warmup called before ${this.detectorName} was loaded`);
    }
    const tensor = (await this.mbPlausibilityModel.predict(
      tf.zeros([1, this.imageSize, this.imageSize, 1]),
      {
        batchSize: 1,
      },
    )) as tf.Tensor<tf.Rank>;
    tensor.dispose();
  }

  protected async _build() {
    this.mbPlausibilityModel = await loadGraphModel(
      getModelPath(this.mbPlausibilityModelPath),
    );
  }

  protected async _detect(
    pixelSource: HTMLCanvasElement | ImageFrame,
    topLeft: Point,
    dimensions: Point,
    rotate: boolean,
  ) {
    const model = this.mbPlausibilityModel;
    if (!model) {
      throw new Error(`Detect called before ${this.detectorName} was loaded`);
    }

    // Create a temp canvas `mbPatchCanvas` which should be cleared later.
    const mbPatchCanvas = cropCanvas({
      pixelSource,
      sHeight: dimensions[0],
      sWidth: dimensions[1],
      sx: topLeft[0],
      sy: topLeft[1],
      dHeight: dimensions[0],
      dWidth: dimensions[1],
    });
    const pred = tf.tidy(() => {
      const image = tf.browser.fromPixels(mbPatchCanvas);
      const maybeRotated = rotate
        ? tf.squeeze(
            tf.image.rotateWithOffset(
              // @ts-expect-error - TS2345 - Argument of type 'Tensor<Rank>' is not assignable to parameter of type 'string | number | boolean | Float32Array | Int32Array | Uint8Array | Uint8Array[] | RecursiveArray<boolean> | RecursiveArray<...> | Tensor<...> | RecursiveArray<...>'.
              tf.expandDims(tf.cast(image, 'float32')),
              Math.PI / 2,
            ),
          )
        : image;
      const preparedImage = prepareImageForModel(
        maybeRotated,
        'mobilenet',
        true,
      );
      const batchPred = model.predict(preparedImage, {
        batchSize: 1,
      }) as tf.Tensor<tf.Rank>;
      return tf.squeeze(batchPred);
    });

    // Data returns Float[1] so pull out the actual float score and return.
    const resultArray = await pred.data();
    pred.dispose();
    releaseCanvas(mbPatchCanvas);
    return resultArray[0];
  }
}
