import {GraphModel, loadGraphModel} from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';

import releaseCanvas from 'gelato/frontend/src/lib/releaseCanvas';
import Detector from 'gelato/frontend/src/ML/detectors/Detector';
import {resizeCanvasForProcessing} from 'gelato/frontend/src/ML/lib/findIdInImage';
import {prepareImageForModel, getModelPath} from 'gelato/frontend/src/ML/utils';

const MB_SUPPORT_PATH = 'microblink_support_model/mbsupport_2022-10-30';

export default class MicroBlinkSupportDetector extends Detector<
  [HTMLCanvasElement],
  number
> {
  public readonly detectorName: string = 'MicroBlinkSupportDetector';

  private mbSupportModel: GraphModel | undefined;

  private readonly imageSize: number;

  constructor(imageSize: number) {
    super();
    this.imageSize = imageSize;
  }

  protected async _warmup() {
    if (!this.mbSupportModel) {
      throw new Error(`Warmup called before ${this.detectorName} was loaded`);
    }
    const tensor = (await this.mbSupportModel.predict(
      tf.zeros([1, this.imageSize, this.imageSize, 3]),
      {batchSize: 1},
    )) as tf.Tensor<tf.Rank>;
    tensor.dispose();
  }

  protected async _build() {
    this.mbSupportModel = await loadGraphModel(getModelPath(MB_SUPPORT_PATH));
  }

  protected async _detect(pixelSource: HTMLCanvasElement) {
    const model = this.mbSupportModel;
    if (!model) {
      throw new Error(`Detect called before ${this.detectorName} was loaded`);
    }

    const squareCanvas = resizeCanvasForProcessing(pixelSource, this.imageSize);
    const pred = tf.tidy(() => {
      const image = tf.browser.fromPixels(squareCanvas);
      const preparedImage = prepareImageForModel(image, 'imagenet', false);
      const batchPred = model.predict(preparedImage, {
        batchSize: 1,
      }) as tf.Tensor<tf.Rank>;
      return tf.squeeze(batchPred);
    });

    // Data returns Float[1] so pull out the actual float score and return.
    const resultArray = await pred.data();
    pred.dispose();
    releaseCanvas(squareCanvas);
    return resultArray[0];
  }
}
