import React, { useEffect, useRef } from 'react'
import { useSelector } from 'react-redux'
import lodashUpdate from 'lodash/update'
import { useUnitsGen } from '../../hooks'
import BaseWidget from '../BaseWidget'
import { useDarkModeEffect } from '../../hooks'
import { usePlot, useWaveformData } from './hooks'
import { arrayMap, promisify } from '../../../../../../utils/utils'

import { PlotBuilder } from '../../PlotJS'
import { mean_waveforms, overlap_waveforms } from './builders'
import { updateAxesRange, onUpdateDeph, factoryColumnBuckets } from './helpers'
import {
  DEPH,
  DEPTH_UPDATER,
  SENSITIVITY,
  WAVEFORM_CANVAS_ID,
  LAYOUT_ARGS_DEFAULT
} from './consts'
import { StandardKeyEvents } from '../../../../../TimeseriesView/CaptureKeyEvent/helpers'
import { DrawingLoader } from '../../components/DrawingLoader'
import { getOr } from 'lodash/fp'
import {
  REINIT_CLUSTER_UNITS,
  REINIT_SIMILARITY_UNITS
} from '../../redux/actions'
import { useReinitClusters } from '../utils'

export class PlotWaveforms {
  static renderPlot = null
  static units = null
  static plot = new PlotBuilder()
  static builders = {}
  static updatePlotData = () => {}
  static clearTraces = () => {}
  static clearAllTraces = () => {}

  static logic = factoryColumnBuckets()

  static setRenderPlot(builder) {
    this.renderPlot = builder
  }

  static setUnits(units) {
    this.units = units
  }

  static initEvents(updaters) {
    Object.keys(updaters).forEach(key => {
      this[key] = updaters[key]
    })
  }

  static initBuilders(builders) {
    this.builders = builders
  }

  static setBuilder(type) {
    this.plot.changeBuilder(this.builders[type])
  }

  static get builder() {
    return this.plot.builderType
  }

  static clearAllBuckets() {
    this.logic.clearAllBuckets()
  }

  static clearBucket(unit) {
    this.logic.clearBucket(unit)
  }

  static handleBuckets(channel) {
    return this.logic.handleBuckets(channel) * 17
  }

  static getBucketIndices(channel, unit) {
    return this.logic.getBucketIndices(channel, unit)
  }

  static makeGeometry(waveformGeometry) {
    return ({ channel, y, base }) => {
      if (waveformGeometry && base) return base

      return [this.handleBuckets(channel), y]
    }
  }

  static runBuilder(args) {
    const { unit, waveformGeometry, waveformOverlapping, unitIdx } = args

    if (!this.units.includes(unit)) return {}

    return this.builder(
      unit,
      args.payloadData,
      args.HEIGHT,
      args.deph,
      channels => {
        return waveformGeometry
          ? {}
          : {
              columnsPos: this.logic.mutateBuckets(
                channels.reduce((acc, v) => ({ ...acc, [v]: args.unit }), {}),
                String(unit)
              ),
              yAxisChunk: this.logic.getYAxisBucket(unit)
            }
      },
      waveformOverlapping,
      unitIdx
    )
  }
}

export const WaveFormView = ({ W, H }) => {
  const elem = React.useRef(null)
  const traces = useRef(new Map([]))
  const channelLabels = useRef(new Map([]))
  const channelsMapGeometry = useRef(new Map([]))
  const deph = useRef(DEPH)
  const maxXAxis = useRef([])
  const darkMode = useSelector(state => state.darkMode)
  const loading = useSelector(state => state.widgetsCache.waveforms.loading)
  const { clusterUnits, similarityUnits, uniqueUnits, render } = useUnitsGen()
  const {
    waveformsView,
    algo,
    sortingId,
    curationId,
    waveformGeometry,
    waveformOverlapping
  } = useWaveformData()

  const layoutArgs = useRef({
    axisArgs: LAYOUT_ARGS_DEFAULT[algo]
  })

  const scalingRef = useRef({
    x: 0
  })

  const renderPlot = usePlot(W - 3, H - 30, darkMode, layoutArgs.current)

  useEffect(() => {
    PlotWaveforms.setRenderPlot(renderPlot)
    PlotWaveforms.initEvents({ updatePlotData, clearTraces, clearAllTraces })

    return () => {
      //if waveform view is not rendered the request should not be made
      PlotWaveforms.setRenderPlot(null)
    }
  }, [renderPlot])

  useEffect(() => {
    PlotWaveforms.initBuilders({ mean_waveforms, overlap_waveforms })
    PlotWaveforms.setBuilder(waveformsView)
  }, [])

  useEffect(() => {
    clearAllTraces()
  }, [sortingId, curationId, algo])

  function clearAllTraces() {
    const gd = document.getElementById(WAVEFORM_CANVAS_ID)

    channelLabels.current = new Map([])
    channelsMapGeometry.current = new Map([])
    traces.current = new Map([])

    if (Array.isArray(gd?.data)) {
      channelLabels.current = new Map()
      channelsMapGeometry.current = new Map()
      traces.current = new Map()

      const indices = gd.data.map((_, i) => i)

      renderPlot.removeTrace(elem.current, indices, { offsetY })

      PlotWaveforms.clearAllBuckets()
    }
  }

  function updateRemainingTraces(gd, indices) {
    const data = gd.data
      .filter(x => x.updateX)
      .reduce((acc, v) => {
        const trace_length = v.trace_length
        const channel = v.text
        const unit = v['data-id']
        const xAxis = waveformOverlapping
          ? 0
          : PlotWaveforms.logic.getBucketIndices(channel, unit) * trace_length

        return [...acc, v.updateX(xAxis)]
      }, [])

    renderPlot.updateAttribute(elem.current ?? gd, { x: data }, indices)
  }

  const updateYOffsets = gd => {
    const annotations = renderPlot.getAnnotations()
    let dict = new Map()

    for (let k in annotations) {
      const { text, y, x } = annotations[k]
      if (!dict.has(text)) {
        dict.set(text, [x, y])
      }
    }

    const sortedData = [...dict.entries()].sort((a, b) => a[1][1] - b[1][1])

    for (let i = 0; i < sortedData.length; i++) {
      const elem = sortedData[i]

      if (sortedData[i - 1]) {
        if (elem[1][1] === sortedData[i - 1][1][1]) {
          elem[1][1] += 4

          for (let j = i + 1; j < sortedData.length; j++) {
            sortedData[j][1][1] += 4
          }
        }
      }
    }

    dict = new Map(sortedData)

    const indices = gd.data.map((_, i) => i)

    const data = gd.data
      .filter(x => x.updateY)
      .reduce((acc, v) => {
        const [x, y] = dict.get(v.text) || [0, v.offsetY]
        return [...acc, v.updateY({ newDeph: deph.current, offset: y })]
      }, [])

    renderPlot.updateAttribute(elem.current ?? gd, { y: data }, indices)
    renderPlot.updateAnnotations(
      elem.current ?? gd,
      sortedData,
      waveformGeometry
    )
  }

  function clearTraces(units, removedState, currentUnit) {
    var gd = document.getElementById(WAVEFORM_CANVAS_ID)

    const removedUnits = new Set()

    if (Array.isArray(gd?.data)) {
      const removedUnitIndices = gd.data
        .map((trace, i) => {
          const unit = trace['data-id']

          /*
          here is the check if the plot has already drawn
          the unit you are selected twice to not draw it again
          */
          if (unit === currentUnit) {
            removedUnits.add(currentUnit)
            return i
          } else if (units.has(unit)) {
            return null
          }
          removedUnits.add(unit)
          return i
        })
        .filter(x => x !== null)

      for (let unit of removedUnits) {
        channelLabels.current.delete(unit)
        traces.current.delete(unit)
        PlotWaveforms.clearBucket(String(unit))
      }

      const cc = arrayMap(channelLabels.current)

      for (let [k] of channelsMapGeometry.current) {
        if (!cc.includes(k)) {
          channelsMapGeometry.current.delete(k)
        }
      }

      return promisify(() => {
        renderPlot.removeTrace(elem.current, removedUnitIndices)
        const newIndices = gd.data.map((t, idx) => idx)
        if (!waveformGeometry && removedState) {
          updateRemainingTraces(gd, newIndices)
        }

        renderPlot.updateAnnotations(
          elem.current,
          [...channelsMapGeometry.current.entries()],
          waveformGeometry
        )
        if (!waveformGeometry) {
          updateYOffsets(gd)
        }
      })
    }

    return Promise.resolve()
  }

  const offsetYObj =
    algo === 'Klusta'
      ? { mean_waveforms: 4, overlap_waveforms: 4 }
      : { mean_waveforms: 4, overlap_waveforms: 4 }

  const offsetY = offsetYObj[waveformsView]
  const HEIGHT =
    algo === 'Klusta'
      ? { mean_waveforms: H || 900, overlap_waveforms: H || 900 }
      : { mean_waveforms: H || 900, overlap_waveforms: H || 900 }

  useDarkModeEffect(elem.current, renderPlot, darkMode)

  const size = W > 0 ? W + H : 0

  const reinitClusters = useReinitClusters()

  useEffect(() => {
    if (size === 0) return

    reinitClusters({
      arr: clusterUnits,
      type: REINIT_CLUSTER_UNITS
    })
    reinitClusters({
      arr: similarityUnits,
      type: REINIT_SIMILARITY_UNITS,
      multiple: true
    })
  }, [size])

  const onGetDom = cb => {
    const gd = document.getElementById(WAVEFORM_CANVAS_ID)

    cb(gd)
  }

  const onRestyle = () => {
    onGetDom(gd => {
      renderPlot.restyle(elem.current ?? gd, {
        'xaxis.range': layoutArgs.current.axisArgs.xaxis.range,
        darkMode
      })
    })
  }

  StandardKeyEvents.registerWithPreventDefault([87, 65, 39], e => {
    //right-widden

    scalingRef.current = lodashUpdate(scalingRef.current, 'x', n =>
      Math.min(n - SENSITIVITY, maxXAxis.current[0])
    )

    layoutArgs.current = lodashUpdate(
      layoutArgs.current,
      'axisArgs.xaxis.range',
      n => [
        Math.min(n[0] + SENSITIVITY, maxXAxis.current[0]),
        Math.max(n[1] - SENSITIVITY, maxXAxis.current[1])
      ]
    )

    onRestyle()
  })

  StandardKeyEvents.registerWithPreventDefault([87, 65, 37], e => {
    //left-narrow

    scalingRef.current = lodashUpdate(scalingRef.current, 'x', function (n) {
      return n + SENSITIVITY
    })

    layoutArgs.current = lodashUpdate(
      layoutArgs.current,
      'axisArgs.xaxis.range',
      function (n) {
        return [n[0] - SENSITIVITY, n[1] + SENSITIVITY]
      }
    )

    onRestyle()
  })

  StandardKeyEvents.registerWithPreventDefault([87, 65, 40], e => {
    //down
    onGetDom(gd => {
      if (Array.isArray(gd?.data)) {
        const { indices, data, newDeph } = onUpdateDeph(
          gd,
          Math.max(deph.current - DEPTH_UPDATER, 0)
        )
        deph.current = newDeph
        renderPlot.updateAttribute(elem.current ?? gd, { y: data }, indices)
      }
    })
  })

  StandardKeyEvents.registerWithPreventDefault(
    [87, 65, 38],
    _e => {
      //up
      onGetDom(gd => {
        if (Array.isArray(gd?.data)) {
          const { indices, data, newDeph } = onUpdateDeph(
            gd,
            deph.current + DEPTH_UPDATER
          )
          deph.current = newDeph
          renderPlot.updateAttribute(elem.current ?? gd, { y: data }, indices)
        }
      })
    },
    []
  )

  function updatePlotData(
    unit,
    payloadData,
    unitIdx,
    waveformGeometry,
    waveformOverlapping
  ) {
    const domCanvasElement = document.getElementById(WAVEFORM_CANVAS_ID)

    if (!elem.current) return

    if (!Array.isArray(domCanvasElement?.data)) {
      renderPlot.render(elem.current, [])
    }

    const { tracesUnit, channelLabelsUnit } = PlotWaveforms.runBuilder({
      unit,
      payloadData,
      unitIdx,
      HEIGHT,
      deph: deph.current,
      waveformGeometry,
      waveformOverlapping
    })

    if (!tracesUnit || !channelLabelsUnit) return

    traces.current.set(unit, tracesUnit)

    for (let [k, v] of channelLabelsUnit) {
      channelsMapGeometry.current.set(k, v)
    }

    channelLabels.current.set(unit, [...channelLabelsUnit.keys()])

    if (payloadData) {
      promisify(() => {
        const arr = [...channelsMapGeometry.current.values()]
        const axs = updateAxesRange(arr)
        const trace_length = payloadData?.nsw
        const prevScaleX = getOr(0, ['current', 'x'], scalingRef)

        //asigning the object directly is not working but only value by value
        const [x1, x2] = axs.xaxis.range
        const [y1, y2] = axs.yaxis.range

        layoutArgs.current.axisArgs.xaxis.range[0] =
          x1 - prevScaleX - trace_length - SENSITIVITY
        layoutArgs.current.axisArgs.xaxis.range[1] =
          x1 + prevScaleX + trace_length * 2
        layoutArgs.current.axisArgs.yaxis.range[0] = y1
        layoutArgs.current.axisArgs.yaxis.range[1] = y2

        const diff = (x2 - x1) / 2
        maxXAxis.current = [x1 + diff - SENSITIVITY, x2 - diff]

        renderPlot.update(elem.current, traces.current.get(unit))
        renderPlot.updateAnnotations(
          elem.current,
          [...channelsMapGeometry.current.entries()],
          waveformGeometry
        )

        if (!waveformGeometry) {
          updateYOffsets(domCanvasElement)
        }
      })
    }
  }

  return uniqueUnits.length ? (
    <div ref={elem} id={WAVEFORM_CANVAS_ID} style={{ position: 'relative' }}>
      {render(() => (
        <>{loading && <DrawingLoader width={W} height={H} />}</>
      ))}
    </div>
  ) : (
    render(() => <BaseWidget />)
  )
}
