import React, { useEffect, useRef } from "react";
import * as d3 from "d3";
import {
  nodesShareTag,
  nodeHasTag,
  CURVE_STROKE_WIDTH,
  FADED_NODE_OPACITY,
  getCurveStyle,
  NODE_HIGHLIGHT_OPACITY,
  NODE_OPACITY,
  NODE_RADIUS,
  SELECTED_NODE_RADIUS,
  generateNodesAndLinks,
} from "./utils";

const drawNetworkOutline = (
  g,
  networkLabelsGroup,
  nodes,
  label,
  conversationNames,
  showConversationLabels
) => {
  // Remove existing outlines and labels
  g.selectAll(`.network-outline-${label}`).remove();
  networkLabelsGroup.selectAll(`.network-label-${label}`).remove();

  // Calculate the convex hull of the nodes
  const hull = d3.polygonHull(nodes.map((d) => [d.x, d.y]));

  if (hull) {
    // Calculate the centroid of the hull
    const centroid = d3.polygonCentroid(hull);
    const expansionFactor = 20; // Amount by which to expand the hull

    // Expand the convex hull outward from the centroid
    const expandedHull = hull.map(([x, y]) => {
      const [cx, cy] = centroid;
      const dx = x - cx;
      const dy = y - cy;
      const dist = Math.sqrt(dx * dx + dy * dy);
      const scale = (dist + expansionFactor) / dist;
      return [cx + dx * scale, cy + dy * scale];
    });

    // Create a smooth closed curve around the expanded convex hull
    const outlinePathGenerator = d3
      .line()
      .x((d) => d[0])
      .y((d) => d[1])
      .curve(d3.curveBasisClosed);

    g.append("path")
      .datum(expandedHull)
      .attr("class", `network-outline network-outline-${label}`)
      .attr("d", outlinePathGenerator)
      .attr("fill", "white")
      .attr("fill-opacity", 0.15);

    if (showConversationLabels) {
      // Add the label at the centroid of the expanded hull
      const networkName = conversationNames
        ? conversationNames[nodes[0].conversationId] ??
          `Conversation ${nodes[0].conversationId}`
        : `Conversation ${nodes[0].conversationId}`;

      // Top of the hull
      const top = expandedHull.reduce((acc, cur) =>
        acc[1] > cur[1] ? cur : acc
      );

      const networkNameSplit = networkName.split("\n");
      // Iterate over the split network name and add it to the label
      networkNameSplit.forEach((token, index) => {
        networkLabelsGroup
          .append("text")
          .attr("class", `network-label network-label-${label}`)
          .attr("x", centroid[0])
          .attr("y", top[1] + index * 15 - (networkNameSplit.length - 1) * 10)
          .attr("text-anchor", "middle")
          .attr("fill", "white")
          .attr("opacity", 0.8)
          .attr("font-size", "10px")
          .attr("font-weight", "bold")
          .text(token);
      });
    }
  }
};

const initializeNodePositions = (nodes, width, height) => {
  const centralNodes = nodes.filter((d) => /^Subnetwork-\d+$/.test(d.id));
  const gridSize = Math.ceil(Math.sqrt(centralNodes.length));
  const spacingX = width / (gridSize + 1);
  const spacingY = height / (gridSize + 1);

  centralNodes.forEach((node, index) => {
    const row = Math.floor(index / gridSize);
    const col = index % gridSize;
    node.x = spacingX * (col + 1);
    node.y = spacingY * (row + 1);

    // Place subnetwork nodes around their central node
    const subnetworkNodes = nodes.filter((d) =>
      d.id.startsWith(`${node.id} - Node`)
    );
    subnetworkNodes.forEach((subNode, subIndex) => {
      const angle = (2 * Math.PI * subIndex) / subnetworkNodes.length;
      const radius = 50; // Radius around the central node
      subNode.x = node.x + radius * Math.cos(angle);
      subNode.y = node.y + radius * Math.sin(angle);
    });
  });

  // Position the central node in the center
  const centralNode = nodes.find((d) => d.id === "Central Node");
  if (centralNode) {
    centralNode.x = width / 2;
    centralNode.y = height / 2;
  }
};

const showConnections = (
  node,
  data,
  gCurves,
  nodes,
  selectedTagRef,
  simulationIDRef
) => {
  // Increment simulation ID to cancel previous simulations
  simulationIDRef.current += 1;
  const currentSimulationID = simulationIDRef.current;

  const targetNodes = data.nodes.filter(
    (n) => n.highlight && nodesShareTag(node, n)
  );

  if (targetNodes.length === 0) {
    console.log("No connections found for node", node.highlight);
    return;
  }

  drawCurvedLines(
    gCurves,
    node,
    targetNodes,
    selectedTagRef,
    simulationIDRef,
    currentSimulationID
  );

  // Enlarge the current node and its connected nodes, fade others
  nodes
    .transition()
    .duration(500)
    .attr("r", (n) =>
      n.id === node.id || targetNodes.includes(n)
        ? SELECTED_NODE_RADIUS
        : NODE_RADIUS
    )
    .attr("opacity", (n) =>
      n.id === node.id || targetNodes.includes(n)
        ? nodeHasTag(n, selectedTagRef.current)
          ? NODE_OPACITY
          : FADED_NODE_OPACITY
        : FADED_NODE_OPACITY
    );
};

const drawCurvedLines = (
  svg,
  sourceNode,
  targetNodes,
  selectedTagRef,
  simulationIDRef,
  simulationID
) => {
  // Clear existing curves
  svg.selectAll(".curved-lines").remove();

  // Parameters for the simulation
  const MIN_SUBDIVISIONS = 2; // Minimum number of subdivisions for the shortest edges
  const MAX_SUBDIVISIONS = 7; // Maximum number of subdivisions for the longest edges
  const K = 7.1; // Edge bundling force constant
  const S = 0.3; // Step size for simulation
  const P = 0.1; // Original position attraction strength
  const maxIterations = 5; // Number of simulation steps

  // Find the minimum and maximum edge lengths
  let minEdgeLength = Infinity;
  let maxEdgeLength = 0;

  // Calculate edge lengths and initialize edges
  const edges = targetNodes.map((targetNode) => {
    const dx = targetNode.x - sourceNode.x;
    const dy = targetNode.y - sourceNode.y;
    const length = Math.sqrt(dx * dx + dy * dy);

    if (length < minEdgeLength) minEdgeLength = length;
    if (length > maxEdgeLength) maxEdgeLength = length;

    return {
      source: sourceNode,
      target: targetNode,
      length: length,
      points: [],
      sourceNode,
      targetNode,
    };
  });

  // Function to get subdivisions based on edge length
  function getSubdivisions(edgeLength) {
    if (maxEdgeLength === minEdgeLength) return MIN_SUBDIVISIONS;
    // Normalize edge length to a value between 0 and 1
    const normalizedLength =
      (edgeLength - minEdgeLength) / (maxEdgeLength - minEdgeLength);

    // Map normalized length to subdivisions
    const subdivisions =
      MIN_SUBDIVISIONS +
      normalizedLength * (MAX_SUBDIVISIONS - MIN_SUBDIVISIONS);

    return Math.round(subdivisions);
  }

  // Initialize control points for each edge
  edges.forEach((edge) => {
    const numSubdivisions = getSubdivisions(edge.length);

    const points = [];
    const initialPoints = [];

    for (let i = 0; i <= numSubdivisions + 1; i++) {
      const t = i / (numSubdivisions + 1);
      const x = edge.source.x * (1 - t) + edge.target.x * t;
      const y = edge.source.y * (1 - t) + edge.target.y * t;

      points.push({ x, y, ox: x, oy: y }); // Store original positions
      initialPoints.push({
        x: edge.source.x,
        y: edge.source.y,
        ox: x,
        oy: y,
      }); // Start at source node
    }
    edge.points = points;
    edge.initialPoints = initialPoints;
    edge.numSubdivisions = numSubdivisions;
  });

  // Draw the edges starting from the source node
  const lineGenerator = d3
    .line()
    .curve(d3.curveBasis)
    .x((d) => d.x)
    .y((d) => d.y);

  const edgePaths = svg
    .append("g")
    .attr("class", "curved-lines")
    .selectAll("path")
    .data(edges)
    .enter()
    .append("path")
    .attr("d", (d) => lineGenerator(d.initialPoints)) // Start with initial points
    .attr("stroke", (d) => getCurveStyle(d, selectedTagRef.current).stroke)
    .attr("stroke-width", (d) =>
      selectedTagRef.current
        ? getCurveStyle(d, selectedTagRef.current).strokeWidth
        : CURVE_STROKE_WIDTH
    )
    .attr("fill", "none")
    .attr("opacity", (d) => getCurveStyle(d, selectedTagRef.current).opacity);

  // Animate edges from the source node to their initial positions
  edgePaths
    .transition()
    .duration(1000)
    .attr("d", (d) => lineGenerator(d.points))
    .end()
    .then(() => {
      // Start the simulation after the initial animation
      runSimulationStep();
    });

  let iteration = 0;

  function runSimulationStep() {
    // Check if the simulation has been canceled
    if (simulationID !== simulationIDRef.current) {
      return;
    }

    iteration++;

    // Collect all control points excluding endpoints
    const points = edges.flatMap((edge) => edge.points.slice(1, -1));

    // Build quadtree for efficient neighbor search
    const quadtree = d3.quadtree(
      points,
      (d) => d.x,
      (d) => d.y
    );

    // Apply forces to control points
    points.forEach((point) => {
      const radius = 25; // Search radius for nearby points
      const neighbors = [];

      quadtree.visit((node, x1, y1, x2, y2) => {
        if (!node.data || node.data === point) return false;
        const dx = node.data.x - point.x;
        const dy = node.data.y - point.y;
        if (dx * dx + dy * dy < radius * radius) {
          neighbors.push(node.data);
        }
        return (
          x1 > point.x + radius ||
          x2 < point.x - radius ||
          y1 > point.y + radius ||
          y2 < point.y - radius
        );
      });

      // Apply forces from neighbors
      neighbors.forEach((other) => {
        const dx = other.x - point.x;
        const dy = other.y - point.y;
        const distanceSquared = dx * dx + dy * dy;
        if (distanceSquared > 0) {
          const distance = Math.sqrt(distanceSquared);
          const force = (K * S) / distance;
          point.x += force * dx;
          point.y += force * dy;
        }
      });

      // Move point towards original position
      point.x += (point.ox - point.x) * P * S;
      point.y += (point.oy - point.y) * P * S;
    });

    // Update the drawing
    edgePaths
      .transition()
      .duration(500)
      .attr("d", (d) => lineGenerator(d.points))
      .end()
      .then(() => {
        if (iteration < maxIterations) {
          // Schedule next iteration
          requestAnimationFrame(runSimulationStep);
        }
      });
  }
};

// Custom force to pull nodes towards their subnetwork centers
const forcePullToSubnetworkCenter = (nodes, strength, width, height) => {
  const nodeById = new Map(nodes.map((d) => [d.id, d]));
  const widthHeightRatio = height / width;
  return (alpha) => {
    nodes.forEach((d) => {
      if (d.id.startsWith("Subnetwork-")) {
        const centralNode = nodeById.get("Central Node");
        const centerX = centralNode.x;
        const centerY = centralNode.y;
        d.vx += (centerX - d.x) * strength * alpha * widthHeightRatio;
        d.vy += (centerY - d.y) * strength * alpha;
      } else if (d.id.startsWith("Subnetwork-") && d.id.includes("Node")) {
        const parentId = d.id.split(" - ")[0];
        const parentNode = nodeById.get(parentId);
        d.vx += (parentNode.x - d.x) * strength * alpha * widthHeightRatio;
        d.vy += (parentNode.y - d.y) * strength * alpha;
      }
    });
  };
};

const SquishedBubbles = ({
  conversationIds,
  conversationToHighlights,
  highlights,
  onNodeClick,
  autoconnectHighlights,
  conversationNames,
  selectedTag,
  showConversationLabels,
  initiallySelectedHighlight,
}) => {
  const svgRef = useRef();
  const selectedNodeRef = useRef(null);
  const targetAlpha = 0.01;
  const simulationRef = useRef();
  const dataRef = useRef();
  const gCurvesRef = useRef();
  const nodesRef = useRef();
  const selectedTagRef = useRef(selectedTag);

  // Store the initial selection logic in a ref to prevent re-initialization
  const hasInitializedSelection = useRef(false);

  // Ref to manage simulation IDs
  const simulationIDRef = useRef(0);

  let svg;
  let gCurves;

  // Update the selected tag ref
  useEffect(() => {
    selectedTagRef.current = selectedTag;
  }, [selectedTag]);

  const handleNodeSelection = (
    node,
    data,
    gCurves,
    nodes,
    selectedTagRef,
    simulation,
    simulationIDRef
  ) => {
    onNodeClick(node.highlight);
    showConnections(
      node,
      data,
      gCurves,
      nodes,
      selectedTagRef,
      simulationIDRef
    );
    selectedNodeRef.current = node;
    simulation.alpha(0).stop(); // Stop the simulation immediately
  };

  useEffect(() => {
    if (!simulationRef.current) {
      const data = generateNodesAndLinks(
        conversationIds,
        conversationToHighlights
      );
      svg = d3.select(svgRef.current);

      // Remove any existing nodes or links
      svg.selectAll("*").remove();

      const width = svg.node().parentNode.clientWidth;
      const height = svg.node().parentNode.clientHeight;
      svg.style("background-color", "black");
      svg.attr("width", width).attr("height", height);

      initializeNodePositions(data.nodes, width, height);

      const hullsGroup = svg.append("g").attr("class", "hulls");
      gCurves = svg.append("g").attr("class", "curves");

      const nodes = svg
        .append("g")
        .attr("class", "nodes")
        .selectAll("circle")
        .data(data.nodes.filter((d) => d.highlight !== undefined)) // Only show nodes with highlights
        .enter()
        .append("circle")
        .attr("r", NODE_RADIUS)
        .attr("fill", (d) => (d.highlight ? d.highlight.color : "black"))
        .on("click", (event, d) => {
          event.stopPropagation(); // Prevent background click event from firing
          handleNodeSelection(
            d,
            data,
            gCurves,
            nodes,
            selectedTagRef,
            simulationRef.current,
            simulationIDRef
          );
        })
        .call(
          d3
            .drag()
            .on("start", (event, d) => {
              if (selectedNodeRef.current) return;
              if (!event.active)
                simulationRef.current.alphaTarget(0.3).restart();
              d.fx = d.x;
              d.fy = d.y;
            })
            .on("drag", (event, d) => {
              if (selectedNodeRef.current) return;
              d.fx = event.x;
              d.fy = event.y;
            })
            .on("end", (event, d) => {
              if (selectedNodeRef.current) return;
              if (!event.active) simulationRef.current.alphaTarget(targetAlpha);
              d.fx = null;
              d.fy = null;
            })
        );

      // Store variables in refs for access in other hooks
      dataRef.current = data;
      gCurvesRef.current = gCurves;
      nodesRef.current = nodes;

      const networkLabelsGroup = svg
        .append("g")
        .attr("class", "conversation-labels");

      const centralNodes = new Set(
        data.nodes
          .filter((d) => /^Central Node|^Subnetwork-\d+$/.test(d.id))
          .map((d) => d.id)
      );

      const simulation = d3
        .forceSimulation(data.nodes)
        .force(
          "link",
          d3
            .forceLink(data.links)
            .id((d) => d.id)
            .distance(10)
        )
        .force(
          "charge",
          d3
            .forceManyBody()
            .strength((d) =>
              centralNodes.has(d.id)
                ? 0
                : Math.sqrt(400 / conversationIds.length) * -10
            )
        )
        .force("center", d3.forceCenter(width / 2, height / 2))
        .force("collide", d3.forceCollide().radius(8))
        .force(
          "pull",
          forcePullToSubnetworkCenter(data.nodes, 0.05, width, height)
        ) // Custom force to pull nodes to subnetwork centers
        .alpha(1)
        .alphaDecay(0.02)
        .alphaTarget(targetAlpha);

      simulation.on("tick", () => {
        nodes.attr("cx", (d) => d.x).attr("cy", (d) => d.y);

        // Draw network outline for each subnetwork
        const subnetworks = {};
        data.nodes.forEach((d) => {
          const match = d.id.match(/^Subnetwork-(\d+)/);
          if (match) {
            const subnetworkId = match[0];
            if (!subnetworks[subnetworkId]) {
              subnetworks[subnetworkId] = [];
            }
            subnetworks[subnetworkId].push(d);
          }
        });

        hullsGroup.selectAll("*").remove();
        Object.keys(subnetworks).forEach((subnetworkId) => {
          const subnetworkNodes = subnetworks[subnetworkId];
          drawNetworkOutline(
            hullsGroup,
            networkLabelsGroup,
            subnetworkNodes,
            subnetworkId,
            conversationNames,
            showConversationLabels
          );
        });
      });

      simulationRef.current = simulation;

      // Periodically restart the simulation to keep the nodes moving
      const interval = setInterval(() => {
        if (!selectedNodeRef.current) {
          simulationRef.current.alpha(targetAlpha).restart();
        }
      }, 5000);

      svg.on("click", () => {
        // Increment simulation ID to cancel previous simulations
        simulationIDRef.current += 1;

        selectedNodeRef.current = null;
        onNodeClick(null);
        // Clear connections and reset node and link styles
        gCurves.selectAll(".curved-lines").remove();

        nodes
          .transition()
          .duration(500)
          .attr("r", NODE_RADIUS)
          .attr("opacity", NODE_OPACITY);

        // Restart simulation with target alpha
        simulationRef.current.alpha(targetAlpha).restart();
      });

      // Autoconnect highlights if provided
      if (autoconnectHighlights && autoconnectHighlights.length > 0) {
        const zoom = d3
          .zoom()
          .scaleExtent([0.5, 10])
          .on("zoom", (event) => {
            svg.attr("transform", event.transform);
          });

        // Apply the zoom behavior to the SVG
        svg.call(zoom);

        let index = 0;
        const runHighlightCycle = () => {
          if (index >= autoconnectHighlights.length) {
            // If we've reached the end of the array, reset all
            gCurves.selectAll(".curved-lines").remove();
            nodes
              .transition()
              .duration(500)
              .attr("r", NODE_RADIUS)
              .attr("opacity", NODE_OPACITY);
            svg
              .transition()
              .duration(1500)
              .call(zoom.transform, d3.zoomIdentity);
            simulationRef.current.alpha(targetAlpha).restart();
            return;
          }

          const highlightId = autoconnectHighlights[index];
          console.log("Highlight ID:", highlightId);
          const node = data.nodes.find(
            (n) => n.highlight && n.highlight.id === highlightId
          );
          if (node) {
            if (selectedNodeRef.current) {
              // Increment simulation ID to cancel previous simulations
              simulationIDRef.current += 1;

              gCurves.selectAll(".curved-lines").remove();
              nodes
                .transition()
                .duration(500)
                .attr("r", NODE_RADIUS)
                .attr("opacity", NODE_OPACITY);
            }
            // Restart simulation to allow nodes to move
            simulationRef.current.alpha(1).restart();
            setTimeout(() => {
              // Stop simulation and show connections after 2 seconds
              simulationRef.current.alpha(0).stop();
              handleNodeSelection(
                node,
                data,
                gCurves,
                nodes,
                selectedTagRef,
                simulationRef.current,
                simulationIDRef
              );
            }, 2000); // Allow nodes to move for 2 seconds
            setTimeout(() => {
              // Apply zoom to the selected node
              const transform = d3.zoomIdentity
                .translate(width / 2 - node.x, height / 2 - node.y)
                .scale(2);
              svg.transition().duration(1000).call(zoom.transform, transform);

              index++;
              setTimeout(runHighlightCycle, 5000); // Schedule next execution
            }, 1000);
          } else {
            console.warn(`Highlight ${highlightId} not found`);
            simulationRef.current.alphaTarget(targetAlpha).restart();
            index++;
            setTimeout(runHighlightCycle, 5000); // Schedule next execution
          }
        };

        // Start the cycle
        setTimeout(runHighlightCycle, 5000);
      }

      if (initiallySelectedHighlight && !hasInitializedSelection.current) {
        const node = data.nodes.find(
          (n) => n.highlight && n.highlight.id === initiallySelectedHighlight
        );
        if (node) {
          handleNodeSelection(
            node,
            data,
            gCurves,
            nodes,
            selectedTagRef,
            simulationRef.current,
            simulationIDRef
          );
          hasInitializedSelection.current = true;
        }
      }

      // Clear interval on unmount
      return () => clearInterval(interval);
    }
  }, [
    conversationIds,
    conversationToHighlights,
    highlights,
    autoconnectHighlights,
  ]);

  useEffect(() => {
    if (simulationRef.current && dataRef.current) {
      if (initiallySelectedHighlight) {
        const node = dataRef.current.nodes.find(
          (n) => n.highlight && n.highlight.id === initiallySelectedHighlight
        );
        if (node) {
          handleNodeSelection(
            node,
            dataRef.current,
            gCurvesRef.current,
            nodesRef.current,
            selectedTagRef,
            simulationRef.current,
            simulationIDRef
          );
        }
      } else {
        // Increment simulation ID to cancel previous simulations
        simulationIDRef.current += 1;

        selectedNodeRef.current = null;
        onNodeClick(null);
        // Clear connections and reset node and link styles
        gCurvesRef.current.selectAll(".curved-lines").remove();

        nodesRef.current
          .transition()
          .duration(500)
          .attr("r", NODE_RADIUS)
          .attr("opacity", NODE_OPACITY);

        simulationRef.current.alpha(1).restart();
      }
    }
  }, [initiallySelectedHighlight]);

  useEffect(() => {
    if (selectedTag) {
      const highlightNodes = d3.select(svgRef.current).selectAll("circle");
      const curves = d3.select(svgRef.current).selectAll(".curved-lines path");

      // Highlight nodes based on selected tag
      highlightNodes
        .transition()
        .attr("opacity", (d) =>
          nodeHasTag(d, selectedTag) ? NODE_OPACITY : NODE_HIGHLIGHT_OPACITY
        );

      // Update curves based on selected tag
      curves.transition().each(function (d) {
        const style = getCurveStyle(d, selectedTag);
        d3.select(this)
          .attr("stroke", style.stroke)
          .attr("opacity", style.opacity)
          .attr("stroke-width", style.strokeWidth);
      });
    }
  }, [selectedTag]);

  return <svg ref={svgRef}></svg>;
};

export default SquishedBubbles;
