import {
  ApolloClient,
  HttpLink,
  InMemoryCache,
  NormalizedCacheObject,
  disableFragmentWarnings,
  from,
  split,
} from '@apollo/client';
import { setContext } from '@apollo/client/link/context';
import { onError } from '@apollo/client/link/error';
import { RetryLink } from '@apollo/client/link/retry';
import { GraphQLWsLink } from '@apollo/client/link/subscriptions';
import { getMainDefinition } from '@apollo/client/utilities';
import { useAuth0 } from '@auth0/auth0-react';
import { fromUnixTime, isAfter, sub } from 'date-fns';
import { CloseCode, createClient } from 'graphql-ws';
import { JwtPayload, jwtDecode } from 'jwt-decode';
import { useCallback, useMemo, useRef } from 'react';

import useLogtail from 'hooks/useLogtail';

function getHostname() {
  const hostname = import.meta.env.REACT_APP_GRAPHQL_HOST;

  if (import.meta.env.REACT_APP_DEPLOY_ENV === 'prod') {
    return 'https://gql.syllabird.com/v1/graphql';
  } else if (import.meta.env.REACT_APP_DEPLOY_ENV === 'test') {
    return `https://${hostname}.onrender.com/v1/graphql`;
  } else {
    return `http://${hostname}/v1/graphql`;
  }
}

function getWsHostname() {
  const hostname = import.meta.env.REACT_APP_GRAPHQL_HOST;

  if (import.meta.env.REACT_APP_DEPLOY_ENV === 'prod') {
    return 'wss://gql.syllabird.com/v1/graphql';
  } else if (import.meta.env.REACT_APP_DEPLOY_ENV === 'test') {
    return `wss://${hostname}.onrender.com/v1/graphql`;
  } else {
    return `ws://${hostname}/v1/graphql`;
  }
}

export default function useAuthApolloClient() {
  const { getAccessTokenSilently } = useAuth0();
  const logtail = useLogtail();

  const httpUserTokenRef = useRef<string | null>(null);
  const wsUserTokenRef = useRef<string | null>(null);
  const clientRef = useRef<ApolloClient<NormalizedCacheObject> | null>(null);

  const resetClient = useCallback(async () => {
    if (clientRef.current) {
      httpUserTokenRef.current = null;
      wsUserTokenRef.current = null;
      await clientRef.current.resetStore();
    }
  }, []);

  // Disable warnings about duplicated fragments since we trigger it with our
  // generated types. We still seem to get a few warnings on page load before we
  // run this function. Those warnings are not indicative of an issue.
  disableFragmentWarnings();

  const cache = useMemo(
    () =>
      new InMemoryCache({
        typePolicies: {
          Query: {
            fields: {
              student(_, { args, toReference }) {
                return toReference({
                  __typename: 'students',
                  id: args?.id,
                });
              },
              assignment(_, { args, toReference }) {
                return toReference({
                  __typename: 'assignments',
                  id: args?.id,
                });
              },
              student_assignment(_, { args, toReference }) {
                return toReference({
                  __typename: 'student_assignments',
                  id: args?.id,
                });
              },
            },
          },
        },
      }),
    []
  );

  const link = useMemo(() => {
    let httpUserTokenExpiration = new Date();

    let wsUserTokenExpiration = new Date();

    httpUserTokenRef.current = null;
    wsUserTokenRef.current = null;

    const errorLink = onError(({ graphQLErrors, networkError }) => {
      graphQLErrors?.forEach(
        (graphQLError) =>
          !graphQLError.message.includes('JWTExpired') &&
          logtail.error('graphql_error', {
            error: {
              message: graphQLError.message,
              code: graphQLError.extensions?.code,
            },
          })
      );

      if (networkError) {
        logtail.error('graphql_network_error', {
          error: { message: networkError.message },
        });
      }
    });

    const retryLink = new RetryLink({
      delay: {
        initial: 300,
      },
      attempts: {
        max: 3,
      },
    });

    const getHttpUserToken = setContext(async () => {
      // We check if the http token expires in one minute to leave extra room for
      // the time it takes to make a network request
      if (
        !httpUserTokenRef.current ||
        isAfter(new Date(), sub(httpUserTokenExpiration, { minutes: 1 }))
      ) {
        httpUserTokenRef.current = await getAccessTokenSilently();

        const decodedUserToken = jwtDecode<JwtPayload>(
          httpUserTokenRef.current
        );
        httpUserTokenExpiration = decodedUserToken.exp
          ? fromUnixTime(decodedUserToken.exp)
          : new Date();
      }
    });

    const resetExpiredHttpUserToken = onError(({ graphQLErrors }) => {
      graphQLErrors?.forEach(({ message }) => {
        if (message.includes('JWTExpired')) {
          httpUserTokenRef.current = null;
        }
      });
    });

    const getHttpAuthHeadersLink = setContext((_, { headers }) => {
      return {
        headers: {
          ...headers,
          authorization: httpUserTokenRef.current
            ? `Bearer ${httpUserTokenRef.current}`
            : null,
        },
      };
    });

    const httpLink = new HttpLink({
      uri: getHostname(),
    });

    const wsLink = new GraphQLWsLink(
      createClient({
        url: getWsHostname(),
        connectionParams: async () => {
          if (!wsUserTokenRef.current) {
            wsUserTokenRef.current = await getAccessTokenSilently();

            const decodedUserToken = jwtDecode<JwtPayload>(
              wsUserTokenRef.current
            );
            wsUserTokenExpiration = decodedUserToken.exp
              ? fromUnixTime(decodedUserToken.exp)
              : new Date();
          }

          return {
            headers: { authorization: `Bearer ${wsUserTokenRef.current}` },
          };
        },
        on: {
          connected: (socket) => {
            // If the websocket is open and the token expires in a minute we need to
            // refresh it
            if (
              isAfter(new Date(), sub(wsUserTokenExpiration, { minutes: 1 })) &&
              (socket as WebSocket).readyState === WebSocket.OPEN
            ) {
              (socket as WebSocket).close(CloseCode.Forbidden, 'Forbidden');
            }
          },
          closed: (event) => {
            // If we closed the socket with the Forbidden code we will need to try refreshing the token
            if ((event as CloseEvent).code === CloseCode.Forbidden) {
              wsUserTokenRef.current = null;
            }
          },
        },
      })
    );

    // Send queries and mutations to httpLink and subscriptions to wsLink
    const splitLink = split(
      ({ query }) => {
        const definition = getMainDefinition(query);
        return (
          definition.kind === 'OperationDefinition' &&
          definition.operation === 'subscription'
        );
      },
      wsLink,
      httpLink
    );

    return from([
      errorLink,
      retryLink,
      getHttpUserToken,
      resetExpiredHttpUserToken,
      getHttpAuthHeadersLink,
      splitLink,
    ]);
  }, [getAccessTokenSilently, logtail]);

  const client = useMemo(() => {
    const newClient = new ApolloClient({ cache, link });
    clientRef.current = newClient;
    return newClient;
  }, [cache, link]);

  return { client, resetClient };
}
