import * as tf from '@tensorflow/tfjs';
import { DocTypes, INVC_VAT, INVC_VAT_NEXT, PAYMENT_CONFIRMATION } from './docTypes';
import { tokenizeWithWordIndex } from './tokenizer';
import { wordIndex } from './wordIndex';

const getClassNameMulti = (highestClassIndex:number) => {
  if (highestClassIndex) return PAYMENT_CONFIRMATION;
  if (highestClassIndex) return INVC_VAT_NEXT;
  return INVC_VAT;
};

export const classifyDocMulti = async (doc: string): Promise<DocTypes> => {
  const words = tokenizeWithWordIndex(doc, wordIndex, 250);
  const wordsTensor = tf.tensor2d([words]);
  const model = await tf.loadLayersModel('model/model.json');
  const predictionTensor = model.predict(wordsTensor) as tf.Tensor2D;
  const prediction = await predictionTensor.array();
  // console.log(prediction[0], 'doc prediction');
  const max = Math.max(...prediction[0]);
  const highestClassIndex = prediction[0].indexOf(max);
  return getClassNameMulti(highestClassIndex);
};




export const classifyDocBinary = async (doc: string): Promise<DocTypes> => {
  // console.log(doc);
  const words = tokenizeWithWordIndex(doc, wordIndex, 250);
  const wordsTensor = tf.tensor2d([words]);
  const model = await tf.loadLayersModel('model/model.json');
  const predictionTensor = model.predict(wordsTensor) as tf.Tensor2D;
  const prediction = await predictionTensor.array();
  // console.log(prediction[0], 'doc prediction');
  // console.log(prediction[0][0] >= 0.5 ? PAYMENT_CONFIRMATION : INVC_VAT);
  return prediction[0][0] >= 0.5 ? PAYMENT_CONFIRMATION : INVC_VAT;
};

