import {GraphModel, loadGraphModel} from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';
import {produce} from 'immer';

import {
  Feedback,
  InputMethod,
} from 'gelato/frontend/src/controllers/states/DocumentState';
import getDocumentTypeByDetectionResult from 'gelato/frontend/src/lib/getDocumentTypeByDetectionResult';
import ImageFrame from 'gelato/frontend/src/lib/ImageFrame';
import BaseInspector from 'gelato/frontend/src/ML/detectors/BaseInspector';
import {
  IMAGE_SIZE,
  PROBABILTY_THRESHOLD,
  MINIMUM_PROBABILTY_THRESHOLD,
  MINIMUM_PASSPORT_FRONT_PROBABILTY_THRESHOLD,
  STRONG_IDDETECTOR_PREDICTION_THRESHOLD,
} from 'gelato/frontend/src/ML/lib/constants';
import {getModelPath, prepareImageForModel} from 'gelato/frontend/src/ML/utils';

import type {
  FeedbackValue,
  InspectionState,
  IDInspectorResult,
} from 'gelato/frontend/src/controllers/states/DocumentState';
import type {ApplicationState} from 'gelato/frontend/src/controllers/types';
import type {
  ModelProbabilities,
  Rectangle,
  Point,
} from 'gelato/frontend/src/ML/IDDetectorAPI';

// The model folder. It should contain the following files:
// gelato/frontend/public/assets/id_detectors/iddetectorssd_apr23_wrapped/model.json
const MODEL_FOLDER = 'id_detectors/iddetectorssd_apr23_wrapped';

// The minimum probability threshold for a document to be considered valid.
// Note that this value must be higher than PROBABILTY_THRESHOLD which
// is used to determine if a document is present regardless of its validity.
// Based on this Hubble query:
// https://hubble.corp.stripe.com/queries/hedger/f3cf39cc/simple-pivot-percent-barchart
// We pick a threshold of 0.69 to optimze for the best tradeoff between "IDP
// verified" rate and "web completion" rate.
export const MIN_VALID_PROBABILITY = Math.max(PROBABILTY_THRESHOLD, 0.69);

// The singleton instance.
let instance: IDInspector | null = null;

/**
 * The ID Inspector.
 * It inspects the source image and identifies the document type and location.
 */
export default class IDInspector extends BaseInspector<
  [Readonly<ApplicationState>, Readonly<InspectionState>],
  Readonly<InspectionState>
> {
  static displayName = 'IDInspector';

  /**
   * The ID Probability Model
   */
  private _model: GraphModel | null = null;

  /**
   * Whether the inspector is supported in the current environment.
   */
  static isSupported(): boolean {
    // IDInspector uses "tfjs" which requires WebGL.
    return BaseInspector.isWebGLSupported();
  }

  static getInstance(): IDInspector {
    if (!instance) {
      instance = new IDInspector();
    }
    return instance;
  }

  constructor() {
    super(IDInspector.displayName);
  }

  /**
   * @implements {BaseInspector}
   */
  protected async buildImpl(): Promise<void> {
    await BaseInspector.setUpTensorflow();
    this._model = await loadGraphModel(getModelPath(MODEL_FOLDER));
  }

  /**
   * @implements {BaseInspector}
   */
  protected async warmUpImpl(): Promise<void> {
    const model = this._model!;

    const output = await model!.executeAsync(
      tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3]),
    );

    if (output instanceof tf.Tensor) {
      output.dispose();
    } else if (Array.isArray(output)) {
      output.forEach((t) => t.dispose());
    }
  }

  /**
   * @implements {BaseInspector}
   */
  protected async detectImpl(
    appState: Readonly<ApplicationState>,
    inspectionState: Readonly<InspectionState>,
  ): Promise<Readonly<InspectionState>> {
    const {inputImage} = inspectionState;
    const result = await this._inspect(appState, inputImage);

    return produce(inspectionState, (draft) => {
      draft.idInspectorResult = result;
    });
  }

  /**
   * @implements {BaseInspector}
   */
  protected async disposeImpl(): Promise<void> {
    this._model?.dispose();
    this._model = null;

    if (instance === this) {
      // The instance is being disposed and no longer usable. Clear the
      // singleton instance.
      instance = null;
    }
  }

  /**
   * Inspect the source image and return the result.
   * @param sourceImage The source image to inspect.
   * @returns The result of the inspection or null if the document is not
   *   detected.
   */
  private async _inspect(
    appState: Readonly<ApplicationState>,
    sourceImage: ImageFrame,
  ): Promise<IDInspectorResult> {
    const model = this._model!;

    // For the sake of model needs, we'll convert the image into a 224x224 image
    // and inspect it.

    // 1. Obtain the source image captured by the camera or file upload.
    // ┌──────────────────┐
    // │                  │
    // │                  │
    // │                  │
    // └──────────────────┘
    // 2. Scale the image down so that its longer side is exactly 224px.
    // ┌────────┐
    // │        │
    // └────────┘
    // 3. Place the image in the center of a 224x224 square box.
    // ┌────────┐
    // ┌────────┐
    // │        │
    // └────────┘
    // └────────┘
    // 4. Compute the location of the document in the 224x224 box.
    // 5. Convert the location to the original image coordinates.

    const inspectedImage = await sourceImage.clone();

    // Compute the scale to fit the image into the model input size (224x224).
    const scale = inspectedImage.computeFitToScale(IMAGE_SIZE, IMAGE_SIZE);

    // Get the image that could fit into the model input size (224x224).
    const scaledImage = await inspectedImage.scaleTo(scale);

    // Get the square image that the model expects.
    const squareImage = await scaledImage.toSquare();
    scaledImage.dispose();

    const pixelSource = squareImage.getSource();

    if (squareImage.placeholder || !pixelSource) {
      // Do not inspect the placeholder or empty image.
      squareImage.dispose();
      return {
        documentType: null,
        feedback: null,
        inspectedImage,
        isValid: false,
        location: null,
        probability: null,
      };
    }

    const tfImage = tf.browser.fromPixels(pixelSource!);
    squareImage.dispose();

    const modelImage = prepareImageForModel(tfImage, 'mobilenet');
    tfImage.dispose();

    const coordinateLogitsBatch = await model.executeAsync(modelImage);
    modelImage.dispose();

    if (!(coordinateLogitsBatch instanceof tf.Tensor)) {
      // This should never happen.
      throw new Error(`Expected Tensor, got ${coordinateLogitsBatch}`);
    }

    const {coordinatesNormedTensor, probabilitiesTensor} = tf.tidy(() => {
      // predict returns a rank-3 tensor, batch dimension has only one component
      const coordinateLogits = tf.squeeze(coordinateLogitsBatch);

      // Current model coordinates are normalized to [0, 1]
      const coordinatesNormedTensor = tf.slice(coordinateLogits, [0], [4]);

      // Probability is last 5 outputs:
      // 0 = No Document,
      // 1 = passport
      // 2 = ID card front
      // 3 = ID card back
      // 4 = Invalid ID
      const classes = 5;
      const probabilitiesTensor = tf.slice(coordinateLogits, [4], [classes]);

      return {coordinatesNormedTensor, probabilitiesTensor};
    });

    const coordinatesNormed = await coordinatesNormedTensor.data();
    const topLeft: Point = [coordinatesNormed[0], coordinatesNormed[1]];
    const dimensions: Point = [coordinatesNormed[2], coordinatesNormed[3]];
    const probabilities = await probabilitiesTensor.data();

    coordinateLogitsBatch.dispose();
    coordinatesNormedTensor.dispose();
    probabilitiesTensor.dispose();

    const probability: ModelProbabilities = {
      noDocument: probabilities[0],
      frontPassport: probabilities[1],
      frontCard: probabilities[2],
      back: probabilities[3],
      invalid: probabilities[4],
    };

    const documentType =
      getDocumentTypeByDetectionResult({
        probability,
      }) || null;

    // Convert the normalized coordinates to the original image coordinates.
    let [lx, ly] = topLeft;
    let [dw, dh] = dimensions;

    // Convert the percentage to the actual pixel values.
    const {width: sw, height: sh} = squareImage;
    lx *= sw;
    ly *= sh;
    dw *= sw;
    dh *= sh;

    // Remove the padding that was added to make the image square.
    const paddingY = (sh - scaledImage.height) / 2;
    const paddingX = (sw - scaledImage.width) / 2;
    lx -= paddingX;
    ly -= paddingY;

    // Scale the coordinates back to the original image size.
    const reverseScale = 1 / scale;
    lx *= reverseScale;
    ly *= reverseScale;
    dw *= reverseScale;
    dh *= reverseScale;

    const location: Rectangle = {
      topLeft: [Math.floor(lx), Math.floor(ly)],
      dimensions: [Math.ceil(dw), Math.ceil(dh)],
    };

    const feedback = computeFeedback(appState, probability);

    const side = appState.document.workingSide;

    let isValid = false;
    if (!feedback) {
      if (side === 'front') {
        isValid =
          probability.frontCard >= MIN_VALID_PROBABILITY ||
          probability.frontPassport >= MIN_VALID_PROBABILITY;
      } else {
        isValid = probability.back >= MIN_VALID_PROBABILITY;
      }
    }

    return {
      documentType,
      feedback,
      inspectedImage,
      isValid,
      location,
      probability,
    };
  }
}

/**
 * Compute the feedback based on the result.
 * @param appState The application state.
 * @param result The result of the inspection.
 * @returns The feedback. Null if no feedback is needed.
 */
export function computeFeedback(
  appState: Readonly<ApplicationState>,
  modelProbabilities: ModelProbabilities,
): FeedbackValue | null {
  const side = appState.document.workingSide;

  const {frontCard, frontPassport, back, invalid, noDocument} =
    modelProbabilities;

  const frontTotal = frontCard + frontPassport;

  const isFront = side === 'front';

  // HACK - we allow "invalid" backs b/c there are a lot of weird international
  // documents that we have no data for the backs. So for now allow invalid
  // backs.
  const probability = isFront ? frontTotal : back;

  // In auto capture mode, tell user to move the document into view when no
  // document is detected.
  const noDocumentFeedback =
    appState.document.workingInputMethod === InputMethod.cameraAutoCapture
      ? Feedback.moveIntoView
      : Feedback.noDocument;

  if (probability < PROBABILTY_THRESHOLD) {
    // Correct side not detected- figure out proper error
    // First check if there's energy in "noDocument"
    if (noDocument >= MINIMUM_PROBABILTY_THRESHOLD) {
      return noDocumentFeedback;
    } else if (invalid > PROBABILTY_THRESHOLD) {
      return Feedback.invalidDocument;
    } else if (
      !isFront &&
      frontTotal >= STRONG_IDDETECTOR_PREDICTION_THRESHOLD
    ) {
      // When model is very confident that this is a front, then hard block.
      return Feedback.useFrontSideBlock;
    } else if (!isFront && frontTotal >= MINIMUM_PROBABILTY_THRESHOLD) {
      return Feedback.useFrontSide;
    } else if (isFront && back >= MINIMUM_PROBABILTY_THRESHOLD) {
      return Feedback.useBackSide;
    }
    // Default to noDocument if we are in some edge case where energy is distributed ~evenly.
    return noDocumentFeedback;
  }

  const documentTypeAllowlist = appState.session?.documentTypeAllowlist;
  const passportOnly =
    documentTypeAllowlist?.length === 1 &&
    documentTypeAllowlist[0] === 'passport';

  if (
    passportOnly &&
    frontPassport < MINIMUM_PASSPORT_FRONT_PROBABILTY_THRESHOLD
  ) {
    // TODO(hedger): Maybe we should run this check only with a separate
    // experiment since we don't know how it will affect the conversion rate.
    // if we block passports that are not detected with high confidence.
    return Feedback.usePassport;
  }

  return null;
}
