import { annotation as csToolsAnnotation } from '@cornerstonejs/tools';
import { Internal, OHIF } from '@gleamer/types';
import { cloneDeep, isEqual } from 'lodash';
import { ActionTypes } from 'redux-undo';
import { getItemLabelPassState } from './labelPass.selectors';

export type UndoRedoMiddlewareConfig = {
  servicesManager: OHIF.ServicesManager;
  drawRegion: (params: {
    servicesManager: OHIF.ServicesManager;
    region: Internal.RegionInternal;
    observations: Internal.ObservationInternal[];
  }) => { initialRegion: Internal.RegionInternal; drawnRegion: Internal.RegionInternal | null };
  annotation: typeof csToolsAnnotation;
};

export function undoRedoMiddleware({
  servicesManager,
  drawRegion,
  annotation,
}: UndoRedoMiddlewareConfig) {
  const { MeasurementService, CornerstoneViewportService } = servicesManager.services;

  return store => next => action => {
    if (action.type !== ActionTypes.UNDO && action.type !== ActionTypes.REDO) {
      return next(action);
    }

    const { observations: previousObservations } = getItemLabelPassState(store.getState());
    const previousRegions = previousObservations.flatMap(obs => obs.regions);
    const previousRegionUids = previousRegions.map(region => region.uid);

    const result = next(action);

    const { observations } = getItemLabelPassState(store.getState());
    const currentRegions = observations.flatMap(obs => obs.regions);
    const currentRegionUids = currentRegions.map(region => region.uid);

    // get regions where region.data.handles.points were modified
    const modifiedRegions = currentRegions.filter(region => {
      const previousRegion = previousRegions.find(reg => reg.uid === region.uid);
      if (!previousRegion) {
        return false;
      }

      const previousPoints = previousRegion.data.handles.points;
      const currentPoints = region.data.handles.points;

      return !isEqual(previousPoints, currentPoints);
    });

    const deletedRegions = previousRegions.filter(
      region => !currentRegionUids.includes(region.uid)
    );

    const addedRegions = currentRegions.filter(region => !previousRegionUids.includes(region.uid));

    deletedRegions.forEach(region => {
      MeasurementService.remove(region.uid, region.source);
    });

    addedRegions.forEach(region => {
      drawRegion({
        servicesManager,
        region,
        observations,
      });
    });

    modifiedRegions.forEach(region => {
      const annotationToUpdate = annotation.state.getAnnotation(region.uid);

      annotationToUpdate.data.handles.points = cloneDeep(region.data.handles.points);

      CornerstoneViewportService.getRenderingEngine().render();
    });

    return result;
  };
}
