import {Group} from '@visx/group';
import cx from 'classnames';
import {easeCubic} from 'd3';
import {assign, cloneDeep, drop, filter, find, first, forEach, get, isFinite, map, max, min,
  sortBy, transform, uniq, uniqueId, zip} from 'lodash';
import {observable, action, computed, reaction, comparer, makeObservable} from 'mobx';
import {observer} from 'mobx-react';
import PropTypes from 'prop-types';
import {Component, Fragment, useMemo} from 'react';
import {Spring, animated, to} from '@react-spring/web';
import {Message, Segment} from 'semantic-ui-react';
import {natsort} from 'apstra-ui-common';

import {sankeyCircular} from '../../../vendor/d3-sankey-circular';

import './TrafficDiagram.less';
import {speedToInt} from './trafficUtils';

const Link = animated(({rx = 5, className, y0, y1, xSource, xTarget, stroke, strokeWidth, linkProps}) => {
  const d = useMemo(() => {
    const w = 0.5 * strokeWidth;
    const r = Math.min(rx, strokeWidth);
    const xc = 0.5 * (xSource + xTarget);
    return `M${xSource + r},${y0 - w}C${xc},${y0 - w},${xc},${y1 - w},${xTarget - r},${y1 - w}` +
      `a${r},${r},0,0,1,${r},${r}v${strokeWidth - 2 * r}a${r},${r},0,0,1,${-r},${r}` +
      `C${xc},${y1 + w},${xc},${y0 + w},${xSource + r},${y0 + w}` +
      `a${r},${r},0,0,1,${-r},${-r}v${2 * r - strokeWidth}a${r},${r},0,0,1,${r},${-r}`;
  }, [strokeWidth, rx, xSource, y0, xTarget, y1]);
  return <path
    className={cx('iba-utilization-links', className)}
    d={d}
    fill={stroke}
    {...linkProps}
  />;
});

@observer
export default class TrafficDiagram extends Component {
  static propTypes = {
    nodeLinkTraffic: PropTypes.object.isRequired,
    paths: PropTypes.array.isRequired
  };

  static defaultProps = {
    minZoom: 0.5,
    maxZoom: 8,
    zoomDeltaMultiplier: 0.01,
    baseFontSize: 14,
    baseLinkHeight: 20,
    baseLinkWidth: 200,
    inactiveColor: '#c0c5cb',
    width: 840,
    height: 500,
    nodeWidth: 30,
    minNodeHeight: 3,
    nodePadding: 10,
    nodeLabelOffsetX: 1,
    nodeLabelRotateAngle: -90,
    nodeClipPathMargin: 3,
    interfaceLabelOffset: 12,
    margin: {top: 20, right: 16, bottom: 6, left: 18},
    springConfig: {duration: 1000, easing: easeCubic},
    maxLinkWidth: 100,
    nodeRectRadius: 5,
    linkHorizontalMargin: 1,
    aggregateRectRadius: 8,
    aggregatePadding: 3,
    aggregateWidth: 16,
    aggregateXOffset: 15,
    paddingRecalculateThreshold: 0.1,
    sankeyIterations: 0,
    showLabelsMaxLinkCount: 16,
  };

  @observable verticalPadding = 0;

  @action
  calculateVerticalPadding = () => {
    const {
      props: {margin, maxLinkWidth, height, paddingRecalculateThreshold},
      layout, sankeyGraph, verticalPadding
    } = this;
    const maxWidth = max(map(
      filter(layout(sankeyGraph).links, ({active, width}) => isFinite(width) && active),
      'width'));
    if (!isFinite(maxWidth)) return;
    const diagramHeight = height - margin.top - margin.bottom - 2 * verticalPadding;
    const scaledHeight = diagramHeight / maxWidth * maxLinkWidth;
    if (Math.abs(scaledHeight / diagramHeight - 1) < paddingRecalculateThreshold) return;
    this.verticalPadding = Math.max(0, 0.5 * (height - scaledHeight - margin.top - margin.bottom));
  };

  positionNodeLabel = (node) => {
    const {nodeLabelOffsetX} = this.props;
    const x = (node.x1 - node.x0) / 2 + nodeLabelOffsetX;
    return {x, y: (node.y1 - node.y0) / 2};
  };

  clipPathId = uniqueId();

  getScaledFontSize = (sankey) => {
    const {baseFontSize, baseLinkHeight, baseLinkWidth} = this.props;
    const minLinkWidth = min(map(filter(sankey.links, 'active'), 'width'));
    const sortedNodeX = sortBy(uniq(map(sankey.nodes, 'x0')));
    const nodeXPairs = zip(sortedNodeX, drop(sortedNodeX));
    const minNodeDistance = min(map(nodeXPairs, ([x0, x1]) => x1 - x0));
    return baseFontSize * min([1, minLinkWidth / baseLinkHeight, minNodeDistance / baseLinkWidth]);
  };

  constructor(props) {
    super(props);
    makeObservable(this);
    this.paddingCalculatorDisposer = reaction(
      () => [this.props.height, this.props.paths, this.props.nodeLinkTraffic],
      this.calculateVerticalPadding,
      {equals: comparer.structural, fireImmediately: true}
    );
  }

  componentWillUnmount() {
    this.paddingCalculatorDisposer();
  }

  updateSankeyGraphWithValues = (sankeyGraph) => {
    const {props: {nodeLinkTraffic}} = this;
    forEach(sankeyGraph.links, (link) => {
      assign(link, find(nodeLinkTraffic[link.sourceId].outputs, {
        sourceIntf: link.sourceIntf,
        targetIntf: link.targetIntf,
        targetId: link.targetId,
      }));
    });
    forEach(sankeyGraph.nodes, (node) => {
      assign(node, nodeLinkTraffic[node.id]);
    });
  };

  @computed.struct
  get sankeyGraph() {
    const sankeyGraph = pathsToSankeyGraph(this.props.paths);
    this.updateSankeyGraphWithValues(sankeyGraph);
    return sankeyGraph;
  }

  @computed
  get layout() {
    const {props: {margin, nodePadding, nodeWidth, width, height, paths, sankeyIterations}, verticalPadding} = this;
    return sankeyCircular()
      .nodeWidth(nodeWidth)
      .nodePadding(Math.min(nodePadding, 0.3 * height / paths.length))
      .size([width, height])
      .nodeId(({id}) => id)
      .extent([[margin.left, margin.top + verticalPadding],
        [width - margin.right, height - margin.bottom - verticalPadding]])
      .iterations(sankeyIterations)
      .linkPadding(5)
      .minLinkWidth(3);
  }

  @computed
  get fontSize() {
    return this.getScaledFontSize(this.layout(this.sankeyGraph));
  }

  renderLink = (link, index) => {
    const {
      props: {inactiveColor, springConfig, linkHorizontalMargin},
    } = this;
    const linkWithMargin = cloneDeep(link);
    linkWithMargin.source.x1 += linkHorizontalMargin;
    linkWithMargin.target.x0 -= linkHorizontalMargin;
    const {width, y0, y1, source: {x1: xSource}, target: {x0: xTarget}} = linkWithMargin;
    const linkWidth = Math.max(1, width);
    return (
      <Spring
        key={`link-${index}`}
        to={{y0, y1, strokeWidth: linkWidth, xSource, xTarget}}
        config={springConfig}
      >
        {({y0, y1, strokeWidth, xSource, xTarget}) => {
          const {
            className: linkClassName, stroke = inactiveColor, ...linkProps
          } = linkWithMargin.props ?? {};
          const linkClipPathId = `utilization-link-clip-${uniqueId()}`;
          return (
            <Fragment>
              <clipPath id={linkClipPathId}>
                <animated.rect
                  width={to([xTarget, xSource], (xTarget, xSource) => xTarget - xSource)}
                  height={to([y0, y1, strokeWidth], (y0, y1, strokeWidth) => Math.abs(y1 - y0) + strokeWidth)}
                  transform={to([y0, y1, strokeWidth, xSource], (y0, y1, strokeWidth, xSource) =>
                    `translate(${xSource}, ${Math.min(y0, y1) - 0.5 * strokeWidth})`
                  )}
                />
              </clipPath>
              <g clipPath={`url(#${linkClipPathId})`}>
                <Link
                  className={linkClassName}
                  y0={y0}
                  y1={y1}
                  stroke={stroke}
                  strokeWidth={strokeWidth.to((strokeWidth) => Math.max(1, strokeWidth - 1))}
                  linkProps={linkProps}
                  xSource={xSource}
                  xTarget={xTarget}
                />
              </g>
            </Fragment>
          );
        }}
      </Spring>
    );
  };

  renderNode = (node, index) => {
    const {
      props: {
        inactiveColor, nodeWidth, nodeClipPathMargin, springConfig,
        nodeLabelRotateAngle, nodeRectRadius, minNodeHeight
      },
      fontSize, positionNodeLabel
    } = this;
    const {x, y} = positionNodeLabel(node);
    const {x0, y0, y1, source, target, name} = node;
    const generalX = x + x0;
    const generalY = y + y0;
    const nodeHeight = max([y1 - y0, minNodeHeight]);
    const clipPathId = `utilization-node-text-${index}-${this.clipPathId}`;
    const {
      className: nodeClassName,
      fill = inactiveColor,
      ...nodeProps
    } = node.props ?? {};
    return (
      <Spring
        key={`node-${index}`}
        to={{xNode: x0, yNode: y0, height: nodeHeight, yNodeLabel: generalY, fontSize}}
        config={springConfig}
      >
        {({xNode, yNode, height, yNodeLabel, fontSize}) =>
          <Group
            className='iba-utilization-nodes'
          >
            <animated.rect
              className={cx('iba-utilization-nodes-rect',
                nodeClassName
              )}
              width={nodeWidth}
              height={height}
              transform={to([xNode, yNode], (xNode, yNode) =>
                `translate(${xNode}, ${yNode})`
              )}
              fill={fill}
              rx={nodeRectRadius}
              ry={nodeRectRadius}
              {...nodeProps}
            />
            <clipPath id={clipPathId}>
              <animated.rect
                width={nodeWidth}
                height={height.to((height) => Math.max(height - nodeClipPathMargin * 2, 0))}
                transform={to([xNode, yNode], (xNode, yNode) =>
                  `translate(${xNode}, ${yNode + nodeClipPathMargin})`
                )}
              />
            </clipPath>
            <g clipPath={`url(#${clipPathId})`}>
              <animated.text
                className={
                  cx('iba-utilization-nodes-text', {'node-source-or-target': source ?? target})
                }
                dominantBaseline='middle'
                transform={yNodeLabel.to((yNodeLabel) =>
                  `translate(${generalX}, ${yNodeLabel}) rotate(${nodeLabelRotateAngle})`)
                }
                fontSize={fontSize}
              >
                {name}
              </animated.text>
            </g>
          </Group>
        }
      </Spring>
    );
  };

  renderLinkLabels = ({y0, y1, sourceIntf, targetIntf, source, target}, index) => {
    const {
      props: {interfaceLabelOffset, springConfig},
      fontSize
    } = this;
    return (
      <Spring
        key={`nodeLabel-${index}`}
        to={{xSource: source.x1, ySource: y0, xTarget: target.x0, yTarget: y1, fontSize}}
        config={springConfig}
      >
        {({xSource, ySource, xTarget, yTarget, fontSize}) =>
          <Group>
            {sourceIntf !== 'UNKNOWN' &&
              <animated.text
                className='iba-utilization-labels-source'
                x={xSource.to((xSource) => xSource + interfaceLabelOffset)}
                y={ySource}
                fontSize={fontSize}
                dominantBaseline='central'
              >
                {sourceIntf}
              </animated.text>
            }
            {targetIntf !== 'UNKNOWN' &&
              <animated.text
                className='iba-utilization-labels-target'
                x={xTarget.to((xTarget) => xTarget - interfaceLabelOffset)}
                y={yTarget}
                fontSize={fontSize}
                dominantBaseline='central'
              >
                {targetIntf}
              </animated.text>
            }
          </Group>
        }
      </Spring>
    );
  };

  renderLinkAggregates = (aggregates, aggregateId) => {
    const {aggregateRectRadius, aggregatePadding, aggregateWidth, aggregateXOffset, springConfig} = this.props;
    const renderSource = new Set(map(aggregates, 'sourceId')).size === 1;
    const [node, x, y, direction] = renderSource ? ['source', 'x1', 'y0', 1] : ['target', 'x0', 'y1', -1];
    const xPos = get(first(aggregates), [node, x]);
    const yMin = min(map(aggregates, (aggregate) => aggregate[y] - aggregate.width / 2));
    const yMax = max(map(aggregates, (aggregate) => aggregate[y] + aggregate.width / 2));
    return (
      <Spring
        key={aggregateId}
        to={{
          x: xPos - 0.5 * aggregateWidth + direction * aggregateXOffset,
          y: yMin - aggregatePadding,
          height: yMax - yMin + 2 * aggregatePadding,
        }}
        config={springConfig}
      >
        {({x, y, height}) =>
          <animated.rect
            className='iba-utilization-aggregate'
            width={aggregateWidth}
            height={height}
            x={x}
            y={y}
            rx={aggregateRectRadius}
          />
        }
      </Spring>
    );
  };

  render() {
    const {
      props: {showLabelsMaxLinkCount, width, height, сlassName},
      sankeyGraph, transform, layout, renderLink, renderNode, renderLinkLabels, renderLinkAggregates,
    } = this;

    layout(sankeyGraph).nodes.forEach((node) => {
      if (!node.sourceLinks.length) node.source = true;
      if (!node.targetLinks.length) node.target = true;
    });

    const renderInterfaceLabels = sankeyGraph.links.length < showLabelsMaxLinkCount;

    return (
      <div>
        <Segment size='mini'>
          {!renderInterfaceLabels && <Message
            info
            icon='info circle'
            content={'Interface labels are hidden due to the size of topology. ' +
              'Please hover over the link to see the interface name.'}
          />}
          <svg
            className={cx('iba-utilization', сlassName)}
            viewBox={`0 0 ${width} ${height}`}
          >
            <g transform={transform}>
              {map(sankeyGraph.nodes, renderNode)}
              {map(sankeyGraph.links, renderLink)}
              {renderInterfaceLabels && map(filter(sankeyGraph.links, 'active'), renderLinkLabels)}
              {map(sankeyGraph.linkAggregates, renderLinkAggregates)}
            </g>
          </svg>
        </Segment>
      </div>
    );
  }
}

export function generateLinkId(link, path) {
  return !path ?
      [link.src_system, link.src_if_name, link.dst_system, link.dst_if_name].join('-') :
      [path.src_node, link.src_intf_name, path.dst_node, link.dst_intf_name].join('-');
}

export function pathsToSankeyGraph(paths) {
  const nodesMap = new Map();
  const linksMap = new Map();
  forEach(paths, (path) => {
    forEach(path, (pathFragment) => {
      const {dst_node: dstNodeName, src_node: srcNodeName, links, srcId, dstId, srcRole, dstRole} =
        pathFragment;
      nodesMap.set(srcId, {
        id: srcId,
        name: srcNodeName,
        role: srcRole,
      });
      nodesMap.set(dstId, {
        id: dstId,
        name: dstNodeName,
        role: dstRole,
      });
      forEach(links, (link, index) => {
        const {speed, src_intf_name: srcInterfaceName, dst_intf_name: dstInterfaceName, aggregateId} = link;
        const linkId = generateLinkId(link, pathFragment);
        if (!linksMap.has(linkId)) {
          linksMap.set(linkId, {
            index: index,
            speed: speedToInt(speed),
            source: srcId,
            target: dstId,
            sourceIntf: srcInterfaceName,
            targetIntf: dstInterfaceName,
            sourceId: srcId,
            targetId: dstId,
            sourceRole: srcRole,
            targetRole: dstRole,
            srcNodeName,
            dstNodeName,
            aggregateId
          });
        }
      });
    });
  });
  const result = {
    nodes: [...nodesMap.values()].sort((node1, node2) => natsort(node1.name, node2.name)),
    links: [...linksMap.values()].sort((link1, link2) => (
      compareAggregated(link1.aggregateId, link2.aggregateId) ||
      natsort(link1.srcNodeName, link2.srcNodeName) ||
      natsort(link1.dstNodeName, link2.dstNodeName) ||
      natsort(link1.sourceIntf, link2.sourceIntf) ||
      natsort(link1.targetIntf, link2.targetIntf) ||
      natsort(link1.index, link2.index)
    ))
  };

  result.linkAggregates = transform(result.links, (acc, link) => {
    if (!link.aggregateId) return;
    if (!acc[link.aggregateId]) acc[link.aggregateId] = [];
    acc[link.aggregateId].push(link);
  }, {});
  return result;
}

function compareAggregated(aggregateId1, aggregateId2) {
  if (aggregateId1 && aggregateId2) return natsort(aggregateId1, aggregateId2);
  if (aggregateId1) return 1;
  if (aggregateId2) return -1;
  return 0;
}
