import {debounce, flatten, get, isUndefined, map, mean, merge, reduce, sumBy} from 'lodash';
import {observer} from 'mobx-react';
import {FC, useCallback, useMemo, useState} from 'react';
import {scaleLinear} from '@visx/scale';
import {Circle} from '@visx/shape';
import cx from 'classnames';
import {useResizeDetector} from 'react-resize-detector';
import {formatChartAxisTime, Axes, AxisProps, AxesConsumerFn} from 'apstra-ui-common';

import ChartPopup from './ChartPopup';
import {extent} from './utils';

import './ScatterPlotChart.less';

const getTrend = (data, xAccessor, yAccessor) => {
  const [xs, ys] = [map<number>(data, xAccessor), map<number>(data, yAccessor)];
  const [meanX, meanY] = [mean(xs), mean(ys)];
  const [xd, yd] = [map(xs, (x) => (x - meanX)), map(ys, (y) => (y - meanY))];
  const m = reduce(xd, (acc, x, index) => (acc + x * yd[index]), 0) / sumBy(xd, (x) => (x * x));
  const b = meanY - m * meanX;
  return {m, b};
};

type ScatterPlotChartProps = {
  axes : Record<'x' | 'y', AxisProps>;
  data: [];
  width: number;
  height: number;
  className?: string;
  showTrend?: boolean;
  colorRange?: [string, string];
  maxSize?: number;
  xAccessor: string;
  yAccessor: string;
  valueAccessor: string;
  processPopupContent: (data: unknown) => unknown;
  debounceWait?: number;
  pointColor?: string;
  trendColor?: string;
};

const ScatterPlotChart: FC<ScatterPlotChartProps> = ({axes, data, width: widthProp, height: heightProp, className,
  showTrend, colorRange, maxSize, xAccessor, yAccessor, valueAccessor, processPopupContent, debounceWait, pointColor,
  trendColor}) => {
  const {popupDescription, containerRef, width, height, axesProps, render} = useScatterPlotChart({axes, data,
    widthProp, heightProp, showTrend, colorRange, maxSize, xAccessor, yAccessor, valueAccessor, processPopupContent,
    debounceWait, pointColor, trendColor});
  return (
    <>
      <div ref={containerRef} className={cx('scatter-plot-chart', className)}>
        <svg width={width} height={height}>
          <Axes {...axesProps}>
            {render}
          </Axes>
        </svg>
      </div>
      <ChartPopup popupDescription={popupDescription} />
    </>
  );
};

const useScatterPlotChart = ({axes, data, widthProp, heightProp = 200, colorRange, maxSize, showTrend, trendColor,
  xAccessor, yAccessor, valueAccessor, processPopupContent, debounceWait, pointColor}) => {
  const [popupDescription, setPopupDescription] = useState<unknown>(null);

  const {width: parentWidth = widthProp || 500, ref: containerRef} = useResizeDetector({handleHeight: false});
  const [width, height] = [isUndefined(widthProp) ? parentWidth : widthProp, heightProp];

  const xDomain = useMemo(
    () => {
      const [minX, maxX] = extent(data, xAccessor);
      const range = maxX - minX;
      const offset = maxSize ? (1.5 * maxSize * range) / width : 0;
      return [minX - offset, maxX + offset];
    },
    [data, maxSize, width, xAccessor]
  );

  const yDomain = useMemo(
    () => {
      const [minY, maxY] = extent(data, yAccessor);
      const range = maxY - minY;
      const offset = maxSize ? (1.5 * maxSize * range) / height : 0;
      return [minY - offset, maxY + offset];
    },
    [data, height, maxSize, yAccessor]
  );

  const sizeScale = useMemo(
    () =>
      scaleLinear({
        range: [1, maxSize],
        domain: extent(data, valueAccessor)
      }),
    [maxSize, data, valueAccessor]
  );

  const colorScale = useMemo(
    () =>
      scaleLinear({
        range: colorRange || ['#FF0000', '#0000FF'],
        domain: extent(data, valueAccessor),
      }),
    [colorRange, data, valueAccessor]
  );

  const showPopup = useMemo(
    () => debounce((e, data) => {
      setPopupDescription({
        node: e.target,
        header: 'Plot Data:',
        content: processPopupContent?.(data) ?? null,
      });
    }, debounceWait),
    [debounceWait, processPopupContent]
  );

  const hidePopup = useMemo(
    () => () => {
      showPopup.cancel();
      setPopupDescription(null);
    },
    [showPopup]
  );

  const axesProps = {
    width,
    height,
    x: merge(
      {
        isLinear: true,
        formatLabel: axes?.x?.isTimestamp ? formatChartAxisTime : undefined
      },
      axes?.x,
      {values: xDomain}
    ),
    y: merge(
      {isLinear: true},
      axes?.y,
      {values: yDomain}
    )
  };

  // Trend line calculation
  const {m, b} = useMemo(
    () => getTrend(data, xAccessor, yAccessor),
    [data, xAccessor, yAccessor]
  );

  const render = useCallback<AxesConsumerFn>(
    (xScale, yScale) => {
      const [x1, y1, x2, y2] = flatten(map(xDomain, (x) => [xScale(x), yScale(x * m + b)]));
      return [
        map(data, (d, index) => {
          const value = get(d, valueAccessor);
          return (
            <Circle
              key={index}
              {
                ...(valueAccessor && colorRange) ?
                  {fill: colorScale(value)} :
                  (pointColor ? {fill: pointColor} : {})
              }
              className='plot'
              cx={xScale(get(d, xAccessor))}
              cy={yScale(get(d, yAccessor))}
              r={maxSize ? sizeScale(value) : 5}
              onMouseEnter={(e) => showPopup(e, d)}
              onMouseLeave={hidePopup}
            />
          );
        }),
        showTrend &&
          <line
            key='trend'
            className='trend'
            {...{x1, y1, x2, y2, style: trendColor ? {stroke: trendColor} : undefined}}
          />
      ];
    },
    [xDomain, data, showTrend, trendColor, m, b, valueAccessor, colorRange, colorScale, pointColor, xAccessor,
      yAccessor, maxSize, sizeScale, hidePopup, showPopup]
  );

  return {popupDescription, render, containerRef, width, height, axesProps};
};

ScatterPlotChart.defaultProps = {
  processPopupContent: (data) => data,
  debounceWait: 200
};

export default observer(ScatterPlotChart);
