import { useState, useCallback, useRef, useLayoutEffect } from 'react';
import { generateClient } from 'aws-amplify/api';
import { useQuery } from '@tanstack/react-query';
import { milliseconds } from 'date-fns';

const client = generateClient();

const useCallbackRef = (callback) => {
  const callbackRef = useRef(callback);
  useLayoutEffect(() => {
    callbackRef.current = callback;
  }, [callback]);
  return callbackRef;
};

// fetch api recursively until nextToken is null
const recursiveFetch = (
  query,
  { variables, ...rest } = {},
  nextToken = null,
  items = [],
) =>
  new Promise((resolve, reject) => {
    client
      .graphql({
        query,
        variables: {
          ...variables,
          ...rest,
          nextToken,
        },
      })
      .then((result) => {
        // queryName follows the format data[queryName]
        const queryName = Object.keys(result?.data)[0];
        const data = result?.data[queryName];

        if (data?.items) {
          // eslint-disable-next-line no-param-reassign
          items = [...items, ...(data?.items ?? [])]; // merge items from previous fetch
        }
        // if nextToken present then call the fetch again with nextToken
        if (data?.nextToken) {
          recursiveFetch(query, { variables, ...rest }, data?.nextToken, items)
            .then((nextResult) => resolve(nextResult))
            .catch((err) => reject(err));
        } else if (data?.items) {
          resolve({
            [queryName]: {
              items: [...items],
            },
          });
        } else {
          resolve({
            [queryName]: {
              ...data,
            },
          });
        }
      })
      .catch((err) => {
        reject(err);
      });
  });

const useQueryRecursive = (
  graphqlQuery,
  { variables, onCompleted, onError, skip, staleTime, ...rest } = {},
) => {
  const [refetchInterval, setRefetchInterval] = useState(false);
  const onCompletedRef = useCallbackRef(onCompleted);
  const onErrorRef = useCallbackRef(onError);

  const startPolling = useCallback((value) => {
    setRefetchInterval(value);
  }, []);

  // set a default limit value if not set
  const limit = variables?.limit || 100;

  const stopPolling = useCallback(() => {
    setRefetchInterval(false);
  }, []);
  const queryName = graphqlQuery?.definitions[0]?.name?.value || 'invalidQuery';
  // fetch query
  const { data, error, isFetching, refetch } = useQuery(
    [queryName, variables],
    async () => {
      const result = await recursiveFetch(graphqlQuery, {
        ...variables,
        limit,
      });
      return result;
    },
    {
      refetchInterval,
      refetchIntervalInBackground: true,
      onSuccess: (_data) => onCompletedRef.current?.(_data || []),
      onError: (_error) => onErrorRef.current?.(_error),
      enabled: !skip,
      staleTime: milliseconds(staleTime || { minutes: 1 }),
      ...rest,
    },
  );

  return {
    data: data || [],
    loading: isFetching,
    error,
    startPolling,
    stopPolling,
    refetch,
  };
};

export default useQueryRecursive;
