import { getTextFromMessage } from "@/lib/messages";
import { ModelConfig, ModelConfigWithPromptTools } from "@/services/configs.service";
import { ModelEndpoint } from "@/services/playground.service";
import { JsonSchema, LinkedTool, EditorTool } from "@/services/v4.service";
import { Prompt, TemplateLanguage } from "@/types/app/prompt";
import { EditorState, Modifier, SelectionState } from "draft-js";
import _ from "lodash";
import { parser, nodes } from "nunjucks";
import { diffWords } from "diff";
import { ToolFunction } from "@/types/app/tool";

// Regex looking for {{variable}} (with up to a single space between the brackets)
// Letters, numbers, and underscores are allowed.
// The matching variable name is available as match[1], while the full matched text
// (i.e. including curly braces) is available as match[0].
// TODO: Disallow variable names that start with a number
// JAB ^ is that required? We're using dicts everywhere I'd hope.
// https://regexr.com/7ou70
// NB: This regex needs to be kept in sync with the one in the backend in interfaces/model.py
export const INPUT_REGEX = /{{[\s]?([a-zA-Z_0-9\.\[\]]*)[\s]?}}/g;

// Regex looking for {{tool_name("arg1", "arg2")}}
// Matches the inner args as match[2] which will then need to be parsed (split + trim)
// https://regexr.com/7oo32
// {{ serp_api(query) }} -> serp_api(query), serp_api, query
// {{serp_api("my string", my_var)}} -> serp_api("my string"), serp_api, `"my string", my_var`
// NB: This regex needs to be kept in sync with the one in the backend in interfaces/model.py
// This regex is slightly different in that it has different capture groups
export const TOOL_REGEX = /{{[\s]?([a-zA-Z_0-9-]+)\(([\"\w\s,-]*)\)[\s]?}}/g;

// TODO: Remove in favour of extractInputsFromPromptVersion` when migration to v5 is complete
export const extractInputsFromModelConfig = (modelConfig: ModelConfigWithPromptTools): string[] => {
  // TODO(v5): Update.
  const endpoint = modelConfig.endpoint ?? "chat";
  const templateLanguage = modelConfig.template_language ?? "default";
  return extractInputsFromPromptVersion({
    endpoint,
    template_language: templateLanguage,
    template: endpoint === "chat" ? (modelConfig.chat_template ?? null) : (modelConfig.prompt_template ?? null),
  });
};

export const extractInputsFromPromptVersion = (
  promptVersion: Pick<Prompt, "endpoint" | "template" | "template_language">,
): string[] => {
  switch (promptVersion.endpoint) {
    case "chat":
      if (!Array.isArray(promptVersion.template)) {
        console.error("Chat template is not an array", promptVersion.template);
        return [];
      }
      return _.uniq(
        (promptVersion.template || []).flatMap((message) =>
          extractInputsFromPrompt(getTextFromMessage(message.content), promptVersion.template_language),
        ),
      );
    case "complete":
      if (typeof promptVersion.template !== "string") {
        console.error("Prompt template is not a string", promptVersion.template);
        return [];
      }
      return extractInputsFromPrompt(promptVersion.template || "", promptVersion.template_language);
    default:
      throw new Error("Unknown model endpoint");
  }
};

export const extractInputsFromPrompt = (prompt: string, templateLanguage?: TemplateLanguage | null): string[] => {
  // Extract the list of input variable names from the prompt
  // Extended with the inputs from any tools too.
  // > extractInputsFromPrompt("{{action}} a poem by {{wiki(author)}} on {{topic}}:")
  // [ 'action', 'author', 'topic' ]

  const highlights =
    templateLanguage === "jinja" ? extractJinjaHighlightsFromPrompt(prompt) : extractInputIndicesFromPrompt(prompt);
  const inputIndices = highlights.filter((x) => x.type === "input");
  const inputs = inputIndices.map((x) => x.text);
  const toolsIndices = extractToolIndicesFromPrompt(prompt);
  const toolInputs = toolsIndices.map((x) => x.args.filter((y) => y.type === "variable").map((y) => y.arg)).flat();
  return _.uniq(inputs.concat(toolInputs));
};

interface ToolIndexArg {
  start: number;
  end: number;
  arg: string;
  type: "literal" | "variable";
}

export interface ToolIndex {
  start: number;
  end: number;
  // the function name portion of the signature - e.g. `serp_api` for serp_api("my string", my_var)
  call: string;
  // the full text of the tool call - e.g. `{{serp_api(my_var)}`
  text: string;
  args: ToolIndexArg[];
  type: "tool";
}

export const extractInputIndicesFromPrompt = (prompt: string): Extraction[] => {
  // > extractInputIndicesFromPrompt("{{action}} a poem by {{author}} on {{topic}}:")
  // [
  //   { start: 0, end: 10, name: 'action' },
  //   { start: 21, end: 31, name: 'author' },
  //   { start: 35, end: 44, name: 'topic' }
  // ]
  if (!prompt) {
    return [];
  }
  return [...prompt.matchAll(INPUT_REGEX)].map((x) => {
    if (x.index === undefined) {
      throw new Error(`matchAll returned undefined index for x ${x}`);
    }
    return {
      start: x.index,
      end: x.index + x[0].length,
      text: x[1],
      type: "input",
    };
  });
};

export const extractToolIndicesFromPrompt = (prompt: string): ToolIndex[] => {
  if (!prompt) {
    return [];
  }
  return [...prompt.matchAll(TOOL_REGEX)].map((x) => {
    if (x.index === undefined) {
      throw new Error(`matchAll returned undefined index for x ${x}`);
    }

    const openParenthesisIndex = x[1].length + 3; // 3 for the opening {{ and (
    let match;
    let args: ToolIndexArg[] = [];
    // I got this from ChatGPT :)
    const regex = /(?<name>"[^"]+"|\w+)(\s*,)?/g;
    while ((match = regex.exec(x[2]))) {
      if (match.groups === undefined) {
        continue;
      }
      if (match.index === undefined) {
        throw new Error(`matchAll returned undefined index for match ${match}`);
      }
      // If starts with a number or a double/single/back quote, it's a literal
      // otherwise it's a variable
      const type = match.groups.name.match(/^[0-9"']/) ? "literal" : "variable";
      args.push({
        start: openParenthesisIndex + match.index,
        end: openParenthesisIndex + match.index + match[0].length,
        arg: match.groups.name,
        type: type,
      });
    }
    return {
      start: x.index,
      end: x.index + x[0].length,
      call: x[1],
      text: x[0],
      args: args,
      type: "tool",
    };
  });
};

export const extractToolsFromPrompt = (
  prompt: string,
): {
  name: string;
  call: string;
}[] => {
  const toolsIndices = extractToolIndicesFromPrompt(prompt);
  const toolNamesAndCalls = toolsIndices.map((x) => ({ name: x.text, call: x.call }));
  // Unique by call
  return _.uniqBy(toolNamesAndCalls, "call");
};

export const addTextAtCursor = (text: string, editorState: EditorState): EditorState => {
  const currentContent = editorState.getCurrentContent();
  const currentSelection = editorState.getSelection();

  let newContent = Modifier.replaceText(currentContent, currentSelection, text);

  const textToInsertSelection = currentSelection.set(
    "focusOffset",
    currentSelection.getFocusOffset() + text.length,
  ) as SelectionState;

  const inlineStyles = editorState.getCurrentInlineStyle();
  inlineStyles.forEach(
    (inlineStyle) => (newContent = Modifier.applyInlineStyle(newContent, textToInsertSelection, inlineStyle || "")),
  );

  let newState = EditorState.push(editorState, newContent, "insert-characters");
  newState = EditorState.forceSelection(
    newState,
    textToInsertSelection.set("anchorOffset", textToInsertSelection.getAnchorOffset() + text.length) as any,
  );

  return newState;
};

export const deleteTextAtCursor = (characters: number, editorState: EditorState): EditorState => {
  const currentContent = editorState.getCurrentContent();
  const currentSelection = editorState.getSelection();

  let newContent = Modifier.removeRange(
    currentContent,
    currentSelection.merge({
      anchorOffset: currentSelection.getAnchorOffset() - characters,
      focusOffset: currentSelection.getAnchorOffset(),
    }),
    "backward",
  );

  let newState = EditorState.push(editorState, newContent, "remove-range");

  return newState;
};

export const DEFAULT_TOOL_PARAMETER_SCHEMA = {
  type: "object",
  properties: {},
  required: [],
} as JsonSchema;

// When comparing tool configs to see if the tools are equivalent we want to only change the name, description, and parameter_schema
export const areToolsEqual = (tool1: ToolFunction, tool2: ToolFunction): boolean => {
  // If either tool is a linked tool, don't compare. We only care about inline tools.
  if ("id" in tool1 || "id" in tool2) {
    return false;
  }

  return (
    tool1.name === tool2.name &&
    tool1.description === tool2.description &&
    _.isEqual(tool1.parameters, tool2.parameters)
  );
};

export interface Extraction {
  start: number; // start character index in the prompt
  end: number; // end character index (non-inclusive)
  text: string; // the highlighted text (either a block or an input)
  call?: string; // the tool call signature
  type: "input" | "block" | "tool"; // We don't yet use block, but we might want to
}

/* Get character offset from line and column numbers provided by jinja nodes */
function getCharOffset(text: string, lineno: number, colnum: number): number {
  const lines = text.split("\n");
  if (lineno > lines.length) {
    return text.length;
  }
  let offset = 0;
  for (let i = 0; i < lineno; i++) {
    offset += lines[i].length + 1;
  }
  offset += colnum;
  return offset;
}

/* Collect defined symbol nodes by traversing the jinja AST.

A `defined` symbol is one that does not require a user input to render the string.
e.g. one in a set statement, or the local variable in a for loop.
These are not extracted as inputs in our Editor.

Unfortunately Symbol nodes do not contain any information that we can
use to determine whether it requires user input or not - this is computed on the fly
within the jinja compiler it seems (not 100% certain on this, but it's what I
gathered from a brief scan through their source code).

Therefore, we need to define logic which figures this out based on the node type and context,
case-by-case, which is annoying. Using the AST is a much easier representation to
do this with vs trying to build a tokeniser against the raw tokens (see the
attempt `extractHighlightsFromPrompt0` in previous commits that still misses some
edge cases, but had the advantage of also parsing blocks. It uses a
stack to track the different possible nested blocks)*/
function collectDefinedSymbolNodes(ast: Node): { defined: Set<string>; functionCalls: Set<Node> } {
  const defined = new Set<string>();
  const functionCalls = new Set<Node>();

  function traverse(node: nodes.Node) {
    if (!node) return;

    if (node.typename === "FunCall") {
      functionCalls.add(node);
      defined.add(node.name.value);
    } else if (node.typename === "Filter") {
      defined.add(node.name.value);
    } else if (node.typename === "Set") {
      node.targets.forEach((target: nodes.Node) => {
        if (target.typename === "Symbol") {
          defined.add(target.value);
        }
      });
    } else if (node.typename === "For") {
      if (node.name.typename === "Symbol") {
        defined.add(node.name.value);
      } else if (node.name.typename === "Array") {
        node.name.children.forEach((child: nodes.Node) => {
          if (child.typename === "Symbol") {
            defined.add(child.value);
          }
        });
      }
      // Add the special 'loop' variable that's implicitly available in for-loops
      defined.add("loop");
    } else if (node.typename === "Macro") {
      defined.add(node.name.value);
    } else if (node.typename === "Import" || node.typename === "FromImport") {
      if (node.target && node.target.typename === "Symbol") {
        defined.add(node.target.value);
      }
      if (node.names && node.names.children) {
        node.names.children.forEach((child: nodes.Node) => {
          if (child.typename === "Pair") {
            if (child.value && child.value.typename === "Symbol") {
              defined.add(child.value.value);
            }
          } else if (child.typename === "Symbol") {
            defined.add(child.value);
          }
        });
      }
    }

    // Combine potential child arrays?
    const childNodes: Node[] = [];
    if (node.children) {
      childNodes.push(...node.children);
    }
    if (node.body) {
      if (Array.isArray(node.body)) {
        childNodes.push(...node.body);
      } else {
        childNodes.push(node.body);
      }
    }

    childNodes.forEach(traverse);
  }

  traverse(ast);
  return { defined, functionCalls };
}

/*
Extract all the input and tool call variables from a jinja prompt template.

We pay the cost of traversing the AST twice (once to determine which symbols
do not require user input, and once to get all symbols). Although less performant,
this allows for simpler logic in the AST parsing.
TODO: to improve this to use a single pass instead of 2.
 */
export function extractJinjaHighlightsFromPrompt(prompt: string): Extraction[] {
  try {
    const ast = parser.parse(prompt);
    const definedSymbols = collectDefinedSymbolNodes(ast);
    const allSymbols = ast.findAll(nodes.Symbol);
    const inputs = allSymbols.filter((sym: nodes.Node) => !definedSymbols.defined.has(sym.value));
    let extractions: Extraction[] = [];
    inputs.forEach((input: nodes.Node) => {
      const startOffset = getCharOffset(prompt, input.lineno, input.colno);
      const endOffset = startOffset + input.value.length;
      extractions.push({
        start: startOffset,
        end: endOffset,
        text: input.value,
        type: "input",
      });
    });
    definedSymbols.functionCalls.forEach((funCall: nodes.Node) => {
      // funcCall line and col numbers cover the signature of function
      const endOffset = getCharOffset(prompt, funCall.lineno, funCall.colno);
      const startOffset = endOffset - funCall.name.value.length;
      extractions.push({
        start: startOffset,
        end: endOffset,
        text: funCall.name.value,
        type: "tool",
      });
    });
    return extractions;
  } catch (e) {
    // console.error(e);
    return [];
  }
}

/**
 * Compares a template and its rendered output, and returns the highlights in the render
 * that indicate where content has been inserted/changed.
 */
export function extractRenderHighlights(template: string, render: string): Extraction[] {
  const diffs = diffWords(template, render);
  const highlights: Extraction[] = [];
  let renderIndex = 0;

  diffs.forEach((part) => {
    if (part.added) {
      // For added words, record their start/end positions and text.
      highlights.push({
        start: renderIndex,
        end: renderIndex + part.value.length,
        text: part.value,
        type: "block",
      });
      renderIndex += part.value.length;
    } else if (part.removed) {
      // Removed parts come from the template; they don’t affect the rendered output's index.
    } else {
      // Equal parts: just advance the index.
      renderIndex += part.value.length;
    }
  });

  return highlights;
}
