import { getLength, trimTokens } from '../../../../util/tokenizer';

const qa_prompt_template = (context, question) => {
  return `<s> [INST] <<SYS>>
  Answer the question as truthfully as possible using the provided summary, and if the answer is not contained within the summary below, ask the user if they could be more specific.
<</SYS>>
>>SUMMARY<<
${context}
>>QUESTION<<
${question} [/INST]
>>ANSWER<<`.trim();
};
const chat_qa_prompt_template = (chat_history, context, question) => {
  return `<s> [INST] <<SYS>>
  Answer the question as truthfully as possible using the provided summary, and if the answer is not contained within the summary below, ask the user if they could be more specific.
<</SYS>> [/INST] </s>
${chat_history}
<s> [INST] 
>>SUMMARY<<
${context}
>>QUESTION<<
${question} [/INST]
>>ANSWER<<`.trim();
};

const buildChatHistory = async (
  model_config,
  chat_history,
  prompt,
  context
) => {
  const history = [...chat_history];
  if (!history.length) {
    return history;
  }
  const prompt_tokens = await getLength(model_config.tokenizer, prompt);
  const context_tokens = await getLength(model_config.tokenizer, context);
  const max_chat_tokens =
    model_config.max_input_length -
    (model_config.prompt_template_tokens + prompt_tokens + context_tokens);

  if (max_chat_tokens < 5) {
    return [];
  }

  let remaining_tokens = max_chat_tokens;
  const chats = [];
  for (let pair_index = history.length / 2 - 1; pair_index >= 0; pair_index--) {
    const question = history[pair_index * 2].content;
    const answer = history[pair_index * 2 + 1].content;

    const chat_pair_length = await getLength(
      model_config.tokenizer,
      `${question}${answer}`
    );

    if (chat_pair_length > remaining_tokens) {
      break;
    } else {
      remaining_tokens = remaining_tokens - chat_pair_length;
      chats.push(history[pair_index * 2]);
      chats.push(history[pair_index * 2 + 1]);
    }
  }

  return chats
    .map((message) => {
      if (message.role === 'user') {
        return `<s> [INST] 
>>QUESTION<<
${message.content} [/INST]`;
      } else if (message.role === 'assistant') {
        return `>>ANSWER<<
${message.content} </s>`.trim();
      } else {
        return '';
      }
    })
    .join('\n');
};

const buildContext = async (model_config, docs, prompt) => {
  const prompt_tokens = await getLength(model_config.tokenizer, prompt);

  const max_context_tokens =
    model_config.max_input_length -
    (model_config.prompt_template_tokens + prompt_tokens);

  const context = await trimTokens(
    model_config.tokenizer,
    docs.map((doc) => doc.zContent).join('\n\n'),
    max_context_tokens
  );
  return context;
};

const buildPrompt = async (model_config, chat_history, prompt, docs) => {
  const context = await buildContext(model_config, docs, prompt);

  const history = await buildChatHistory(
    model_config,
    chat_history,
    prompt,
    context
  );

  if (history.length <= 1) {
    return qa_prompt_template(context, prompt);
  } else {
    return chat_qa_prompt_template(history, context, prompt);
  }
};
export default buildPrompt;
