import React, { useEffect, useRef } from "react";
import * as d3 from "d3";

// Constants for styling
const NODE_RADIUS = 5;
const SELECTED_NODE_RADIUS = 8;
const NODE_OPACITY = 1;
const FADED_NODE_OPACITY = 0.3;
const CURVE_STROKE_WIDTH = 1;
const THICK_CURVE_STROKE_WIDTH = 3;
const CURVE_OPACITY = 0.3;
const HIGHLIGHTED_CURVE_OPACITY = 0.7;
const CURVE_COLOR = "white";
const FADED_CURVE_COLOR = "gray";
const NODE_HIGHLIGHT_OPACITY = 0.3;

// Generate random data
export const generateNodesAndLinks = (
  conversationIds,
  conversationToHighlights
) => {
  const centralNode = { id: "Central Node" };
  const nodes = [centralNode];
  const links = [];
  const subnetworksCount = conversationIds.length;

  for (let i = 0; i < subnetworksCount; i++) {
    const conversationId = conversationIds[i];
    const subnetworkId = `Subnetwork-${i}`;
    nodes.push({ id: subnetworkId, conversationId });
    links.push({ source: "Central Node", target: subnetworkId });

    const nodesInSubnetwork = conversationToHighlights[conversationId];

    for (let j = 0; j < nodesInSubnetwork.length; j++) {
      const nodeId = `${subnetworkId} - Node ${j}`;
      nodes.push({ id: nodeId, highlight: nodesInSubnetwork[j] });
      links.push({ source: subnetworkId, target: nodeId });
    }
  }

  return { nodes, links };
};

const drawNetworkOutline = (
  g,
  networkLabelsGroup,
  nodes,
  label,
  conversationNames,
  showConversationLabels
) => {
  // Remove existing outlines and labels
  g.selectAll(`.network-outline-${label}`).remove();
  networkLabelsGroup.selectAll(`.network-label-${label}`).remove();

  // Calculate the convex hull of the nodes
  const hull = d3.polygonHull(nodes.map((d) => [d.x, d.y]));

  if (hull) {
    // Calculate the centroid of the hull
    const centroid = d3.polygonCentroid(hull);
    const expansionFactor = 20; // Amount by which to expand the hull

    // Expand the convex hull outward from the centroid
    const expandedHull = hull.map(([x, y]) => {
      const [cx, cy] = centroid;
      const dx = x - cx;
      const dy = y - cy;
      const dist = Math.sqrt(dx * dx + dy * dy);
      const scale = (dist + expansionFactor) / dist;
      return [cx + dx * scale, cy + dy * scale];
    });

    // Create a smooth closed curve around the expanded convex hull
    const outlinePathGenerator = d3
      .line()
      .x((d) => d[0])
      .y((d) => d[1])
      .curve(d3.curveBasisClosed);

    g.append("path")
      .datum(expandedHull)
      .attr("class", `network-outline network-outline-${label}`)
      .attr("d", outlinePathGenerator)
      .attr("fill", "white")
      .attr("fill-opacity", 0.15);

    if (showConversationLabels) {
      // Add the label at the centroid of the expanded hull
      const networkName = conversationNames
        ? conversationNames[nodes[0].conversationId] ??
          `Conversation ${nodes[0].conversationId}`
        : `Conversation ${nodes[0].conversationId}`;

      // top of the hull
      const top = expandedHull.reduce((acc, cur) =>
        acc[1] > cur[1] ? cur : acc
      );

      const networkNameSplit = networkName.split("\n");
      // iterate over the split network name and add it to the label
      networkNameSplit.forEach((token, index) => {
        networkLabelsGroup
          .append("text")
          .attr("class", `network-label network-label-${label}`)
          .attr("x", centroid[0])
          .attr("y", top[1] + index * 15 - (networkNameSplit.length - 1) * 10)
          .attr("text-anchor", "middle")
          .attr("fill", "white")
          .attr("opacity", 0.8)
          .attr("font-size", "10px")
          .attr("font-weight", "bold")
          .text(token);
      });
    }
  }
};

const initializeNodePositions = (nodes, width, height) => {
  const centralNodes = nodes.filter((d) => /^Subnetwork-\d+$/.test(d.id));
  const gridSize = Math.ceil(Math.sqrt(centralNodes.length));
  const spacingX = width / (gridSize + 1);
  const spacingY = height / (gridSize + 1);

  centralNodes.forEach((node, index) => {
    const row = Math.floor(index / gridSize);
    const col = index % gridSize;
    node.x = spacingX * (col + 1);
    node.y = spacingY * (row + 1);

    // Place subnetwork nodes around their central node
    const subnetworkNodes = nodes.filter((d) =>
      d.id.startsWith(`${node.id} - Node`)
    );
    subnetworkNodes.forEach((subNode, subIndex) => {
      const angle = (2 * Math.PI * subIndex) / subnetworkNodes.length;
      const radius = 50; // Radius around the central node
      subNode.x = node.x + radius * Math.cos(angle);
      subNode.y = node.y + radius * Math.sin(angle);
    });
  });

  // Position the central node in the center
  const centralNode = nodes.find((d) => d.id === "Central Node");
  if (centralNode) {
    centralNode.x = width / 2;
    centralNode.y = height / 2;
  }
};

const showConnections = (node, data, gCurves, nodes, selectedTag) => {
  const nodeTags = new Set(node.highlight.tags);
  const targetNodes = [];
  data.nodes.forEach((n) => {
    if (n.highlight) {
      n.highlight.tags.forEach((tag) => {
        if (nodeTags.has(tag)) {
          targetNodes.push(n);
          return;
        }
      });
    }
  });

  if (targetNodes.length === 0) {
    console.log("No connections found for node", node.highlight);
    return;
  }

  drawCurvedLines(gCurves, node, targetNodes, selectedTag);

  // Enlarge the current node and its connected nodes, fade others
  nodes
    .transition()
    .duration(500)
    .attr("r", (n) =>
      n.id === node.id || targetNodes.includes(n)
        ? SELECTED_NODE_RADIUS
        : NODE_RADIUS
    )
    .attr("opacity", (n) =>
      n.id === node.id || targetNodes.includes(n)
        ? !selectedTag || n.highlight.tags.includes(selectedTag)
          ? NODE_OPACITY
          : FADED_NODE_OPACITY
        : FADED_NODE_OPACITY
    );
};

const drawCurvedLines = (svg, sourceNode, targetNodes, selectedTag) => {
  // Clear existing curves
  svg.selectAll(".curved-lines").remove();

  const linkHorizontal = d3
    .linkHorizontal()
    .x((d) => d.x)
    .y((d) => d.y);

  const paths = targetNodes.map((targetNode) => ({
    source: sourceNode,
    target: targetNode,
  }));

  const pathSelection = svg
    .append("g")
    .attr("class", "curved-lines")
    .selectAll("path")
    .data(paths)
    .enter()
    .append("path")
    .attr("d", linkHorizontal)
    .attr("stroke", (d) => {
      const sourceTags = new Set(d.source.highlight.tags);
      const targetTags = new Set(d.target.highlight.tags);
      return !selectedTag
        ? CURVE_COLOR
        : sourceTags.has(selectedTag) && targetTags.has(selectedTag)
        ? CURVE_COLOR
        : FADED_CURVE_COLOR;
    })
    .attr("stroke-width", (d) => {
      const sourceTags = new Set(d.source.highlight.tags);
      const targetTags = new Set(d.target.highlight.tags);
      return !selectedTag
        ? CURVE_STROKE_WIDTH
        : sourceTags.has(selectedTag) && targetTags.has(selectedTag)
        ? THICK_CURVE_STROKE_WIDTH
        : CURVE_STROKE_WIDTH;
    })
    .attr("fill", "none")
    .attr("opacity", (d) => {
      const sourceTags = new Set(d.source.highlight.tags);
      const targetTags = new Set(d.target.highlight.tags);
      return !selectedTag
        ? HIGHLIGHTED_CURVE_OPACITY
        : sourceTags.has(selectedTag) && targetTags.has(selectedTag)
        ? HIGHLIGHTED_CURVE_OPACITY
        : CURVE_OPACITY;
    });

  pathSelection
    .attr("stroke-dasharray", function () {
      return this.getTotalLength();
    })
    .attr("stroke-dashoffset", function () {
      return this.getTotalLength();
    })
    .transition()
    .duration(2000)
    .attr("stroke-dashoffset", 0);
};

// Custom force to pull nodes towards their subnetwork centers
const forcePullToSubnetworkCenter = (nodes, strength, width, height) => {
  const nodeById = new Map(nodes.map((d) => [d.id, d]));
  const widthHeightRatio = height / width;
  return (alpha) => {
    nodes.forEach((d) => {
      if (d.id.startsWith("Subnetwork-")) {
        const centralNode = nodeById.get("Central Node");
        const centerX = centralNode.x;
        const centerY = centralNode.y;
        d.vx += (centerX - d.x) * strength * alpha * widthHeightRatio;
        d.vy += (centerY - d.y) * strength * alpha;
      } else if (d.id.startsWith("Subnetwork-") && d.id.includes("Node")) {
        const parentId = d.id.split(" - ")[0];
        const parentNode = nodeById.get(parentId);
        d.vx += (parentNode.x - d.x) * strength * alpha * widthHeightRatio;
        d.vy += (parentNode.y - d.y) * strength * alpha;
      }
    });
  };
};

const SquishedBubbles = ({
  conversationIds,
  conversationToHighlights,
  highlights,
  onNodeClick,
  autoconnectHighlights,
  conversationNames,
  selectedTag,
  showConversationLabels,
}) => {
  const svgRef = useRef();
  const selectedNodeRef = useRef(null);
  const targetAlpha = 0.01;

  // ref for the selected tag
  const selectedTagRef = useRef(selectedTag);

  let svg;
  let gCurves;

  // Update the selected tag ref
  useEffect(() => {
    selectedTagRef.current = selectedTag;
  }, [selectedTag]);

  useEffect(() => {
    if (selectedTag) {
      const highlightNodes = d3.select(svgRef.current).selectAll("circle");
      const curves = d3.select(svgRef.current).selectAll(".curved-lines path");

      // Highlight nodes and thicken curves based on selected tag
      highlightNodes
        .transition()
        .attr("opacity", (d) =>
          d.highlight && d.highlight.tags.includes(selectedTag)
            ? NODE_OPACITY
            : NODE_HIGHLIGHT_OPACITY
        );

      curves
        .transition()
        .attr("stroke", (d) => {
          const sourceTags = new Set(d.source.highlight.tags);
          const targetTags = new Set(d.target.highlight.tags);
          return sourceTags.has(selectedTag) && targetTags.has(selectedTag)
            ? CURVE_COLOR
            : FADED_CURVE_COLOR;
        })
        .attr("opacity", (d) => {
          const sourceTags = new Set(d.source.highlight.tags);
          const targetTags = new Set(d.target.highlight.tags);
          return sourceTags.has(selectedTag) && targetTags.has(selectedTag)
            ? HIGHLIGHTED_CURVE_OPACITY
            : CURVE_OPACITY;
        })
        .attr("stroke-width", (d) => {
          const sourceTags = new Set(d.source.highlight.tags);
          const targetTags = new Set(d.target.highlight.tags);
          return sourceTags.has(selectedTag) && targetTags.has(selectedTag)
            ? THICK_CURVE_STROKE_WIDTH
            : CURVE_STROKE_WIDTH;
        });
    }
  }, [selectedTag]);

  useEffect(() => {
    const data = generateNodesAndLinks(
      conversationIds,
      conversationToHighlights
    );
    svg = d3.select(svgRef.current);

    // Remove any existing nodes or links
    svg.selectAll("*").remove();

    const width = svg.node().parentNode.clientWidth;
    const height = svg.node().parentNode.clientHeight;
    svg.style("background-color", "black");
    svg.attr("width", width).attr("height", height);

    initializeNodePositions(data.nodes, width, height);

    const hullsGroup = svg.append("g").attr("class", "hulls");
    gCurves = svg.append("g").attr("class", "curves");

    const nodes = svg
      .append("g")
      .attr("class", "nodes")
      .selectAll("circle")
      .data(data.nodes.filter((d) => d.highlight !== undefined)) // Only show nodes with highlights
      .enter()
      .append("circle")
      .attr("r", NODE_RADIUS)
      .attr("fill", (d) => (d.highlight ? d.highlight.color : "black"))
      .on("click", (event, d) => {
        event.stopPropagation(); // Prevent background click event from firing
        onNodeClick(d.highlight);
        showConnections(d, data, gCurves, nodes, selectedTagRef.current);
        selectedNodeRef.current = d;
        simulation.alpha(0).stop(); // Stop the simulation immediately
      })
      .call(
        d3
          .drag()
          .on("start", (event, d) => {
            if (selectedNodeRef.current) return;
            if (!event.active) simulation.alphaTarget(0.3).restart();
            d.fx = d.x;
            d.fy = d.y;
          })
          .on("drag", (event, d) => {
            if (selectedNodeRef.current) return;
            d.fx = event.x;
            d.fy = event.y;
          })
          .on("end", (event, d) => {
            if (selectedNodeRef.current) return;
            if (!event.active) simulation.alphaTarget(targetAlpha);
            d.fx = null;
            d.fy = null;
          })
      );

    const networkLabelsGroup = svg
      .append("g")
      .attr("class", "conversation-labels");

    const centralNodes = new Set(
      data.nodes
        .filter((d) => /^Central Node|^Subnetwork-\d+$/.test(d.id))
        .map((d) => d.id)
    );

    const simulation = d3
      .forceSimulation(data.nodes)
      .force(
        "link",
        d3
          .forceLink(data.links)
          .id((d) => d.id)
          .distance(10)
      )
      .force(
        "charge",
        d3
          .forceManyBody()
          .strength((d) =>
            centralNodes.has(d.id)
              ? 0
              : Math.sqrt(400 / conversationIds.length) * -10
          )
      )
      .force("center", d3.forceCenter(width / 2, height / 2))
      .force("collide", d3.forceCollide().radius(8))
      .force(
        "pull",
        forcePullToSubnetworkCenter(data.nodes, 0.05, width, height)
      ) // Custom force to pull nodes to subnetwork centers
      .alphaTarget(targetAlpha);

    simulation.on("tick", () => {
      nodes.attr("cx", (d) => d.x).attr("cy", (d) => d.y);

      // Draw network outline for each subnetwork
      const subnetworks = {};
      data.nodes.forEach((d) => {
        const match = d.id.match(/^Subnetwork-(\d+)/);
        if (match) {
          const subnetworkId = match[0];
          if (!subnetworks[subnetworkId]) {
            subnetworks[subnetworkId] = [];
          }
          subnetworks[subnetworkId].push(d);
        }
      });

      hullsGroup.selectAll("*").remove();
      Object.keys(subnetworks).forEach((subnetworkId) => {
        const subnetworkNodes = subnetworks[subnetworkId];
        drawNetworkOutline(
          hullsGroup,
          networkLabelsGroup,
          subnetworkNodes,
          subnetworkId,
          conversationNames,
          showConversationLabels
        );
      });
    });

    // Periodically restart the simulation to keep the nodes moving
    const interval = setInterval(() => {
      if (!selectedNodeRef.current) {
        simulation.alpha(targetAlpha).restart();
      }
    }, 5000);

    const handleMouseStop = () => {
      if (!selectedNodeRef.current) {
        simulation.alphaTarget(targetAlpha).restart();
      }
    };

    svg.on("click", () => {
      selectedNodeRef.current = null;
      onNodeClick(null);
      // Clear connections and reset node and link styles
      gCurves.selectAll(".curved-lines").remove();

      nodes
        .transition()
        .duration(500)
        .attr("r", NODE_RADIUS)
        .attr("opacity", NODE_OPACITY);

      handleMouseStop();
    });

    // Autoconnect highlights if provided
    if (autoconnectHighlights && autoconnectHighlights.length > 0) {
      // Define the zoom behavior
      const zoom = d3
        .zoom()
        .scaleExtent([0.5, 10])
        .on("zoom", (event) => {
          svg.attr("transform", event.transform);
        });

      // Apply the zoom behavior to the SVG
      svg.call(zoom);

      let index = 0;
      const runHighlightCycle = () => {
        if (index >= autoconnectHighlights.length) {
          // If we've reached the end of the array, reset all
          gCurves.selectAll(".curved-lines").remove();
          nodes
            .transition()
            .duration(500)
            .attr("r", NODE_RADIUS)
            .attr("opacity", NODE_OPACITY);
          svg.transition().duration(1500).call(zoom.transform, d3.zoomIdentity);
          simulation.alphaTarget(targetAlpha).restart();
          return;
        }

        const highlightId = autoconnectHighlights[index];
        console.log("Highlight ID:", highlightId);
        const node = data.nodes.find(
          (n) => n.highlight && n.highlight.id === highlightId
        );
        if (node) {
          if (selectedNodeRef.current) {
            gCurves.selectAll(".curved-lines").remove();
            nodes
              .transition()
              .duration(500)
              .attr("r", NODE_RADIUS)
              .attr("opacity", NODE_OPACITY);
          }
          // restart simulation to allow nodes to move
          simulation.alpha(targetAlpha).restart();
          setTimeout(() => {
            // Stop simulation and show connections after 1 second
            simulation.alpha(0).stop();
            showConnections(node, data, gCurves, nodes, selectedTag);
            selectedNodeRef.current = node;
          }, 2000); // Allow nodes to move for 1 second
          setTimeout(() => {
            // Apply zoom to the selected node
            const transform = d3.zoomIdentity
              .translate(width / 2 - node.x, height / 2 - node.y)
              .scale(2); // Adjust the scale factor as needed
            svg.transition().duration(1000).call(zoom.transform, transform);

            index++;
            setTimeout(runHighlightCycle, 5000); // Schedule next execution
          }, 1000);
        } else {
          console.warn(`Highlight ${highlightId} not found`);
          simulation.alphaTarget(targetAlpha).restart();
          index++;
          setTimeout(runHighlightCycle, 5000); // Schedule next execution
        }
      };

      // Start the cycle
      setTimeout(runHighlightCycle, 5000);
    }

    // Clear interval on unmount
    return () => clearInterval(interval);
  }, [
    conversationIds,
    conversationToHighlights,
    highlights,
    autoconnectHighlights,
  ]);

  return <svg ref={svgRef}></svg>;
};

export default SquishedBubbles;
