import {curveBasis} from '@visx/curve';
import {LinePath} from '@visx/shape';
import cx from 'classnames';
import {compact, find, flatMap, groupBy, keys, pull, some, sortBy, transform, values} from 'lodash';
import {computed, makeObservable} from 'mobx';
import {observer} from 'mobx-react';
import {Component, FunctionComponentElement, ReactNode, cloneElement} from 'react';

import IBAContext, {IBAContextValue} from '../../../IBAContext';
import {ProbeHighlights} from '../types';
import {ProbeProcessorType} from '../../../types';

type ConnectorsCanvasProps = {
  arrowMarkerSize: number;
  highlights: ProbeHighlights;
  levelSize: number;
  lineRadius: number;
  maxArrowLevels: number;
  maxDistinctInputs: number;
  outputCounterRadius: number;
  processorSpacing: number;
  processorStepHeight: number;
  processors: ProbeProcessorType[];
  stageStepHeight: number;
  style?: React.CSSProperties;
  width: number;
}

type ArrowPosition = {
  length: number;
  level?: number;
  processorName: string;
  processorPosition: number;
  stageName?: string;
  stagePosition: number;
};

@observer
export class ConnectorsCanvas extends Component<ConnectorsCanvasProps> {
  static contextType = IBAContext;

  static defaultProps = {
    lineRadius: 5,
    arrowMarkerSize: 10,
    maxArrowLevels: 3,
    maxDistinctInputs: 4,
    levelSize: 5,
    outputCounterRadius: 6,
  };

  constructor(props) {
    super(props);
    makeObservable(this);
  }

  @computed get elementPositions() {
    const {processors, processorStepHeight, stageStepHeight, processorSpacing} = this.props;
    const positions = {canvasHeight: 0, processors: {}, stages: {}};
    let currentPosition = 0;
    for (const processor of processors) {
      positions.processors[processor.name] = currentPosition;
      currentPosition += processorStepHeight;
      const processorDefinition = find((this.context as IBAContextValue).processorDefinitions, {name: processor.type})!;
      for (const outputName of keys(processorDefinition.outputs)) {
        const stageName = processor.outputs[outputName];
        positions.stages[stageName] = currentPosition;
        currentPosition += stageStepHeight;
      }
      currentPosition += processorSpacing;
    }
    positions.canvasHeight = currentPosition;
    return positions;
  }

  @computed get arrowPositions(): ArrowPosition[] {
    const {elementPositions, props: {
      processors, processorStepHeight, stageStepHeight, arrowMarkerSize, maxArrowLevels, maxDistinctInputs
    }} = this;
    const arrowPositions = flatMap(processors, ({inputs, name: processorName}) => {
      return transform(compact(values(inputs)), (result, {stage: stageName}, stageIndex, stages) => {
        if (processorName in elementPositions.processors && stageName! in elementPositions.stages) {
          const inputCount = Math.min(maxDistinctInputs, stages.length);
          const processorPosition =
            elementPositions.processors[processorName] + processorStepHeight / 2 -
            (inputCount - 1) * arrowMarkerSize / 2 + Math.min(stageIndex, maxDistinctInputs - 1) * arrowMarkerSize;
          const stagePosition = elementPositions.stages[stageName!] + stageStepHeight / 2;
          const length = Math.abs(processorPosition - stagePosition);
          result.push({processorName, processorPosition, stageName, stagePosition, length});
        }
      }, [] as ArrowPosition[]);
    });
    type ArrowPositionGroups = {
      arrowPositions: ArrowPosition[];
      length: number;
      processorPosition: number;
      stagePosition: number;
    }
    const arrowPositionGroups = transform(arrowPositions, (result, position) => {
      const arrowPositionGroup = find(result, (arrowPositionGroup) =>
        arrowPositionGroup.processorPosition === position.processorPosition &&
        arrowPositionGroup.stagePosition === position.stagePosition
      );
      if (arrowPositionGroup) {
        arrowPositionGroup.arrowPositions.push(position);
        arrowPositionGroup.processorPosition = Math.max(
          arrowPositionGroup.processorPosition,
          position.processorPosition
        );
        arrowPositionGroup.stagePosition = Math.min(
          arrowPositionGroup.stagePosition,
          position.stagePosition
        );
        arrowPositionGroup.length = Math.abs(arrowPositionGroup.processorPosition - arrowPositionGroup.stagePosition);
      } else {
        result.push({
          arrowPositions: [position],
          processorPosition: position.processorPosition,
          stagePosition: position.stagePosition,
          length: position.length
        });
      }
    }, [] as ArrowPositionGroups[]);
    const arrowPositionGroupsByLevel = [] as ArrowPositionGroups[][];
    function positionsIntersect(position1: ArrowPositionGroups, position2: ArrowPositionGroups) {
      const [min1, max1] = sortBy([position1.processorPosition, position1.stagePosition]);
      const [min2, max2] = sortBy([position2.processorPosition, position2.stagePosition]);
      return max1 > min2 && max2 > min1;
    }
    while (arrowPositionGroups.length && arrowPositionGroupsByLevel.length < maxArrowLevels) {
      const arrowPositionGroupsForCurrentLevel = transform(arrowPositionGroups, (result, positionGroup) => {
        if (!result.some((anotherPositionGroup) =>
          positionGroup === anotherPositionGroup || positionsIntersect(positionGroup, anotherPositionGroup)
        )) {
          result.push(positionGroup);
        }
      }, [] as ArrowPositionGroups[]);
      pull(arrowPositionGroups, ...arrowPositionGroupsForCurrentLevel);
      arrowPositionGroupsByLevel.push(arrowPositionGroupsForCurrentLevel);
    }
    if (arrowPositionGroups.length) {
      arrowPositionGroupsByLevel[arrowPositionGroupsByLevel.length - 1].push(...arrowPositionGroups);
    }
    const arrowPositionsWithLevels = transform(arrowPositionGroupsByLevel, (result, arrowPositionGroups, index) => {
      for (const arrowPositionGroup of arrowPositionGroups) {
        for (const arrowPosition of arrowPositionGroup.arrowPositions) {
          arrowPosition.level = index;
        }
        result.push(...arrowPositionGroup.arrowPositions);
      }
    }, [] as ArrowPosition[]);
    return arrowPositionsWithLevels;
  }

  @computed get arrowCounters() {
    return groupBy(this.arrowPositions, 'stagePosition');
  }

  render() {
    const {width, lineRadius, arrowMarkerSize, levelSize, highlights, style, outputCounterRadius} = this.props;
    const startX = width;
    const drawnArrows = {};
    return (
      <svg
        className='connectors-canvas'
        style={{...style, width, height: this.elementPositions.canvasHeight}}
      >
        {transform(
          this.arrowPositions,
          (
            [regularArrows, highlightableArrows],
            {processorName, processorPosition, stageName, stagePosition, level}
          ) => {
            const key = `${processorName} ${stageName}`;
            if (key in drawnArrows) return;
            drawnArrows[key] = true;
            const endX = 1 + level! * levelSize;
            const startY = Math.min(stagePosition, processorPosition);
            const endY = Math.max(stagePosition, processorPosition);
            const arrowBody = (
              <LinePath
                data={[
                  {x: startX, y: startY},
                  {x: endX + lineRadius, y: startY},
                  {x: endX, y: startY},
                  {x: endX, y: startY + lineRadius},
                  {x: endX, y: endY - lineRadius},
                  {x: endX, y: endY},
                  {x: endX + lineRadius, y: endY},
                  {x: startX, y: endY},
                ]}
                x={({x}) => x}
                y={({y}) => y}
                curve={curveBasis}
              />
            );
            const arrowHead = (
              <path
                d={[
                  `M${startX - arrowMarkerSize},${processorPosition - 1 + arrowMarkerSize / 2}`,
                  `L${startX - 1},${processorPosition}`,
                  `L${startX - arrowMarkerSize},${processorPosition + 1 - arrowMarkerSize / 2}`,
                ].join(' ')}
              />
            );
            regularArrows.push(
              cloneElement(arrowBody, {key: `${key} body regular`}),
              cloneElement(arrowHead, {key: `${key} head regular`})
            );
            const highlighted =
            processorName in highlights.processors && highlights.processors[processorName].highlightConnections ||
            stageName! in highlights.stages && highlights.stages[stageName!].highlightConnections;
            const highlightedClassName = highlighted ? 'highlighted' : 'not-highlighted';
            highlightableArrows.push(
              cloneElement(arrowBody, {key: `${key} body highlightable`, className: highlightedClassName}),
              cloneElement(arrowHead, {key: `${key} head highlightable`, className: highlightedClassName})
            );
          },
          [[] as FunctionComponentElement<any>[], [] as FunctionComponentElement<any>[]]
        )}
        {transform(this.arrowCounters, (acc, v, pos) => {
          if (v.length < 2) return;
          const highlighted = some(v, ({processorName, stageName}) =>
            processorName in highlights.processors && highlights.processors[processorName].highlightConnections ||
            stageName! in highlights.stages && highlights.stages[stageName!].highlightConnections);
          const highlightedClassName = highlighted ? 'highlighted' : 'not-highlighted';
          acc.push(
            <g
              key={pos}
              transform={`translate(${startX - outputCounterRadius},${pos})`}
            >
              <ellipse
                className={cx('output-counter', highlightedClassName)}
                cx={0}
                cy={0}
                rx={outputCounterRadius}
                ry={outputCounterRadius}
              />
              <text className={cx('output-counter', highlightedClassName)} x={0} y={0}>{v.length}</text>
            </g>
          );
        }, [] as ReactNode[])}
      </svg>
    );
  }
}
