import React, { useEffect, useMemo } from 'react';
import { difference } from 'ramda';
import isEqual from 'lodash-es/isEqual';
import {
  RenderItemProps,
  DefaultHandlers,
  DefaultState,
  TreeType,
} from './types';
import { Tree, TreeProps } from './Tree';
import { checkTree, filterTreeBy, getNodesIds } from './utils';

export type CheckboxTreeRenderItemProps<T> = RenderItemProps<
  T,
  DefaultHandlers & { onCheckedChange: () => void },
  DefaultState & { checked: boolean }
>;
interface CheckboxTreeProps<T> extends Omit<TreeProps<T>, 'renderItem'> {
  checked: string[];
  onCheckedChange: (checkedIds: string[]) => void;
  /**
   * A predicate function that is invoked for every node and leaf of the given
   * tree. Specify this function if you want to omit some leafs or nodes from
   * the tree before the "checking" algorithm starts.
   */
  filterTree?: (node: TreeType<T>) => boolean;
  renderItem: (props: CheckboxTreeRenderItemProps<T>) => JSX.Element;
  autoCheckParent?: boolean;
}

export function CheckboxTree<T>({
  treeData,
  isItemLoaded,
  expanded,
  onExpandedChange,
  checked,
  onCheckedChange,
  renderItem,
  renderLoader,
  loadMoreItems,
  height,
  itemSize,
  threshold,
  filterTree,
  autoCheckParent = true,
}: CheckboxTreeProps<T>) {
  const filteredTree = useMemo(
    () => (filterTree ? filterTreeBy(treeData, filterTree) || [] : treeData),
    [filterTree, treeData],
  );

  const checkHandler = React.useCallback(
    (checkedId) =>
      checkTree(filteredTree, checked, checkedId, { autoCheckParent }),
    [checked, filteredTree, autoCheckParent],
  );

  useEffect(() => {
    const newCheckedIds = checkTree(filteredTree, checked, '', {
      autoCheckParent,
    });
    if (!isEqual(checked, newCheckedIds)) {
      onCheckedChange(newCheckedIds);
    }
  }, [filteredTree, checked, onCheckedChange, autoCheckParent]);

  return (
    <Tree
      onExpandedChange={onExpandedChange}
      expanded={expanded}
      height={height}
      treeData={treeData}
      itemSize={itemSize}
      threshold={threshold}
      isItemLoaded={isItemLoaded}
      loadMoreItems={loadMoreItems}
      renderLoader={renderLoader}
      renderItem={({ data, handlers, level, state }) => {
        const enhancedHandlers = {
          ...handlers,
          onCheckedChange: () => {
            const newChecked = checkHandler(data.id);
            const nodesToExpand = getNodesIds(data);
            const additionalNodesToExpand = difference(nodesToExpand, expanded);
            const isChecked = !checked.includes(data.id);

            onCheckedChange(newChecked);

            if (isChecked && additionalNodesToExpand.length) {
              onExpandedChange([...expanded, ...additionalNodesToExpand]);
            }
          },
        };

        const enhancedState = {
          ...state,
          checked: checked.includes(data.id),
        };

        return renderItem({
          data,
          handlers: enhancedHandlers,
          level,
          state: enhancedState,
        });
      }}
    />
  );
}
