import * as tf from '@tensorflow/tfjs-core';

import {
  resizeCanvas,
  getPixelSourceDimensions,
} from 'gelato/frontend/src/lib/canvas';
import releaseCanvas from 'gelato/frontend/src/lib/releaseCanvas';
import Detector from 'gelato/frontend/src/ML/detectors/Detector';

export default class LightDetector extends Detector<
  [HTMLCanvasElement],
  number
> {
  public readonly detectorName = 'LightDetector';

  /**
   * Determines what luminance values are considered dark.
   *
   * This is a float value in the range [0, 1].
   */
  luminanceUpperBound: number;

  /**
   * The max dimension of given images. If an image has a
   * dimension larger than this, it will be resized proportionally.
   */
  maxDimension: number;

  constructor(luminanceUpperBound: number, maxDimension: number) {
    super();
    this.luminanceUpperBound = luminanceUpperBound;
    this.maxDimension = maxDimension;
  }

  protected async _build() {}

  private resizeIfNeeded(pixelSource: HTMLCanvasElement): HTMLCanvasElement {
    const {sourceWidth, sourceHeight} = getPixelSourceDimensions(pixelSource);
    return Math.max(sourceWidth, sourceHeight) > this.maxDimension
      ? resizeCanvas(pixelSource, this.maxDimension)
      : pixelSource;
  }

  /**
   * Returns a score indicating how dark an image is. The score
   * is the proportion of pixels in the image that have a luminance
   * value (ie. grayscale value) lower than the luminance upper bound.
   *
   * @param pixelSource an image
   * @returns a score in the range [0, 1]
   */
  protected async _detect(pixelSource: HTMLCanvasElement) {
    // If `resizedPixelSource` is different from `pixelSource`, it should
    // be cleared.
    const resizedPixelSource = this.resizeIfNeeded(pixelSource);

    // Compute darknessScore
    const resultTensor = tf.tidy(() => {
      const image = tf.browser.fromPixels(resizedPixelSource);
      const darkPixelMask = tf.image.threshold(
        image,
        'binary',
        true,
        this.luminanceUpperBound,
      );
      const binarizedDarkPixelMask = tf.div(darkPixelMask, 255);
      const numDarkPixels = tf.sum(binarizedDarkPixelMask);
      return tf.div(numDarkPixels, darkPixelMask.size);
    });

    const resultArray = await resultTensor.data();
    resultTensor.dispose();

    if (resizedPixelSource !== pixelSource) {
      // `resizedPixelSource` was created locally, clear it.
      releaseCanvas(resizedPixelSource);
    }
    return resultArray[0];
  }
}
