import {useResizeDetector} from 'react-resize-detector';
import {useMemo} from 'react';
import {hierarchy, Partition} from '@visx/hierarchy';
import cx from 'classnames';
import {Group} from '@visx/group';
import {Arc} from '@visx/shape';
import {brandColorNames, formatNumber} from 'apstra-ui-common';
import {forEach, isUndefined, keyBy, map, size, transform, uniqueId} from 'lodash';
import {Text} from '@visx/text';
import {scaleOrdinal} from 'd3';

import {Filter3d} from './DonutChart';
import {TooltipPopup, TooltipProvider, useTooltip} from './GraphTooltips';
import ChartLegend, {useHoveredLegendItem} from './ChartLegend';

import './SunburstChart.less';

const HOVER_OFFSET = 8;
export const HORIZONTAL_MARGINS = 70;
export const MIN_SECTOR_SIZE = 0.4;

const findParentNode = (node) => {
  let parent = node.parent;
  while (parent.parent && parent.depth !== 1) {
    parent = parent.parent;
  }
  return parent;
};

const collectNodes = (node) => {
  const result = [node];
  return node.children ?
    transform(node.children, (acc, child) => acc.push(...collectNodes(child)), result) :
    result;
};

const SunburstChart = ({
  data, width: propsWidth, className, thickness, opacityStep, withLegend,
}) => {
  const {width: parentWidth, ref} = useResizeDetector({handleWidth: true});
  const fluidWidth = isUndefined(propsWidth);
  const size = fluidWidth ? parentWidth : propsWidth;

  const root = useMemo(() => {
    return hierarchy(data)
      .sort((a, b) => (b.size || 0) - (a.size || 0))
      .sum((d) => d.size);
  }, [data]);

  const colors = useMemo(() => {
    let i = 0;
    return transform(root.descendants().slice(1), (result, {depth, data, ...node}) => {
      if (depth === 1) {
        result[data.id] = data.color || brandColorNames[i++ % brandColorNames.length];
      } else {
        const parent = findParentNode(node);
        result[data.id] = result[parent?.data.id] || brandColorNames[i++ % brandColorNames.length];
      }
    });
  }, [root]);

  const {onMouseOver, onMouseOut, hoveredItem} = useHoveredLegendItem();

  // All nodes flattened
  const allNodes = useMemo(
    () => collectNodes(root),
    [root]
  );

  // Map node.id -> node
  const nodesMap = useMemo(
    () => keyBy(allNodes, 'data.id'),
    [allNodes]
  );

  const [ordinalColorScale, legendDescriptionMap, maxDepth] = useMemo(() => {
    const domain = [];
    const range = [];
    const legendDescriptionMap = {};
    let maxDepth = 0;

    forEach(allNodes.slice(1), ({depth, data, value}) => {
      domain.push(data.id);
      if (depth < 4) {
        legendDescriptionMap[data.id] = {
          name: data.name,
          value: formatNumber(value, {units: data.units, short: true}),
          glyphClassName: `sunburst-legend-${depth}-glyph`,
          itemClassName: `sunburst-legend-${depth}-item`,
        };
      }
      if (depth > maxDepth) maxDepth = depth;
      range.push(colors[data.id]);
    });
    return [
      scaleOrdinal()
        .domain(domain)
        .range(range),
      legendDescriptionMap,
      maxDepth
    ];
  }, [allNodes, colors]);

  // For hovered item highlight its ascendants and descendants
  const hoveredIds = useMemo(
    () => {
      if (!hoveredItem) return [];
      const hoveredNode = nodesMap[hoveredItem];
      return hoveredNode ?
        map([...hoveredNode.ancestors(), ...hoveredNode.descendants()], 'data.id') :
        [];
    },
    [hoveredItem, nodesMap]
  );

  const filterId = useMemo(() => uniqueId('sunburst-chart-shadow-filter'), []);
  const radius = size / 2;
  const backgroundArcThickness = 2.5 * thickness;
  const outerRadius = radius - (backgroundArcThickness - thickness) / 2;

  return (
    <TooltipProvider>
      <div className='sunburst-chart-wrapper'>
        <div
          ref={ref}
          className={cx('sunburst-chart', className, {fluid: fluidWidth})}
          style={fluidWidth ? undefined : {width: propsWidth, height: propsWidth}}
        >
          <svg width={fluidWidth ? '100%' : size} height={size}>
            <defs>
              <Filter3d filterId={filterId} />
            </defs>
            <Group top={radius} left={radius}>
              <Arc
                className='sunburst-chart-background-arc'
                startAngle={0} endAngle={360}
                innerRadius={0}
                outerRadius={radius}
                filter={`url(#${filterId})`}
              />
              <Partition
                className='partition'
                top={0}
                left={0}
                root={root}
                size={[2 * Math.PI, outerRadius]}
              >
                {(data) => {
                  const props = {
                    data: data,
                    colors: colors,
                    maxDepth,
                    hoveredIds,
                    opacityStep: opacityStep,
                    hoveredItem: hoveredItem,
                    onMouseOver: onMouseOver,
                    onMouseOut: onMouseOut
                  };
                  return [
                    <ArcGroup key='arcs' {...props} />,
                    <ArcGroup key='labels' {...props} onlyLabels />
                  ];
                }}
              </Partition>
            </Group>
          </svg>
          <TooltipPopup hideCloseButton />
        </div>
        {withLegend && (
          <ChartLegend
            ordinalColorScale={ordinalColorScale}
            legendDescriptionMap={legendDescriptionMap}
            onMouseOut={onMouseOut}
            onMouseOver={onMouseOver}
            hoveredItem={hoveredItem}
            ignoreEmpty
          />
        )}
      </div>
    </TooltipProvider>
  );
};

SunburstChart.defaultProps = {
  thickness: 10,
  opacityStep: 0.3,
};

const ArcGroup = ({data, colors, opacityStep, hoveredIds, onMouseOver, onMouseOut, onlyLabels, maxDepth}) => {
  const nodes = useMemo(() => {
    return data.descendants().slice(1);
  }, [data]);
  return (
    <Group>
      {map(nodes, (node, index) => (
        <ArcPath
          key={`node-${index}`}
          opacityStep={opacityStep}
          hoveredIds={hoveredIds}
          colors={colors}
          node={node}
          onMouseOver={onMouseOver}
          onMouseOut={onMouseOut}
          onlyLabels={onlyLabels}
          maxDepth={maxDepth}
        />
      ))}
    </Group>
  );
};

const ArcPath = ({opacityStep, hoveredIds, colors, node, onMouseOut, onMouseOver, onlyLabels, maxDepth}) => {
  const {x0, x1, y0, y1, depth, data, value} = node;

  const {sharedTooltip} = useTooltip();

  const isHovered = hoveredIds.includes(node?.data.id);
  const noHover = !size(hoveredIds);

  const handleMouseOver = () => {
    sharedTooltip.show(`${data.name}: ${value}`, true);
    onMouseOver(data.id);
  };

  const handleMouseOut = () => {
    onMouseOut();
    sharedTooltip.hide();
  };

  const arcOffset = (isHovered && depth === maxDepth) ? HOVER_OFFSET : 0;
  const labelAnchor = ((x0 + x1) / 2) < Math.PI ? 'start' : 'end';
  const sectorIsBigEnough = Math.abs(x1 - x0) > MIN_SECTOR_SIZE;
  const showLabel = depth === 1 && ((noHover && sectorIsBigEnough) || isHovered);

  return (
    <Arc
      innerRadius={depth === 1 ? 0 : y0}
      outerRadius={y1 + (onlyLabels ? 0 : arcOffset)}
      startAngle={x0}
      endAngle={x1}
    >
      {({path}) => (
        <g>
          {onlyLabels ?
            (showLabel &&
              <Text
                className='label'
                x={path.centroid()[0]}
                y={path.centroid()[1]}
                textAnchor={labelAnchor}
              >
                {data.name}
              </Text>
            ) :
            <path
              className={cx('visx-arc', colors[data.id], {hovered: isHovered})}
              fillOpacity={1 - opacityStep * (depth - 1)}
              d={path()}
              onMouseOut={handleMouseOut}
              onMouseOver={handleMouseOver}
            />
          }
        </g>
      )}
    </Arc>
  );
};

export default SunburstChart;
