import { findParentNode } from "@tiptap/core";
import { Node as PMNode, Schema } from "@tiptap/pm/model";
import { Selection } from "@tiptap/pm/state";
import { Rect, TableMap } from "@tiptap/pm/tables";

import type { ContentNodeWithPos } from "../../../core/types";
import { findTable } from "./find";
import { getTableNodeTypes } from "./getTableNodeTypes";
import { isSelectionType } from "./isSelectionType";
import type { CellsMap } from "./types";

/**
 * Get node and the doc level position based on its relative position to table node
 */
const getContentNodePos = (table: ContentNodeWithPos, nodePos: number) => {
  const node = table.node.nodeAt(nodePos) as PMNode;
  const pos = nodePos + table.start;
  return { pos, start: pos + 1, node, depth: table.depth + 2 };
};

/**
 * Returns an array of cells in a column(s), where `columnIndex` could
 * be a column index or an array of column indexes
 */
export const getCellsInColumn =
  (columnIndexes: number | number[]) =>
  (selection: Selection): ContentNodeWithPos[] | undefined => {
    const table = findTable(selection);
    if (!table) {
      return;
    }

    const map = TableMap.get(table.node);
    const indexes = Array.isArray(columnIndexes)
      ? columnIndexes
      : [columnIndexes];

    return indexes
      .filter((index) => index >= 0 && index <= map.width - 1)
      .reduce<ContentNodeWithPos[]>((acc, index) => {
        const cells = map.cellsInRect({
          left: index,
          right: index + 1,
          top: 0,
          bottom: map.height,
        });
        return acc.concat(
          cells.map((nodePos) => getContentNodePos(table, nodePos))
        );
      }, []);
  };

/**
 * Returns an array of cells in a row(s), where `rowIndex` could
 * be a row index or an array of row indexes
 */
export const getCellsInRow =
  (rowIndex: number | number[]) =>
  (selection: Selection): ContentNodeWithPos[] | undefined => {
    const table = findTable(selection);

    if (!table) {
      return;
    }

    const map = TableMap.get(table.node);
    const indexes = Array.isArray(rowIndex) ? rowIndex : [rowIndex];

    return indexes
      .filter((index) => index >= 0 && index <= map.height - 1)
      .reduce<ContentNodeWithPos[]>((acc, index) => {
        const cells = map.cellsInRect({
          left: 0,
          right: map.width,
          top: index,
          bottom: index + 1,
        });
        return acc.concat(
          cells.map((nodePos) => getContentNodePos(table, nodePos))
        );
      }, []);
  };

/**
 * Returns Cells Rect if there are selected cell(s)
 */
export const getCellsRect = (
  map: TableMap,
  selection: Selection
): Rect | undefined => {
  if (!isSelectionType(selection, "cell")) {
    return;
  }
  const start = selection.$anchorCell.start(-1);
  return map.rectBetween(
    selection.$anchorCell.pos - start,
    selection.$headCell.pos - start
  );
};

/**
 * Returns `TableMap`, selected `Rect` with its width and height and individual cells node
 * and its position. This is useful when changing cells' border
 */
export const getCellsMap = (
  schema: Schema,
  selection: Selection
): CellsMap | undefined => {
  const table = findTable(selection);

  if (!table) return;

  const map = TableMap.get(table.node);

  if (!isSelectionType(selection, "cell")) {
    const tableNodeTypes = getTableNodeTypes(schema);
    const cellType = tableNodeTypes.cell;
    const headerCellType = tableNodeTypes.header_cell;
    const cell = findParentNode(
      (node) => node.type === cellType || node.type === headerCellType
    )(selection);
    if (!cell) return;
    return {
      width: 1,
      height: 1,
      map: map,
      cells: [
        {
          pos: cell.pos,
          start: cell.pos + 1,
          node: cell.node,
          depth: table.depth + 2,
        },
      ],
    };
  }
  const rect = getCellsRect(map, selection);
  if (!rect) return;
  const cells = map.cellsInRect(rect);
  return {
    width: rect.right - rect.left,
    height: rect.bottom - rect.top,
    rect,
    map,
    cells: cells.map((nodePos) => getContentNodePos(table, nodePos)),
  };
};
