import '@tensorflow/tfjs-backend-webgl';
import * as tf from '@tensorflow/tfjs-core';

import {cropCanvas} from 'gelato/frontend/src/lib/canvas';
import ImageFrame from 'gelato/frontend/src/lib/ImageFrame';
import releaseCanvas from 'gelato/frontend/src/lib/releaseCanvas';
import Detector from 'gelato/frontend/src/ML/detectors/Detector';

import type {Point, Rectangle} from 'gelato/frontend/src/ML/IDDetectorAPI';
import type {Tensor} from 'gelato/frontend/src/ML/utils';

export default class BlurDetector extends Detector<
  [HTMLCanvasElement | ImageFrame, Rectangle, Point],
  number
> {
  public readonly detectorName = 'BlurDetector';

  private laplacianConvFilter: Tensor | undefined;

  private blurThreshold: number;

  private smoothingEnabled: boolean;

  constructor(blurThreshold: number, smoothingEnabled: boolean) {
    super();
    this.blurThreshold = blurThreshold;
    this.smoothingEnabled = smoothingEnabled;
  }

  protected async _build() {
    const laplacianKernel = [0, 1, 0, 1, -4, 1, 0, 1, 0];
    // Conv filter should be of rank 4
    this.laplacianConvFilter = tf.tensor(
      laplacianKernel,
      [3, 3, 1, 1],
      'float32',
    );
  }

  protected async _detect(
    pixelSource: HTMLCanvasElement | ImageFrame,
    location: Rectangle,
    patchDimensions: Point,
  ) {
    const centered = tf.tidy(() => {
      const patch = cropCanvas({
        pixelSource,
        sHeight: location.dimensions[0],
        sWidth: location.dimensions[1],
        sx: location.topLeft[0],
        sy: location.topLeft[1],
        dHeight: patchDimensions[0],
        dWidth: patchDimensions[1],
        smoothingEnabled: this.smoothingEnabled,
      });
      const image = tf.browser.fromPixels(patch);
      // Convert to grey scale
      const imageGreyscale = tf.expandDims(tf.mean(image, 2), 2);
      // Add batch dimension
      const imageGreyscaleBatched = tf.expandDims(imageGreyscale, 0);
      // Now compute the blur index - apply convolution, squeeze
      const imageConvolved = tf.reshape(
        tf.squeeze(
          tf.conv2d(
            // @ts-expect-error - TS2345 - Argument of type 'Tensor<Rank>' is not assignable to parameter of type 'TensorLike | Tensor3D | Tensor4D'.
            imageGreyscaleBatched,
            this.laplacianConvFilter,
            [1, 1],
            'valid',
          ),
        ),
        [-1],
      );

      // and find standard deviation
      // Now normalize: for non blurry it is > 1.0
      // HACK- this is insane, but computing the variance in TF on iOS Safari
      // is unstable and yields infinities and crashes Mobile Safari.
      // See: https://github.com/tensorflow/tfjs/issues/2475
      const mean = tf.mean(imageConvolved);
      releaseCanvas(patch);
      return tf.sub(imageConvolved, mean);
    });
    // Compute the variance from centered series.
    const values = await centered.data();
    // Something is going on with Blur.
    // On some platforms a non-trivial % of uploads have "0" blur, despite the other models
    // running fine.
    if (!values) {
      return 0;
    }
    let sum = 0;
    for (let i = 0; i < values.length; i++) {
      sum += values[i] * values[i];
      if (!isFinite(sum)) {
        return 0;
      }
    }
    centered.dispose();
    return sum / this.blurThreshold / values.length;
  }
}
