import dagre from '@dagrejs/dagre'
import { Position } from '@vue-flow/core'
import { ref } from 'vue'

/**
 * Composable to run the layout algorithm on the graph.
 * It uses the `dagre` library to calculate the layout of the nodes and edges.
 */
export function useLayout(findNode) {
  const graph = ref(new dagre.graphlib.Graph())

  const previousDirection = ref('LR')

  function layout(nodes, edges, direction) {
    nodes = sortNodes(nodes, edges)

    // we create a new graph instance, in case some nodes/edges were removed, otherwise dagre would act as if they were still there
    const dagreGraph = new dagre.graphlib.Graph()

    graph.value = dagreGraph

    dagreGraph.setDefaultEdgeLabel(() => ({}))

    const isHorizontal = direction === 'LR'
    dagreGraph.setGraph({ rankdir: direction, ranksep: 120 })

    previousDirection.value = direction

    for (const node of nodes) {
      // if you need width+height of nodes for your layout, you can use the dimensions property of the internal node (`GraphNode` type)
      const graphNode = findNode(node.id)
      dagreGraph.setNode(node.id, { width: graphNode.dimensions.width || 180, height: graphNode.dimensions.height || 160 })
    }

    for (const edge of edges) {
      dagreGraph.setEdge(edge.source, edge.target)
    }
    dagre.layout(dagreGraph)

    // set nodes with updated positions
    return nodes.map(node => {
      const nodeWithPosition = dagreGraph.node(node.id)

      return {
        ...node,
        targetPosition: isHorizontal ? Position.Left : Position.Top,
        sourcePosition: isHorizontal ? Position.Right : Position.Bottom,
        position: { x: nodeWithPosition.x, y: nodeWithPosition.y }
      }
    })
  }

  function sortNodes(nodes, edges) {
    const nodesMap = new Map()

    nodes.forEach(node => {
      nodesMap.set(node.id, node)
    })

    // Perform topological sort to determine the order of nodes
    const sortedNodes = []
    const visited = new Set()

    const visit = nodeId => {
      if (!visited.has(nodeId)) {
        visited.add(nodeId)
        edges.filter(edge => edge.source === nodeId).forEach(edge => visit(edge.target))
        sortedNodes.push(nodesMap.get(nodeId))
      }
    }

    nodesMap.forEach((_, nodeId) => visit(nodeId))

    return sortedNodes.reverse()
  }

  return { graph, layout, previousDirection }
}
