/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
import {
  InferenceApi,
  InferenceResponseDto,
  InferenceResponseDtoStatusEnum,
  LanguageCorrectionItemTypeEnum,
  UserConfigurationDtoLanguageCheckModeEnum,
} from '@ink-ai/insight-service-sdk';
import { splitEvery } from 'ramda';
import { getApi } from '../common/requestHelper';
import { RootState } from '.';
import { v4 as uuidv4 } from 'uuid';
import {
  concurrentConsume,
  getInfAndCorrections,
  SHA256Digest,
} from '../common/utils';
import { DateTime } from 'luxon';
import { isGrammarConfigChanged } from './configuration';
import { newAnalyzeWritingStyle } from './writing-style';

export const IN_PROGRESS_STATUS = [
  InferenceResponseDtoStatusEnum.Inprogress,
  InferenceResponseDtoStatusEnum.Submitting,
] as InferenceResponseDtoStatusEnum[];

export const REVIEWED_STATUS = [
  InferenceResponseDtoStatusEnum.Accepted,
  InferenceResponseDtoStatusEnum.Rejected,
] as InferenceResponseDtoStatusEnum[];

export const SHOULD_CREATE_NEW_INF_STATUS = [
  InferenceResponseDtoStatusEnum.Error,
] as InferenceResponseDtoStatusEnum[];

export type CorrectionDiff = {
  type: LanguageCorrectionItemTypeEnum;
  position: number;
  length: number;
  correction?: string;
};

export type CorrectionItem = InferenceResponseDto & {
  hash: string;
  requestedAt?: number;
  isDiff?: boolean;
};

const initialState = {
  grammarList: [] as CorrectionItem[],
  currentInfId: '',
  acceptRejectTimestamp: 0,
  loading: false,
  syncingInferences: false,
  scanning: false,
};

export type CorrectionState = typeof initialState;

export const getTextHash = (text: string) => SHA256Digest(text);

export const inferenceGrammar = createAsyncThunk(
  'correction/InferenceGrammar',
  async (
    { texts, timestamp }: { texts: string[]; timestamp: number },
    { dispatch, getState },
  ) => {
    const state = getState() as RootState;
    if (timestamp !== state.correction.acceptRejectTimestamp) {
      throw new Error('Timestamp mismatch, discard stale inference');
    }
    dispatch(correction.actions.startLoading());
    const originalInfList = await Promise.all(
      texts.map(async (text) => ({
        id: uuidv4(),
        hash: await getTextHash(text),
        text,
      })),
    );
    const { infList, corrections } = getInfAndCorrections(
      originalInfList,
      state.correction.grammarList,
    );
    dispatch(correction.actions.replaceGrammarList(corrections));

    // dispatch initiate writing style analysis
    // If user accept a grammar correction, below action will also be triggered
    // in order to handle this scenario, we need to pass the grammar corrections to the writing style analysis
    // to allow the writing style analysis to inference the accepted grammar corrections separately.
    await dispatch(
      newAnalyzeWritingStyle({
        originInf: originalInfList,
        grammarCorrections: corrections,
      }),
    );

    const infGroup = splitEvery(5, infList);
    const inferenceApi = await getApi(InferenceApi);
    const scanningMode = state.configuration.languageCheckMode;
    try {
      await concurrentConsume(
        async (
          payload: {
            text: string;
            id: string;
          }[],
        ) => {
          const response = await inferenceApi.createLanguageInference({
            payload,
            instanceId: state.auth.instanceId,
            async: true,
          });
          const currentState = getState() as RootState;
          if (
            isGrammarConfigChanged(
              state.configuration,
              currentState.configuration,
            )
          ) {
            throw new Error('Inference aborted due to configuration changed!');
          }
          if (!currentState.correction.scanning) {
            // discard all submitting inferences
            dispatch(
              correction.actions.replaceGrammarList(
                currentState.correction.grammarList.filter(
                  ({ status }) =>
                    status !== InferenceResponseDtoStatusEnum.Submitting,
                ),
              ),
            );
            throw new Error('Inference aborted due to scanning is paused');
          }
          dispatch(correction.actions.mergeGrammarList(response.data));
        },
        1,
      )(infGroup);
    } finally {
      if (
        scanningMode === UserConfigurationDtoLanguageCheckModeEnum.Selection
      ) {
        // for selection mode, we need automatically pause scanning after inference
        dispatch(correction.actions.setScanning(false));
      }
    }
  },
);

export const syncInferences = createAsyncThunk(
  'correction/SyncInferences',
  async (_, { dispatch, getState }) => {
    const state = getState() as RootState;
    if (Date.now() - state.webSocket.lastMessageAt < 15000) {
      console.log('skip sync infs due to living ws connection');
      return;
    }
    if (state.correction.syncingInferences) {
      console.log('skip sync infs');
      return;
    }
    dispatch(correction.actions.startSyncInferences());
    const now = DateTime.now().toUnixInteger();
    // filter out all in progress status older than 3 sec
    const idsToFetch = state.correction.grammarList
      .filter(
        ({ status, requestedAt = 0 }) =>
          now - requestedAt > 10 &&
          status === InferenceResponseDtoStatusEnum.Inprogress,
      )
      .map(({ id }) => id);
    if (idsToFetch.length === 0) {
      console.log('empty sync infs');
      return;
    }
    console.log('start sync infs');
    const idGroups = splitEvery(20, idsToFetch);
    const inferenceApi = await getApi(InferenceApi);
    await idGroups.reduce(async (ret, ids) => {
      await ret;
      const response = await inferenceApi.batchGet({
        ids,
      });
      dispatch(
        correction.actions.mergeGrammarList(
          response.data.filter(
            ({ status }) => !IN_PROGRESS_STATUS.includes(status),
          ),
        ),
      );
    }, Promise.resolve());
    console.log('finish sync infs');
  },
);

export const acceptInference = createAsyncThunk(
  'correction/AcceptInference',
  async (id: string, { getState }) => {
    const state = getState() as RootState;
    const inf = state.correction.grammarList.find(
      (correction) => correction.id === id,
    );
    if (!inf) {
      return { id, newId: id };
    }
    return {
      id,
      hash: await getTextHash(inf.output),
    };
  },
);

export const acceptPartialInference = createAsyncThunk(
  'correction/AcceptPartialInference',
  async (
    { id, correctionIds }: { id: string; correctionIds: number[] },
    { getState },
  ) => {
    const state = getState() as RootState;
    const inf = state.correction.grammarList.find(
      (correction) => correction.id === id,
    );
    if (!inf) {
      return { id, newId: id, correctionIds, output: '' };
    }
    let diffLength = 0;
    const output = inf.corrections.reduce(
      (ret, { position, length, correction }, index) => {
        if (!correctionIds.includes(index)) {
          return ret;
        }
        const output =
          ret.substring(0, position + diffLength) +
          correction +
          ret.substring(position + diffLength + length);
        diffLength = diffLength + correction.length - length;
        return output;
      },
      inf.text ?? '',
    );
    return {
      id,
      hash: await getTextHash(output),
      correctionIds,
      output,
    };
  },
);

export const correction = createSlice({
  name: 'correction',
  initialState: initialState,
  reducers: {
    setScanning: (state, action: PayloadAction<boolean>) => {
      state.scanning = action.payload;
    },
    startSyncInferences: (state) => {
      state.syncingInferences = true;
    },
    updateAcceptRejectTimestamp: (state) => {
      state.acceptRejectTimestamp = Date.now();
    },
    setCurrentInfId: (state, action: PayloadAction<string>) => {
      state.currentInfId = action.payload;
    },
    initGrammarList: (state, action: PayloadAction<CorrectionItem[]>) => {
      state.grammarList = action.payload;
    },
    updateGrammarList: (
      state,
      { payload }: PayloadAction<Partial<CorrectionItem>>,
    ) => {
      const index = state.grammarList.findIndex(
        (item) => item.id === payload.id,
      );
      if (index < 0) {
        return;
      }
      state.grammarList = [
        ...state.grammarList.slice(0, index),
        {
          ...state.grammarList[index],
          ...payload,
          isDiff: !!payload.corrections?.length,
        },
        ...state.grammarList.slice(index + 1),
      ];
    },
    appendGrammarList: (state, action: PayloadAction<CorrectionItem[]>) => {
      state.grammarList = [...state.grammarList, ...action.payload];
    },
    replaceGrammarList: (state, action: PayloadAction<CorrectionItem[]>) => {
      state.grammarList = action.payload.map((item) => ({
        ...item,
      }));
    },
    rejectInference: (state, action: PayloadAction<string>) => {
      const inf = state.grammarList.find(({ id }) => id === action.payload);
      if (!inf) {
        return;
      }
      inf.status = InferenceResponseDtoStatusEnum.Rejected;
    },
    rejectPartialInference: (
      state,
      action: PayloadAction<{ id: string; correctionIds: number[] }>,
    ) => {
      const inf = state.grammarList.find(({ id }) => id === action.payload.id);
      if (!inf) {
        return;
      }
      inf.corrections = inf.corrections?.filter(
        (_, index) => !action.payload.correctionIds.includes(index),
      );
    },
    acceptInference: (state, action: PayloadAction<string>) => {
      const inf = state.grammarList.find(({ id }) => id === action.payload);
      if (!inf) {
        return;
      }
      inf.status = InferenceResponseDtoStatusEnum.Accepted;
    },
    mergeGrammarList: (
      state,
      action: PayloadAction<InferenceResponseDto[]>,
    ) => {
      let newGrammarList = Array.from(state.grammarList);
      let response = Array.from(action.payload);
      newGrammarList = newGrammarList.map((item) => {
        const index = response.findIndex(({ id }) => id === item.id);
        if (index < 0) {
          return item;
        }
        const newItem = response[index];
        response = response.slice(index + 1);
        return {
          ...item,
          ...newItem,
          isDiff: !!newItem.corrections?.length,
        };
      });
      state.grammarList = newGrammarList;
    },
    startLoading: (state) => {
      state.loading = true;
    },
    clearAll: () => {
      return initialState;
    },
  },
  extraReducers: (builder) => {
    // Add reducers for additional action types here, and handle loading state as needed
    builder.addCase(inferenceGrammar.fulfilled, (state) => {
      // Add user to the state array
      state.loading = false;
    });
    builder.addCase(syncInferences.fulfilled, (state) => {
      state.syncingInferences = false;
    });
    builder.addCase(syncInferences.rejected, (state) => {
      state.syncingInferences = false;
    });
    builder.addCase(inferenceGrammar.rejected, (state, action) => {
      // Add user to the state array
      console.error(action.error);
      state.loading = false;
    });
    builder.addCase(acceptInference.fulfilled, (state, { payload }) => {
      const inf = state.grammarList.find(
        (correction) => correction.id === payload.id,
      );
      if (!inf) {
        return;
      }
      inf.hash = payload.hash;
      inf.status = InferenceResponseDtoStatusEnum.Accepted;
    });
    builder.addCase(acceptPartialInference.fulfilled, (state, { payload }) => {
      const inf = state.grammarList.find(
        (correction) => correction.id === payload.id,
      );
      if (!inf) {
        return;
      }
      inf.hash = payload.hash;
      state.currentInfId = inf.id;
      inf.text = payload.output;
      let diffLength = 0;
      inf.corrections = inf.corrections.reduce((ret, correctionItem, index) => {
        if (!payload.correctionIds.includes(index)) {
          correctionItem.position += diffLength;
          ret.push(correctionItem);
        } else {
          diffLength +=
            correctionItem.correction.length - correctionItem.length;
        }
        return ret;
      }, [] as CorrectionDiff[]);
      if (!inf.corrections.length) {
        inf.status = InferenceResponseDtoStatusEnum.Accepted;
      }
    });
  },
});
