import * as ort from 'onnxruntime-web';
import { Tensor } from 'onnxruntime-web';
import { EventEmitter } from 'events';
import PostAndReceiveWorker from 'services/PostAndReceiveWorker';
import logger from 'services/logger';
// eslint-disable-next-line import/no-webpack-loader-syntax
import onnxWorker from 'workerize-loader!./onnx.worker';
import isOffscreenCanvasSupported from 'utils/isOffscreenCanvasSupported';
import { getMessageFromError } from 'utils/errorMessage';
import imageDataToTensor from './utils/imageDataToTensor';

type InputImage = HTMLImageElement | HTMLCanvasElement | OffscreenCanvas;

export interface AgePredictionResult {
  pred_a: number
}

async function resizeImage(image: InputImage, width = 96, height = 96) {
  // Assuming you have a canvas element in your HTML to use for image preprocessing
  const canvas = document.createElement('canvas');
  const ctx = canvas.getContext('2d');

  // Resize the image to 96x96 (adjust as needed for your model)
  canvas.width = width;
  canvas.height = height;
  ctx?.drawImage(image, 0, 0, width, height);

  const resizedImage = ctx?.getImageData(0, 0, width, width);
  return resizedImage;
}

class AgePredictor extends EventEmitter {
  onnxWorker: PostAndReceiveWorker<any>;
  session: ort.InferenceSession | null;
  isOffscreenCanvasSupported: boolean;

  constructor() {
    super();
    this.onnxWorker = new PostAndReceiveWorker({ worker: onnxWorker as Worker, name: 'Onnx' });
    this.session = null;
    this.isOffscreenCanvasSupported = isOffscreenCanvasSupported();
  }

  async init() {
    if (!this.isOffscreenCanvasSupported) {
      logger.info('Initializing Onnx AgePredictor without web worker');
      return this.initWithoutWebWorker();
    }
    logger.info('Initializing Onnx AgePredictor');
    return this.onnxWorker.postAndReceiveMessage({
      type: 'loadModel',
    });
  }

  terminate() {
    logger.info('Terminating Onnx AgePredictor');
    return this.onnxWorker.worker.terminate();
  }

  async freshInit() {
    this.onnxWorker = new PostAndReceiveWorker({ worker: onnxWorker as Worker, name: 'Onnx' });
    return this.init();
  }

  async predictAge(image: InputImage) {
    const resizedImage = await resizeImage(image);

    const { results, time, error } = await this.onnxWorker.postAndReceiveMessage({
      type: 'predictAge',
      imageData: resizedImage,
    });
    if (error) {
      logger.error('Error predicting age:', error);
      return null;
    }
    logger.debug('Age prediction results:', { results, time });
    return results as AgePredictionResult;
  }

  async initWithoutWebWorker() {
    const pathToModel = 'onnx/ssrnet_2pass_83c05c_-_threshold 27_-_1.65_ybpo_7.35_obpy.ort';
    const s = await ort.InferenceSession.create(pathToModel, {
      executionProviders: ['webgl'],
    });

    this.session = s;

    // Create a blank ImageData instance to warm up the model
    const pixelData = new Uint8ClampedArray(96 * 96 * 4);
    const imageData = new ImageData(pixelData, 96, 96);
    const imageTensor = imageDataToTensor(imageData);
    await this.runInference(imageTensor);
  }

  async runInference(preprocessedData: Tensor): Promise<[any, number]> {
    if (this.session === null) {
      throw new Error('Model is not loaded yet');
    }
    // Get start time to calculate inference time.
    const start = new Date();
    // create feeds with the input name from model export and the preprocessed data.
    const feeds: Record<string, Tensor> = {};
    const key = this.session.inputNames[0];
    feeds[key] = preprocessedData;

    // Run the session inference.
    const outputData = await this.session.run(feeds as unknown as ort.InferenceSession.FeedsType);

    // Get the end time to calculate inference time.
    const end = new Date();
    // Convert to seconds.
    const inferenceTime = (end.getTime() - start.getTime()) / 1000;
    // Get output results with the output name from the model export.
    const output: Record<string, any> = {};

    for (const outputName of this.session.outputNames) {
      output[outputName] = outputData[outputName]?.data?.at(0);
    }

    return [output, inferenceTime];
  }

  async predictAgeWithoutWebWorker(image: InputImage) {
    try {
      const resizedImage = await resizeImage(image);

      if (!resizedImage) {
        logger.error('Error resizing image');
        return null;
      }

      const imageTensor = imageDataToTensor(resizedImage);
      const [results, time] = await this.runInference(imageTensor);
      logger.debug('Age prediction results without worker:', { results, time });
      return results as AgePredictionResult;
    } catch (error) {
      const errorMessage = getMessageFromError(error);
      logger.error('Error in predicting age without worker', { errorMessage });
      return null;
    }
  }
}

export default AgePredictor;
