import * as Sentry from '@sentry/browser';
import * as tf from '@tensorflow/tfjs-core';

import {getConfigValue} from 'gelato/frontend/src/lib/config';
import {isIOS15OrGreater, isIOS} from 'gelato/frontend/src/lib/device';
import {reportMetric} from 'gelato/frontend/src/lib/metricsBatcher';
import {getPlatform} from 'gelato/frontend/src/lib/userAgent';

import type {Point} from 'gelato/frontend/src/ML/IDDetectorAPI';

// TODO: import from @tensorflow/tfjs
export type Tensor = any;

// @ts-expect-error - TS7034 - Variable 'RGB_COEF' implicitly has type 'any' in some locations where its type cannot be determined.
let RGB_COEF;

// Using the HTML5 canvas element to verify WebGL support.
const WEBGL_TEST_CANVAS = document.createElement('canvas');

const prefix = getConfigValue('ASSETS_CDN_PREFIX');
// use window.location.origin if prefix is empty (i.e. local dev)
const staticHost = prefix !== '' ? prefix : window.location.origin;

export type TensorFlowModel = {
  predict: (arg1?: any) => any;
};

export type ModelType = 'face' | 'id';

type ReportTFUsageParams<T> = {
  detectionFunction: (arg1: HTMLCanvasElement) => Promise<T>;
  pixelSource: HTMLCanvasElement;
  type: ModelType;
  shouldReportMetrics: boolean;
};

// Function that will always call the passed in detectionFunction,
// but based on a parameter, will report the metrics accordingly.
export async function maybeReportTFLogger<T>({
  detectionFunction,
  pixelSource,
  type,
  shouldReportMetrics,
}: ReportTFUsageParams<T>): Promise<T> {
  if (!shouldReportMetrics) {
    return detectionFunction(pixelSource);
  } else {
    const startMemory = tf.memory();
    const result = await detectionFunction(pixelSource);
    const endMemory = tf.memory();
    const numLeakedMemory = endMemory.numBytes - startMemory.numBytes;
    const numLeakedTensors = endMemory.numTensors - startMemory.numTensors;

    // Report to SignalFX
    reportMetric({
      metric: 'gelato_frontend_tfjs_leftover_memory_bytes',
      operation: 'histogram',
      value: numLeakedMemory,
      tags: [
        {key: 'model_type', value: type},
        {key: 'platform', value: getPlatform()},
      ],
    });

    // Report to Sentry
    Sentry.addBreadcrumb({
      category: 'TFJSMemory',
      message: `reporting TFUsage for ${type}`,
      level: Sentry.Severity.Info,
      data: {
        numLeakedMemory,
        numLeakedTensors,
      },
    });

    return result;
  }
}

export const supportsWebGL = (): boolean => {
  // Test if browser supports WebGL.
  // TFJS needs WebGL2 or WebGL + OES_texture_float extension
  // We test this here because TFJS will load the WebGL engine
  // even if webgl1 is detected, but OES_texture_float is not supported.
  const gl2Ctx = WEBGL_TEST_CANVAS.getContext('webgl2');
  if (gl2Ctx) {
    return true;
  }
  const glCtx = WEBGL_TEST_CANVAS.getContext('webgl');

  // @ts-expect-error - TS2322 - Type 'false | OES_texture_float | null' is not assignable to type 'boolean'.
  return !!glCtx && glCtx.getExtension('OES_texture_float');
};

const setWebGLBackend = async () => {
  try {
    // Explicitly load webgl async. Webpack sometime removes it from the main bundle anyways and
    // might as well make the initial load faster.
    await import('@tensorflow/tfjs-backend-webgl');
    const webGLSupported = await tf.setBackend('webgl');
    if (webGLSupported) {
      // HACK to try and detect GPU shader issue: https://github.com/tensorflow/tfjs/issues/952#issuecomment-468364914
      // This will explode and we will try WASM instead
      const tempTensor = tf.split(tf.zeros([1, 65536, 1, 2]), 2, 3)[0];
      await tempTensor.data();
      tempTensor.dispose();
      return true;
    }

    return false;
  } catch (error: any) {
    Sentry.addBreadcrumb({
      category: 'ML',
      level: Sentry.Severity.Info,
      message: `Error loading WEBGL ${error.message}`,
    });
  }
  return false;
};

const loadWasm = async (): Promise<boolean> => {
  const loadStartTime = Date.now();

  // https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/starter/webpack/README.md
  const wasm = await import('@tensorflow/tfjs-backend-wasm');
  const wasmSimdPath = await import(
    'gelato/frontend/src/ML/tfjs-wasm/tfjs-backend-wasm-simd.wasm'
  );
  const wasmSimdThreadedPath = await import(
    'gelato/frontend/src/ML/tfjs-wasm/tfjs-backend-wasm-threaded-simd.wasm'
  );
  const wasmPath = await import(
    'gelato/frontend/src/ML/tfjs-wasm/tfjs-backend-wasm.wasm'
  );

  wasm.setWasmPaths({
    'tfjs-backend-wasm.wasm': wasmPath.default,
    'tfjs-backend-wasm-simd.wasm': wasmSimdPath.default,
    'tfjs-backend-wasm-threaded-simd.wasm': wasmSimdThreadedPath.default,
  });

  const result = await tf.setBackend('wasm');

  reportMetric({
    metric: 'gelato_frontend_tfjs_load_time',
    operation: 'timing',
    value: Date.now() - loadStartTime,
  });

  return result;
};

/**
 * Initializes Tensorflow and sets the backend to WebGL or WASM.
 * @returns {Promise<boolean | null | undefined>} Returns true if wasm is supported, false if not, and null if not initialized.
 */
export const initializeTensorflow = async (): Promise<
  boolean | null | undefined
> => {
  // Wasm is slower on iPhone but faster on Desktop&Android with lower startup latency.
  const preferWasm = isIOS15OrGreater() || !isIOS();
  // @ts-expect-error - TS2339 - Property 'tfjsInitialized' does not exist on type 'Window & typeof globalThis'.
  if (typeof window.tfjsInitialized === 'undefined') {
    let hasWasm = false;
    try {
      if (preferWasm) {
        hasWasm = await loadWasm();
      }
      if (!hasWasm) {
        const hasWebGL = supportsWebGL() && (await setWebGLBackend());

        if (!hasWebGL && !preferWasm) {
          hasWasm = await loadWasm();
        }

        if (!hasWasm && !hasWebGL) {
          throw Error('WebGL+WASM not supported');
        }
      } else {
        reportMetric({
          metric: 'gelato_frontend_webgl_supported',
          operation: 'count',
          value: 1,
        });
      }
      reportMetric({
        metric: 'gelato_frontend_tfjs_supported',
        operation: 'count',
        value: 1,
      });
    } catch (error: any) {
      reportMetric({
        metric: 'gelato_frontend_tfjs_not_supported',
        operation: 'count',
        value: 1,
      });

      throw error;
    } finally {
      // @ts-expect-error - TS2339 - Property 'tfjsInitialized' does not exist on type 'Window & typeof globalThis'.
      window.tfjsInitialized = true;
      if (process.env.PUPPETEER) {
        // Hack to expose TFJS memory function for testing.
        // @ts-expect-error - TS2339 - Property 'tfjsMemoryFunction' does not exist on type 'Window & typeof globalThis'.
        window.tfjsMemoryFunction = tf.memory;
      }
    }
    return hasWasm;
  }
};

// Adjusts the input Point p to account for
// image offset when using non-square inputs into models.
// We paint the input image into the center of the model input square.
// This means the the coordinates returned by model need to
// be adjusted to account for the offset in the smaller dimension.
export function adjustForScaleToSquare(
  width: number,
  height: number,
  p: Point,
): Point {
  const delta = width - height;
  if (delta > 0) {
    return [p[0], p[1] - delta / 2];
  } else {
    return [p[0] + delta / 2, p[1]];
  }
}
export function scaleCoords(
  p: Point,
  currentImageDims: Point,
  targetImageDims: Point,
): Point {
  const [width, height] = currentImageDims;
  const [targetWidth, targetHeight] = targetImageDims;
  return [
    Math.round((p[0] * targetWidth) / width),
    Math.round((p[1] * targetHeight) / height),
  ];
}

export type PixelValueScaling = 'mobilenet' | 'imagenet';

export function prepareImageForModel(
  image: Tensor,
  scaling_type?: PixelValueScaling | null,
  grayscale: boolean | null = false,
) {
  // @ts-expect-error - TS7005 - Variable 'RGB_COEF' implicitly has an 'any' type.
  if (!RGB_COEF) {
    RGB_COEF = tf.tensor1d([0.2989, 0.587, 0.114]);
  }
  return tf.tidy(() => {
    let img = image;
    if (grayscale) {
      // @ts-expect-error - TS7005 - Variable 'RGB_COEF' implicitly has an 'any' type.
      img = tf.expandDims(tf.sum(tf.mul(image, RGB_COEF), 2), -1);
    }
    // Expand the outer most dimension so we have a batch size of 1.

    const batchedImage = tf.expandDims(img, 0);

    switch (scaling_type) {
      case 'mobilenet':
        // Normalize the image between -1 and 1. The image comes in between 0-255,
        // so we divide by 127.5 and subtract 1.
        // This matches the preprocessing we do in keras
        // https://github.com/keras-team/keras-applications/blob/43ac53e491fab09b9d938dadeee1e82c56d5d25c/keras_applications/imagenet_utils.py#L121
        return tf.sub(tf.div(tf.cast(batchedImage, 'float32'), 127.5), 1);
      case 'imagenet':
        const MEAN_RGB = tf.tensor1d([0.485, 0.456, 0.406]);
        const STDDEV_RGB = tf.tensor1d([0.229, 0.224, 0.225]);

        return tf.div(
          tf.sub(tf.div(tf.cast(batchedImage, 'float32'), 255), MEAN_RGB),
          STDDEV_RGB,
        );
    }
    return tf.cast(batchedImage, 'float32');
  });
}

export function getAssetPath(assetRelativePath: string): string {
  return `${staticHost}/assets/${assetRelativePath}`;
}

export function getModelPath(folder: string): string {
  return `${staticHost}/assets/${folder}/model.json`;
}

export function addToLifoQueue(
  values: Array<number>,
  newValue: number,
  maxEntries: number,
) {
  values.unshift(newValue);
  if (values.length > maxEntries) {
    values.pop();
  }
}

export function variance(values: Array<number>): number {
  if (values.length < 2) {
    return 0;
  }
  const average = values.reduce((a, b) => a + b) / values.length;
  return (
    values.reduce((a, b) => a + (b - average) * (b - average), 0) /
    values.length
  );
}
