import { useCallback, useState } from 'react';
import { loadAccessKeyIdIfExists } from './authentication';
import {
  Conversation,
  Message,
  OpenAiFormatMessage,
  ServerGptQuery,
  Tool,
  assert,
} from './types';
import { map } from 'lodash';

type ResponseGenerator = AsyncGenerator<
  string | { done: true },
  | {
      text: string;
      isError: boolean;
    }
  | { done: true },
  unknown
>;

export function useGetInterruptibleResponseStreamed() {
  const [responseGenerator, setResponseGenerator] =
    useState<ResponseGenerator | null>(null);

  async function getInterruptibleResponseStreamed(
    previousResponseGenerator: ResponseGenerator | null,
    currentTool: Tool,
    input: string,
    currentConversation: Conversation,
    outputCallback: (output: string) => void,
    temperature: number = 0.7,
    model: string = 'gpt-4'
  ) {
    // A previous message is still getting answered. Stop that generator
    if (previousResponseGenerator) {
      previousResponseGenerator.return({ done: true });
    }

    const _responseGenerator = getResponseStreamed(
      currentTool,
      input,
      currentConversation,
      temperature,
      model
    );

    setResponseGenerator(_responseGenerator);

    let _output = '';
    for await (const chunk of _responseGenerator) {
      _output += chunk;
      outputCallback(_output);
    }
  }

  const getInterruptibleResponseStreamed_ = useCallback(
    (
      currentTool: Tool,
      input: string,
      currentConversation: Conversation,
      outputCallback: (output: string) => void,
      temperature: number = 0.7,
      model: string = 'gpt-4'
    ) => {
      getInterruptibleResponseStreamed(
        responseGenerator,
        currentTool,
        input,
        currentConversation,
        outputCallback,
        temperature,
        model
      );
    },
    [responseGenerator]
  );

  return getInterruptibleResponseStreamed_;
}

async function* getResponseStreamed(
  tool: Tool,
  userInput: string,
  messages: Array<Message>,
  temperature: number,
  model?: string
): ResponseGenerator {
  const accessKeyId = loadAccessKeyIdIfExists();

  if (!accessKeyId) {
    return { text: 'No access key', isError: true };
  }

  const response = await fetch(`/chatStream?accessKeyId=${accessKeyId}`, {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
    },
    body: JSON.stringify({
      query: getServerQuery(tool, messages, userInput, temperature, model),
    }),
  });

  const reader = response.body!.getReader();
  const decoder = new TextDecoder();
  while (true) {
    const { done, value } = await reader.read();
    if (done) {
      return { done: true };
    }

    yield decoder.decode(value) || '';
  }
}

function meldMessagesToOpenAiFormat(
  messages: Array<Message>
): Array<OpenAiFormatMessage> {
  return map(messages, (m) => {
    assert(m.sender === 'user' || m.sender === 'bot');

    return {
      role: m.sender === 'user' ? 'user' : 'assistant',
      content: m.text,
    };
  });
}

function getServerQuery(
  { prompt, maxTokens }: Tool,
  messages: Array<Message>,
  userInput: string,
  temperature: number,
  model?: string
): ServerGptQuery {
  const systemMessage: OpenAiFormatMessage = {
    role: 'system',
    content: prompt,
  };

  const userMessage: OpenAiFormatMessage = {
    role: 'user',
    content: userInput,
  };

  return {
    messages: [systemMessage]
      .concat(meldMessagesToOpenAiFormat(messages))
      .concat(userMessage),
    temperature,
    ...(maxTokens !== undefined ? { maxTokens } : {}),
    model,
  };
}
