import {GraphModel, loadGraphModel} from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';

import Detector from 'gelato/frontend/src/ML/detectors/Detector';
import {getModelPath} from 'gelato/frontend/src/ML/utils';

import type {Point} from 'gelato/frontend/src/ML/IDDetectorAPI';

const DETECTOR_MODEL_FOLDER = 'id_detectors/iddetectorssd_apr23_wrapped';

type Results = {
  probabilities: Float32Array | Int32Array | Uint8Array;
  topLeft: Point;
  dimensions: Point;
};

export default class IDProbabilityDetector extends Detector<
  [tf.Tensor<tf.Rank>],
  Results
> {
  public readonly detectorName = 'IDProbabilityDetector';

  private idProbabilityModel: GraphModel | undefined;

  private readonly imageSize: number;

  constructor(imageSize: number) {
    super();
    this.imageSize = imageSize;
  }

  protected async _warmup() {
    if (!this.idProbabilityModel) {
      throw new Error(`Warmup called before ${this.detectorName} was loaded`);
    }
    const tensor = (await this.idProbabilityModel.executeAsync(
      tf.zeros([1, this.imageSize, this.imageSize, 3]),
    )) as tf.Tensor<tf.Rank>;
    tensor.dispose();
  }

  protected async _build() {
    this.idProbabilityModel = await loadGraphModel(
      getModelPath(DETECTOR_MODEL_FOLDER),
    );
  }

  protected async _detect(preparedImage: tf.Tensor<tf.Rank>) {
    const model = this.idProbabilityModel;
    if (!model) {
      throw new Error(`Detect called before ${this.detectorName} was loaded`);
    }

    const coordinateLogitsBatch = (await model.executeAsync(
      preparedImage,
    )) as tf.Tensor<tf.Rank>;
    preparedImage.dispose();

    const {coordinatesNormedTensor, probabilitiesTensor} = tf.tidy(() => {
      // predict returns a rank-3 tensor, batch dimension has only one component
      const coordinateLogits = tf.squeeze(coordinateLogitsBatch);

      // Current model coordinates are normalized to [0, 1]
      const coordinatesNormedTensor = tf.slice(coordinateLogits, [0], [4]);

      // Probability is last 5 outputs:
      // 0 = No Document,
      // 1 = passport
      // 2 = ID card front
      // 3 = ID card back
      // 4 = Invalid ID
      const classes = 5;
      const probabilitiesTensor = tf.slice(coordinateLogits, [4], [classes]);

      return {coordinatesNormedTensor, probabilitiesTensor};
    });

    const coordinatesNormed = await coordinatesNormedTensor.data();
    const topLeft: Point = [coordinatesNormed[0], coordinatesNormed[1]];
    const dimensions: Point = [coordinatesNormed[2], coordinatesNormed[3]];

    const probabilities = await probabilitiesTensor.data();

    coordinateLogitsBatch.dispose();
    coordinatesNormedTensor.dispose();
    probabilitiesTensor.dispose();

    return {probabilities, topLeft, dimensions};
  }
}
