import {debounce, flatten, get, isUndefined, map, mean, reduce, round, sumBy} from 'lodash';
import {observer} from 'mobx-react';
import {useMemo, useState} from 'react';
import {scaleLinear, scaleTime} from '@visx/scale';
import {Circle} from '@visx/shape';
import {Group} from '@visx/group';
import cx from 'classnames';
import {Axis, AxisLeft} from '@visx/axis';
import {GridColumns, GridRows} from '@visx/grid';
import {useResizeDetector} from 'react-resize-detector';
import {useAxisWidth, formatChartAxisTime} from 'apstra-ui-common';

import ChartPopup from './ChartPopup';
import {extent} from './utils';

import './ScatterPlotChart.less';
import {DATA_FORMAT} from './BoxplotChart';

const minXTick = 50;
const minYTick = 20;

const getTrend = (data, xAccessor, yAccessor) => {
  const [xs, ys] = [map(data, xAccessor), map(data, yAccessor)];
  const [meanX, meanY] = [mean(xs), mean(ys)];
  const [xd, yd] = [map(xs, (x) => (x - meanX)), map(ys, (y) => (y - meanY))];
  const m = reduce(xd, (acc, x, index) => (acc + x * yd[index]), 0) / sumBy(xd, (x) => (x * x));
  const b = meanY - m * meanX;
  return {m, b};
};

const ScatterPlotChart = ({data, width: widthProp, height: heightProp, className, showTrend, colorRange, maxSize,
  x = {}, y = {}, valueAccessor, processPopupContent, debounceWait, pointColor, trendColor, dataXFormat}) => {
  const {popupDescription, margin, xLabel, yLabel, xTicks, yTicks, xScale, yScale, innerWidth, innerHeight,
    x1, y1, x2, y2, dots, containerRef, width, height, leftAxisRef, leftAxisWidth,
    formatXTick} = useScatterPlotChart({data,
    widthProp, heightProp, className, showTrend, colorRange, maxSize, x, y, valueAccessor, processPopupContent,
    debounceWait, pointColor, dataXFormat});
  return (
    <>
      <div ref={containerRef} className={cx('scatter-plot-chart', className)}>
        <svg width={width} height={height}>
          <GridRows
            left={margin.left}
            scale={yScale}
            width={innerWidth}
            strokeOpacity={0.8}
            pointerEvents='none'
            numTicks={yTicks}
          />
          <GridColumns
            top={margin.top}
            scale={xScale}
            height={innerHeight}
            strokeOpacity={0.8}
            pointerEvents='none'
            numTicks={xTicks}
          />
          <AxisLeft
            innerRef={leftAxisRef}
            scale={yScale}
            left={margin.left}
            label={yLabel}
            labelProps={{y: 5 - leftAxisWidth}}
            numTicks={yTicks}
          />
          <Axis
            orientation='bottom'
            scale={xScale}
            top={innerHeight + margin.top}
            numTicks={xTicks}
            label={xLabel}
            tickFormat={formatXTick}
          />
          <Group top={0} left={0}>{dots}</Group>
          {
            showTrend &&
              <line className='trend' {...{x1, y1, x2, y2, style: trendColor ? {stroke: trendColor} : undefined}} />
          }
        </svg>
      </div>
      <ChartPopup popupDescription={popupDescription} />
    </>
  );
};

const useScatterPlotChart = ({data, widthProp, heightProp = 200, colorRange, maxSize,
  x = {}, y = {}, valueAccessor, processPopupContent, debounceWait, pointColor, dataXFormat}) => {
  const [popupDescription, setPopupDescription] = useState(null);
  const {ref: leftAxisRef, width: leftAxisWidth} = useAxisWidth(y?.label);

  const {width: parentWidth = widthProp || 500, ref: containerRef} = useResizeDetector({handleHeight: false});
  const [width, height] = [isUndefined(widthProp) ? parentWidth : widthProp, heightProp];

  const formatIsTimestamp = dataXFormat === DATA_FORMAT.timestamp;

  // Margings with and without the axis labels
  const margin = {
    top: 10,
    right: 10,
    bottom: x?.label ? 50 : 30,
    left: 5 + leftAxisWidth
  };

  // Dimensions of the graph itself
  const innerWidth = (width - margin.left - margin.right) || 1;
  const innerHeight = (height - margin.top - margin.bottom) || 1;

  const {
    accessor: xAccessor = 'x',
    ticks: xTicks = round(innerWidth / minXTick),
    label: xLabel
  } = x;

  const {
    accessor: yAccessor = 'y',
    ticks: yTicks = round(innerHeight / minYTick),
    label: yLabel
  } = y;

  const xDomain = useMemo(
    () => {
      const [minX, maxX] = extent(data, xAccessor);
      const range = maxX - minX;
      const offset = maxSize ? (1.5 * maxSize * range) / innerWidth : 0;
      return [minX - offset, maxX + offset];
    },
    [data, innerWidth, maxSize, xAccessor]
  );

  const yDomain = useMemo(
    () => {
      const [minY, maxY] = extent(data, yAccessor);
      const range = maxY - minY;
      const offset = maxSize ? (1.5 * maxSize * range) / innerHeight : 0;
      return [minY - offset, maxY + offset];
    },
    [data, innerHeight, maxSize, yAccessor]
  );

  const xScale = useMemo(
    () =>
      (formatIsTimestamp ? scaleTime : scaleLinear)({
        range: [margin.left, innerWidth + margin.left],
        domain: xDomain
      }),
    [formatIsTimestamp, innerWidth, margin.left, xDomain]
  );

  const yScale = useMemo(
    () =>
      scaleLinear({
        range: [innerHeight + margin.top, margin.top],
        domain: yDomain
      }),
    [innerHeight, margin.top, yDomain]
  );

  const formatXTick = formatIsTimestamp ? formatChartAxisTime : undefined;

  const sizeScale = useMemo(
    () =>
      scaleLinear({
        range: [1, maxSize],
        domain: extent(data, valueAccessor)
      }),
    [maxSize, data, valueAccessor]
  );

  const colorScale = useMemo(
    () =>
      scaleLinear({
        range: colorRange || ['#FF0000', '#0000FF'],
        domain: extent(data, valueAccessor),
      }),
    [colorRange, data, valueAccessor]
  );

  const showPopup = useMemo(
    () => debounce((e, data) => {
      setPopupDescription({
        node: e.target,
        header: 'Plot Data:',
        content: processPopupContent?.(data) ?? null,
      });
    }, debounceWait),
    [debounceWait, processPopupContent]
  );

  const hidePopup = useMemo(
    () => () => {
      showPopup.cancel();
      setPopupDescription(null);
    },
    [showPopup]
  );

  const dots = useMemo(
    () =>
      map(data, (d, index) => {
        const value = get(d, valueAccessor);
        return (
          <Circle
            key={index}
            {
              ...(valueAccessor && colorRange) ?
                {fill: colorScale(value)} :
                (pointColor ? {fill: pointColor} : {})
            }
            className='plot'
            cx={xScale(get(d, xAccessor))}
            cy={yScale(get(d, yAccessor))}
            r={maxSize ? sizeScale(value) : 5}
            onMouseEnter={(e) => showPopup(e, d)}
            onMouseLeave={hidePopup}
          />
        );
      }),
    [colorRange, colorScale, hidePopup, maxSize, data, showPopup, sizeScale,
      valueAccessor, xAccessor, xScale, yAccessor, yScale, pointColor]
  );

  // Trend line calculation
  const {m, b} = useMemo(
    () => getTrend(data, xAccessor, yAccessor),
    [data, xAccessor, yAccessor]
  );

  const [x1, y1, x2, y2] = useMemo(
    () => flatten(map(xDomain, (x) => [xScale(x), yScale(x * m + b)])),
    [b, m, xDomain, xScale, yScale]
  );

  return {popupDescription, margin, xLabel, yLabel, xTicks, yTicks, xScale, yScale, innerWidth, innerHeight,
    x1, y1, x2, y2, dots, containerRef, width, height, leftAxisRef, leftAxisWidth, formatXTick};
};

ScatterPlotChart.defaultProps = {
  processPopupContent: (data) => data,
  debounceWait: 200
};

export default observer(ScatterPlotChart);
