import * as tf from '@tensorflow/tfjs-core';
import {produce} from 'immer';

import {Feedback} from 'gelato/frontend/src/controllers/states/DocumentState';
import ImageFrame from 'gelato/frontend/src/lib/ImageFrame';
import BaseInspector from 'gelato/frontend/src/ML/detectors/BaseInspector';
import {
  BLUR_THRESHOLD,
  BLUR_IMAGE_SIZE,
} from 'gelato/frontend/src/ML/lib/constants';

import type {InspectionState} from 'gelato/frontend/src/controllers/states/DocumentState';
import type {ApplicationState} from 'gelato/frontend/src/controllers/types';
import type {Tensor} from 'gelato/frontend/src/ML/utils';

// Based on this Hubble query:
// https://hubble.corp.stripe.com/queries/hedger/ee147edf/simple-pivot-percent-barchart
// We pick a threshold of 3.0 for the blur score to optimze for the best
// tradeoff between "IDP verified" rate and "web completion" rate.
const MIN_BLUR_SCORE = 3;

// The singleton instance.
let instance: BlurInspector | null = null;

/**
 * The inspector that inspects the blur score for the document.
 * The score ranges from 0 to 250, where 0 means the blur score isn't available.
 * and 250 means the document is the least blurry.
 */
export default class BlurInspector extends BaseInspector<
  [Readonly<ApplicationState>, Readonly<InspectionState>],
  Readonly<InspectionState>
> {
  private _laplacianConvFilter: Tensor | null = null;

  static displayName = 'BlurInspector';

  /**
   * Whether the inspector is supported in the current environment.
   */
  static isSupported(): boolean {
    // BlurInspector uses "tfjs" which requires WebGL.
    return BaseInspector.isWebGLSupported();
  }

  static getInstance(): BlurInspector {
    if (!instance) {
      instance = new BlurInspector();
    }
    return instance;
  }

  constructor() {
    super(BlurInspector.displayName);
  }

  /**
   * @implements {BaseInspector}
   */
  protected async buildImpl(): Promise<void> {
    await BaseInspector.setUpTensorflow();

    const laplacianKernel = [0, 1, 0, 1, -4, 1, 0, 1, 0];
    // Conv filter should be of rank 4
    this._laplacianConvFilter = tf.tensor(
      laplacianKernel,
      [3, 3, 1, 1],
      'float32',
    );
  }

  /**
   * @implements {BaseInspector}
   */
  protected async warmUpImpl(): Promise<void> {
    // Do nothing.
  }

  /**
   * @implements {BaseInspector}
   */
  protected async detectImpl(
    appState: Readonly<ApplicationState>,
    inspectionState: Readonly<InspectionState>,
  ): Promise<Readonly<InspectionState>> {
    const {inputImage, documentLocation} = inspectionState;

    // Testmode QR doesn't need it
    if (inspectionState.testmodeJsQRDetectionResult?.isValid) {
      return inspectionState;
    }

    if (!documentLocation) {
      // It's likely that the document is not ready yet. Reset the blur
      // score.
      return produce(inspectionState, (draft) => {
        draft.documentBlurScore = 0;
      });
    }

    // Compute the blur score.
    const {topLeft, dimensions} = documentLocation!;
    const padding = 2;
    const sourceImage = await inputImage.crop(
      topLeft[0] - padding,
      topLeft[1] - padding,
      dimensions[0] + padding * 2,
      dimensions[1] + padding * 2,
    );
    const blurScore = await this.computeBlurScore(sourceImage);

    return produce(inspectionState, (draft) => {
      draft.documentBlurScore = blurScore;

      // If the blur score is too low, reject the image.
      if (
        blurScore < MIN_BLUR_SCORE &&
        !draft.feedback &&
        draft.documentIsValid &&
        draft.documentImage &&
        draft.documentLocation
      ) {
        // Reject the image that isn't sharp enough.
        draft.documentIsValid = false;
        draft.feedback = Feedback.tooBlurry;
      }
    });
  }

  /**
   * @implements {BaseInspector}
   */
  protected async disposeImpl(): Promise<void> {
    this._laplacianConvFilter?.dispose();
    this._laplacianConvFilter = null;

    if (instance === this) {
      // The instance is being disposed and no longer usable. Clear the
      // singleton instance.
      instance = null;
    }
  }

  /**
   * Compute the blur score for the given image.
   * @param sourceImage
   * @returns The blur score.
   */
  async computeBlurScore(sourceImage: ImageFrame): Promise<number> {
    // Center the image in a 224x224 frame as a cover, cropping overflow.
    // This aligns with legacy BlurDetector behavior.
    // Future consideration: Maybe switch to `fitToContain()` to avoid cropping.

    const sampleImage = await sourceImage.fitToCover(
      BLUR_IMAGE_SIZE,
      BLUR_IMAGE_SIZE,
    );

    const source = sampleImage.getSource();
    if (!source) {
      return 0;
    }

    const centered = tf.tidy(() => {
      const tfImage = tf.browser.fromPixels(source!);
      sampleImage.dispose();

      // Convert to grey scale
      const imageGreyscale = tf.expandDims(tf.mean(tfImage, 2), 2);
      tfImage.dispose();

      // Add batch dimension
      const imageGreyscaleBatched = tf.expandDims(imageGreyscale, 0);
      imageGreyscale.dispose();

      // Now compute the blur index - apply convolution, squeeze
      const imageConvolved = tf.reshape(
        tf.squeeze(
          tf.conv2d(
            imageGreyscaleBatched.arraySync(),
            this._laplacianConvFilter!,
            [1, 1],
            'valid',
          ),
        ),
        [-1],
      );
      imageGreyscaleBatched.dispose();

      // and find standard deviation
      // Now normalize: for non blurry it is > 1.0
      // HACK- this is insane, but computing the variance in TF on iOS Safari
      // is unstable and yields infinities and crashes Mobile Safari.
      // See: https://github.com/tensorflow/tfjs/issues/2475
      const mean = tf.mean(imageConvolved);
      const rank = tf.sub(imageConvolved, mean);
      imageConvolved.dispose();

      return rank;
    });
    // Compute the variance from centered series.
    const values = await centered.data();
    // Something is going on with Blur.
    // On some platforms a non-trivial % of uploads have "0" blur, despite the other models
    // running fine.
    if (!values) {
      return 0;
    }
    let sum = 0;
    for (let i = 0; i < values.length; i++) {
      sum += values[i] * values[i];
      if (!isFinite(sum)) {
        return 0;
      }
    }
    centered.dispose();
    const blurScore = sum / BLUR_THRESHOLD / values.length;
    return blurScore;
  }
}
