import { buildWaveforms, buildMeanWaveforms } from './hooks'
import { seedUnitColor } from '../../../../../../shared/utils'
import { OFFSETY, STD_TYPE } from './consts'
import { buildKey } from './helpers/plotHelpers'
import { arrFromNumber } from '../../../../../TimeseriesView/NewLeftPanel/utils'

const buildCoords = (arr, ch, trace_length) =>
  arr?.reduce(
    (acc, v, i) => ({
      ...acc,
      [ch[i]]: { y: v, x: arrFromNumber(trace_length) }
    }),
    {}
  )

const increaseGeom = geometry => {
  const [x, y] = geometry

  // increased x offset to make traces more readable
  return [x * 2, y]
}

const getGeometry = (geometry, base, yGeom, column, trace_length) => {
  if (!geometry) return base

  if (column !== undefined) {
    return [column * trace_length, yGeom]
  }

  return increaseGeom(geometry)
}

const sortChannels = (
  channels,
  geometry,
  trace_length,
  from,
  gap,
  makeGeometry
) => {
  return channels
    .map(channel => [channel, geometry?.get(channel)])
    .sort((a, b) => a[1][1] - b[1][1])
    .map((channel, idx) => [channel, makeGeometry(channel[0], channel[1][1])])
    .map(([[channel, geom], { bucket, y }], i) => [
      channel,
      getGeometry(geom, [trace_length, y], from + gap * i, bucket, trace_length)
    ])
}

export const mean_waveforms = (
  unit,
  payload,
  HEIGHT,
  deph,
  makeGeometry,
  waveformOverlapping
) => {
  const tracesUnit = []
  const channelLabelsUnit = new Map()

  if (payload) {
    let { mean, std, geometry, nsw: trace_length } = payload

    if (!mean || !std || !geometry) return

    const geometryMap = new Map(geometry)

    const channels = Array.from(geometryMap.keys()) ?? []

    const { columnsPos } = makeGeometry(channels)

    const g = [...geometryMap].map(v => v[1][1])
    const from = Math.min(...g),
      to = Math.max(...g),
      gap = (to - from) / channels.length

    const channelsWithGeometry = sortChannels(
      channels,
      geometryMap,
      trace_length,
      from,
      gap,
      (channel, idx) => ({
        bucket: waveformOverlapping
          ? 0
          : columnsPos?.[buildKey(channel, unit)]?.idx,
        y: idx
      })
    )

    const buildCoordsStd = (mean, std, type = STD_TYPE.ADD) =>
      mean?.reduce((acc, v, i) => {
        return {
          ...acc,
          [channels[i]]: {
            y: v.map((item, st) => {
              const isAdd = type === STD_TYPE.ADD
              const s = std[i][st]
              return isAdd ? item + s : item - s
            }),
            x: arrFromNumber(trace_length)
          }
        }
      }, {})

    const mn = buildCoords(mean, channels, trace_length)
    const std_add = buildCoordsStd(mean, std, STD_TYPE.ADD)
    const std_sub = buildCoordsStd(mean, std, STD_TYPE.SUB)

    for (let i = 0; i < channelsWithGeometry.length; i++) {
      const channel = String(channelsWithGeometry[i][0])
      const geom = channelsWithGeometry[i][1]

      if (!mn[channel]) continue

      const [x, y] = geom
      const offsetX = x,
        offsetY = y * OFFSETY

      const color = seedUnitColor(unit)

      channelLabelsUnit.set(channel, [offsetX, offsetY])

      const m1 = {
        x: mn[channel].x,
        y: mn[channel].y
      }

      const m2 = {
        x: std_add[channel].x,
        y: std_add[channel].y
      }

      const m3 = {
        x: std_sub[channel].x,
        y: std_sub[channel].y
      }

      const common = {
        channel,
        H: HEIGHT.mean_waveforms,
        trace_length,
        offsetX,
        offsetY,
        unit,
        deph
      }

      tracesUnit.push(
        buildMeanWaveforms({
          style: {
            color
          },
          xData: m1.x,
          yData: m1.y,
          ...common
        })
      )

      tracesUnit.push(
        buildMeanWaveforms({
          style: {
            color,
            dash: 'dot'
          },
          xData: m2.x,
          yData: m2.y,
          ...common
        })
      )

      tracesUnit.push(
        buildMeanWaveforms({
          style: {
            color,
            dash: 'dot'
          },
          xData: m3.x,
          yData: m3.y,
          ...common
        })
      )
    }
  }

  return {
    unit,
    tracesUnit,
    channelLabelsUnit
  }
}

export const overlap_waveforms = (
  unit,
  payloadData,
  HEIGHT,
  deph,
  makeGeometry,
  waveformOverlapping
) => {
  const tracesUnit = []
  const channelLabelsUnit = new Map()

  if (payloadData) {
    const { data, geometry, nsw: trace_length } = payloadData

    const geometryMap = new Map(geometry)
    const channels = Array.from(geometryMap.keys()) ?? []

    const { columnsPos } = makeGeometry(channels)

    const g = [...geometryMap].map(v => v[1][1])
    const from = Math.min(...g) * 2,
      to = Math.max(...g) * 2,
      gap = (to - from) / channels.length

    const channelsWithGeometry = sortChannels(
      channels,
      geometryMap,
      trace_length,
      from,
      gap,
      (channel, idx) => ({
        bucket: waveformOverlapping
          ? 0
          : columnsPos?.[buildKey(channel, unit)]?.idx,
        y: idx * 2
      })
    )

    const wf = buildCoords(data, channels, trace_length)

    for (let i = 0; i < channelsWithGeometry.length; i++) {
      const channel = String(channelsWithGeometry[i][0])

      const geom = channelsWithGeometry[i][1]

      if (!wf[channel]) continue

      const yData = wf[channel].y ?? []

      /*
      Only because the y data is inversed on overlap waveforms the trail start is on back and viceversa
      */
      const xData = wf[channel].x

      channelLabelsUnit.set(channel, geom)

      const [offsetX, offsetY] = geom

      //for each X point , you have to draw a 100 point on Y
      for (let x = 0; x < xData.length; x++) {
        tracesUnit.push(
          buildWaveforms({
            unit,
            xData: xData,
            yData: yData.map(item => item[x] * -1),
            H: HEIGHT.overlap_waveforms,
            offsetX,
            offsetY,
            style: {
              color: seedUnitColor(unit),
              width: 1
            },
            deph: columnsPos ? deph / 2 : deph / OFFSETY
          })
        )
      }
    }

    return {
      unit,
      tracesUnit,
      channelLabelsUnit
    }
  }
}
