import { useEffect, useState } from 'react';
import toast from 'react-hot-toast';
import { deepEqual, getErrorMessage } from '../common/utils';

import { PromptComposer, PromptComposerMessages, PromptVersionSelector, Selector, TabGroup } from '../components';
import { PV } from '../components/common/PromptVersionSelector';
import {
  getModels,
  gradeModels,
  runModel,
  runModelMessages,
  scoreEvaluation,
  summarizeResults
} from '../services/Models';
import { Model } from '../types';
import { FullMessageContent, ModelRunResult, ModelRunResultMessages } from '../types/Models';
import { CompletionMetric } from '../types/Performance';
import { ModelPicker, PerformanceGraph } from '../components/performance';
import { mean } from '../common/math';
import { SelectorValue } from '../components/common/Selector';
import { EvaluationType, Grades, RatedEvaluation, Summary } from '../types/Evaluations';
import ProgressTracker, { Progress } from '../components/performance/Progress';
import Skeleton from 'react-loading-skeleton';
import ResultsTable from '../components/performance/ResultsTable';
import { translateModelParameters } from '../common/models';
import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import { faMagnifyingGlassChart, faPersonRunning } from '@fortawesome/free-solid-svg-icons';
import { PromptVersionTypes } from '../types/Prompt';
import { generateCombinedPrompt, generatePromptPayloadFromVersion } from '../common/prompts';
import { PromptMetricSource } from '../types/Metrics';

const runsSelectorValues: SelectorValue[] = [
  { value: 1, label: '1 Run' },
  { value: 5, label: '5 Runs' },
  { value: 10, label: '10 Runs' },
  { value: 15, label: '15 Runs' }
];

/**
 * Props interface for the Performance component.
 */
interface Props {}

/**
 * Performance page component.
 *
 * @component
 * @param {Props} props - The component props.
 * @returns {JSX.Element} The rendered component.
 */
const Performance: React.FC<Props> = ({}: Props) => {
  const [isBusy, setIsBusy] = useState<boolean>(false);
  const [models, setModels] = useState<Model[]>([]);
  const [modelNameMap, setModelNameMap] = useState<Record<string, string>>({});
  const [apv, setApv] = useState<PV>();
  const [selectedModelIds, setSelectedModelIds] = useState<string[]>([]);
  const [selectedRunCount, setSelectedRunCount] = useState<SelectorValue>(runsSelectorValues[1]);
  const [completionMetrics, setCompletionMetrics] = useState<CompletionMetric[]>([]);
  const [rawMetrics, setRawMetrics] = useState<ModelRunResult[]>([]);
  const [rawMetricsMessages, setRawMetricsMessages] = useState<ModelRunResultMessages[]>([]);
  const [isMessagesType, setIsMessagesType] = useState<boolean>(false);
  const [progress, setProgress] = useState<Progress>();
  const [currentPayload, setCurrentPayload] = useState<Record<string, any>>({});
  const [generatedPrompt, setGeneratedPrompt] = useState<string>();
  const [fullMessageContent, setFullMessageContent] = useState<FullMessageContent>();
  const [fullMessageContentEvaluated, setFullMessageContentEvaluated] = useState<FullMessageContent>();
  const [completionCoherence, setCompletionCoherence] = useState<RatedEvaluation[]>([]);
  const [completionFluency, setCompletionFluency] = useState<RatedEvaluation[]>([]);
  const [completionBias, setCompletionBias] = useState<RatedEvaluation[]>([]);
  const [completionFactual, setCompletionFactual] = useState<RatedEvaluation[]>([]);
  const [completionToxicity, setCompletionToxicity] = useState<RatedEvaluation[]>([]);
  const [completionRelevancy, setCompletionRelevancy] = useState<RatedEvaluation[]>([]);
  const [summary, setSummary] = useState<Summary>();
  const [grades, setGrades] = useState<Grades>();
  const [selectedTab, setSelectedTab] = useState<number>(0);

  const reset = () => {
    setIsBusy(false);
    setApv(undefined);
    setGeneratedPrompt(undefined);
    setFullMessageContent(undefined);
    setFullMessageContentEvaluated(undefined);
    setCurrentPayload({});
    resetMetrics();
  };

  const resetScoredEvaluations = () => {
    setCompletionCoherence([]);
    setCompletionFluency([]);
    setCompletionBias([]);
    setCompletionFactual([]);
    setCompletionToxicity([]);
    setCompletionRelevancy([]);
  };

  const resetMetrics = () => {
    resetScoredEvaluations();
    setRawMetrics([]);
    setRawMetricsMessages([]);
    setCompletionMetrics([]);
    setProgress(undefined);
    setSummary(undefined);
    setGrades(undefined);
  };

  useEffect(() => {
    (async () => {
      setIsBusy(true);
      try {
        const data = await getModels();
        setModels(data);
        setModelNameMap(
          data.reduce((acc: Record<string, string>, model) => {
            acc[model.mid] = model.name;
            return acc;
          }, {})
        );
      } catch (error) {
        return toast.error(getErrorMessage(error));
      } finally {
        setIsBusy(false);
      }
    })();
  }, []);

  const getCompletion = (result: ModelRunResult | ModelRunResultMessages): string => {
    if (isMessagesType) {
      const messages = (result as ModelRunResultMessages).response.content;
      return messages[messages.length - 1].text || '';
    }

    return (result as ModelRunResult).completion;
  };

  const getRawMetrics = (): (ModelRunResult | ModelRunResultMessages)[] =>
    isMessagesType ? rawMetricsMessages : rawMetrics;

  useEffect(() => {
    if ((isMessagesType && !rawMetricsMessages.length) || (!isMessagesType && !rawMetrics.length)) return;

    const metrics = getRawMetrics().reduce((acc: Record<string, any>, result) => {
      if (!acc[result.modelId]) {
        acc[result.modelId] = {
          model: result.modelName,
          completion: getCompletion(result),
          latencies: [],
          ttfbs: [],
          requestCosts: [],
          responseCosts: [],
          requestTokens: [],
          responseTokens: []
        };
      }

      if (!result.failed) {
        const { modelId, latency, ttfb, tokens } = result;
        const { requestCost, responseCost, requestTokens, responseTokens } = tokens;
        acc[modelId].completion = getCompletion(result);
        acc[modelId].latencies.push(latency);
        acc[modelId].ttfbs.push(ttfb);
        acc[modelId].requestCosts.push(requestCost);
        acc[modelId].responseCosts.push(responseCost);
        acc[modelId].requestTokens.push(requestTokens);
        acc[modelId].responseTokens.push(responseTokens);
      }
      return acc;
    }, {});

    const metricsAvg: CompletionMetric[] = Object.values(metrics).map((data: any) => {
      return {
        model: data.model,
        completion: data.completion,
        avgLatency: mean(data.latencies),
        maxLatency: Math.max(...data.latencies),
        minLatency: Math.min(...data.latencies),
        avgTtfb: mean(data.ttfbs),
        maxTtfb: Math.max(...data.ttfbs),
        minTtfb: Math.min(...data.ttfbs),
        avgRequestCost: mean(data.requestCosts),
        avgResponseCost: mean(data.responseCosts),
        avgTotalCost: mean(data.requestCosts) + mean(data.responseCosts),
        avgRequestTokens: mean(data.requestTokens),
        avgResponseTokens: mean(data.responseTokens)
      } as CompletionMetric;
    });

    const isMetricsDone = isMessagesType
      ? rawMetricsMessages.length === progress?.totalMetrics
      : rawMetrics.length === progress?.totalMetrics;

    setCompletionMetrics(metricsAvg);
    updateProgress(getRawMetrics().length);

    if (isMetricsDone) {
      (async () => {
        const [coherence, fluency, bias, factual, toxicity, relevancy] = await gradeCompletions();
        const payload = getRollupPayload(metricsAvg, coherence, fluency, bias, factual, toxicity, relevancy);
        await Promise.all([generateSummary(payload), generateModelGrades(payload)]);
      })();
    }
  }, [rawMetrics, rawMetricsMessages]);

  const updateProgress = (completedCount: number) => {
    setProgress((prev) => {
      return {
        ...prev!,
        completedCount,
        endTime: Date.now(),
        failedCount: getRawMetrics().filter((m) => m.failed).length,
        status: `Completed ${completedCount} of ${prev!.totalCount}`,
        percentageComplete: (completedCount / prev!.totalCount) * 100,
        isDone: completedCount === prev!.totalCount
      };
    });
  };

  const onPVChange = (apv: PV) => {
    if (!apv.version) {
      reset();
      return;
    }

    setApv(apv);
    setIsMessagesType(apv.version.type === PromptVersionTypes.MESSAGING);
    setFullMessageContent({
      systemPrompt: apv.version.messageTemplate?.systemPrompt || '',
      messages: apv.version.messageTemplate?.messages || [],
      tools: apv.version.tools || []
    });

    if (!apv.version.samplePayload) {
      setCurrentPayload(generatePromptPayloadFromVersion(apv.version));
    } else {
      setCurrentPayload(apv.version.samplePayload);
    }
  };

  const onModelSelect = (mid: string, checked: boolean) => {
    if (checked) {
      setSelectedModelIds((prev) => [...prev, mid]);
    } else {
      setSelectedModelIds((prev) => prev.filter((id) => id !== mid));
    }
  };

  const start = async () => {
    if (!apv?.version) return toast.error('Please select a prompt version');
    if (!selectedModelIds.length) return toast.error('Please select at least one model');
    if (isMessagesType && !fullMessageContentEvaluated) return toast.error('Please enter at least one message');
    if (!isMessagesType && !generatedPrompt) return toast.error('Please enter a prompt');

    const totalMetrics = selectedModelIds.length * Number(selectedRunCount.value);
    const totalCount = totalMetrics + 2 + selectedModelIds.length * Object.values(EvaluationType).length; // 2 is summary and grades

    resetMetrics();
    setProgress({
      startTime: Date.now(),
      endTime: Date.now(),
      completedCount: 0,
      totalMetrics,
      totalCount,
      percentageComplete: 0,
      status: 'Starting',
      failedCount: 0,
      isDone: false
    });

    setIsBusy(true);

    let promises = [];
    for (let i = 0; i < selectedModelIds.length; i++) {
      for (let j = 0; j < Number(selectedRunCount.value); j++) {
        const translatedParameters = translateModelParameters(
          models,
          apv.version.model,
          selectedModelIds[i],
          apv.version.parameters,
          apv.version.type
        );

        promises.push(
          isMessagesType
            ? runMessagesCompletion(selectedModelIds[i], translatedParameters)
            : runCompletion(selectedModelIds[i], translatedParameters)
        );
      }
    }

    setSelectedTab(1);

    // axios interceptor will throttle the requests
    await Promise.all(promises.sort(() => Math.random() - 0.5));

    setIsBusy(false);
  };

  const runCompletion = async (modelId: string, parameters: any): Promise<void> => {
    return new Promise<void>((resolve) => {
      runModel(
        modelId,
        generatedPrompt!,
        parameters,
        apv?.version?.type,
        apv?.version?.tools,
        apv?.version?.promptId,
        apv?.version?.version,
        PromptMetricSource.ANALYSIS
      )
        .then((result) => {
          result.modelName = modelNameMap[result.modelId] || result.modelId;
          return result;
        })
        .then((result) => setRawMetrics((prev) => [...prev, result]))
        .catch((error) => toast.error(getErrorMessage(error)))
        .finally(resolve);
    });
  };

  const runMessagesCompletion = async (modelId: string, parameters: any): Promise<void> => {
    return new Promise<void>((resolve) => {
      runModelMessages(
        modelId,
        fullMessageContentEvaluated!.systemPrompt,
        fullMessageContentEvaluated!.messages,
        parameters,
        fullMessageContentEvaluated!.tools,
        apv?.version?.promptId,
        apv?.version?.version,
        PromptMetricSource.ANALYSIS
      )
        .then((result) => {
          result.modelName = modelNameMap[result.modelId] || result.modelId;
          return result;
        })
        .then((result) => setRawMetricsMessages((prev) => [...prev, result]))
        .catch((error) => toast.error(getErrorMessage(error)))
        .finally(resolve);
    });
  };

  const gradeCompletions = async (): Promise<
    [RatedEvaluation[], RatedEvaluation[], RatedEvaluation[], RatedEvaluation[], RatedEvaluation[], RatedEvaluation[]]
  > => {
    const _rawMetrics = getRawMetrics();

    const groupedResponses = _rawMetrics.reduce((acc: Record<string, any>, result) => {
      if (!acc[result.modelName]) {
        acc[result.modelName] = [];
      }

      acc[result.modelName].push(getCompletion(result));

      return acc;
    }, {});

    let count = 0;
    let _coherence: RatedEvaluation[] = [];
    let _fluency: RatedEvaluation[] = [];
    let _bias: RatedEvaluation[] = [];
    let _factual: RatedEvaluation[] = [];
    let _toxicity: RatedEvaluation[] = [];
    let _relevancy: RatedEvaluation[] = [];

    for await (const [model, completions] of Object.entries(groupedResponses)) {
      try {
        let uniqueCompletions: string[] = Array.from(new Set(completions));
        uniqueCompletions = uniqueCompletions.filter((c) => c.length > 0);

        const prompt = getPrompt();

        for (const evaluationType of Object.values(EvaluationType)) {
          try {
            const evaluation = await scoreEvaluation(evaluationType, prompt, uniqueCompletions);
            evaluation.model = model;
            if (evaluationType === EvaluationType.COHERENCE) {
              _coherence.push(evaluation);
              setCompletionCoherence((prev) => [...prev, evaluation]);
            } else if (evaluationType === EvaluationType.FLUENCY) {
              _fluency.push(evaluation);
              setCompletionFluency((prev) => [...prev, evaluation]);
            } else if (evaluationType === EvaluationType.BIAS) {
              _bias.push(evaluation);
              setCompletionBias((prev) => [...prev, evaluation]);
            } else if (evaluationType === EvaluationType.FACTUAL) {
              _factual.push(evaluation);
              setCompletionFactual((prev) => [...prev, evaluation]);
            } else if (evaluationType === EvaluationType.TOXICITY) {
              _toxicity.push(evaluation);
              setCompletionToxicity((prev) => [...prev, evaluation]);
            } else if (evaluationType === EvaluationType.RELEVANCY) {
              _relevancy.push(evaluation);
              setCompletionRelevancy((prev) => [...prev, evaluation]);
            }
          } catch (error) {
            console.error(error);
          } finally {
            updateProgress(_rawMetrics.length + ++count);
          }
        }
      } catch (error) {
        console.error(error);
        toast.error(getErrorMessage(error));
      }
    }

    // FIXME: crappy state hack, fix later
    return [_coherence, _fluency, _bias, _factual, _toxicity, _relevancy];
  };

  const getPrompt = (): string =>
    isMessagesType ? generateCombinedPrompt(fullMessageContentEvaluated!) : generatedPrompt || '';

  const getRollupPayload = (
    finalMetrics: CompletionMetric[],
    coherence: RatedEvaluation[],
    fluency: RatedEvaluation[],
    bias: RatedEvaluation[],
    factual: RatedEvaluation[],
    toxicity: RatedEvaluation[],
    relevancy: RatedEvaluation[]
  ): Record<string, any>[] => {
    return finalMetrics.map((metric) => ({
      model: metric.model,
      latency: metric.avgLatency,
      ttfb: metric.avgTtfb,
      cost: metric.avgTotalCost,
      request_cost: metric.avgRequestCost,
      response_tokens: metric.avgResponseTokens,
      coherence: coherence.find((c) => c.model === metric.model)?.averageRating,
      fluency: fluency.find((f) => f.model === metric.model)?.averageRating,
      bias: bias.find((b) => b.model === metric.model)?.averageRating,
      factual: factual.find((f) => f.model === metric.model)?.averageRating,
      toxicity: toxicity.find((t) => t.model === metric.model)?.averageRating,
      relevancy: relevancy.find((r) => r.model === metric.model)?.averageRating
    }));
  };

  const generateSummary = async (payload: Record<string, any>[]) => {
    try {
      setSummary(await summarizeResults(payload));
    } catch (error) {
      toast.error(getErrorMessage(error));
    }
  };

  const generateModelGrades = async (payload: Record<string, any>[]) => {
    try {
      setGrades(await gradeModels(payload));
    } catch (error) {
      toast.error(getErrorMessage(error));
    }
  };

  const getCost = (): number => {
    return (
      getRawMetrics()
        .map((m) => m.tokens.requestCost + m.tokens.responseCost)
        .reduce((a, b) => a + b, 0) +
      completionCoherence.reduce((a, b) => a + b.totalCost, 0) +
      completionFluency.reduce((a, b) => a + b.totalCost, 0) +
      completionBias.reduce((a, b) => a + b.totalCost, 0) +
      completionFactual.reduce((a, b) => a + b.totalCost, 0) +
      completionToxicity.reduce((a, b) => a + b.totalCost, 0) +
      completionRelevancy.reduce((a, b) => a + b.totalCost, 0) +
      (summary?.totalCost || 0) +
      (grades?.totalCost || 0)
    );
  };

  return (
    <TabGroup
      tabNames={['Configuration', 'Results', 'Data']}
      selectedTabIndex={selectedTab}
      onTabChange={setSelectedTab}>
      <>
        <div className="mx-auto">
          {selectedTab === 0 && (
            <div>
              <div className="mb-5"></div>
              <PromptVersionSelector
                selected={apv}
                onChange={onPVChange}
                className="mb-6"
                defaultLabels={['Select Prompt', 'Select Version']}
              />
              <div className="mt-4">
                <div className="flex gap-x-4">
                  {!apv?.version?.type && (
                    <div className="mt-5 text-gray-700 w-full">
                      <div className="mt-2 text-indigo-500 mx-auto text-center">
                        <FontAwesomeIcon icon={faMagnifyingGlassChart} className="w-40 h-40" />
                      </div>
                      <div className="mt-5 text-gray-700 mx-auto text-center">Select a prompt to get started.</div>
                    </div>
                  )}

                  {apv?.version?.type === PromptVersionTypes.MESSAGING && (
                    <PromptComposerMessages
                      fullMessageContent={fullMessageContent}
                      payload={currentPayload}
                      hbsHelpers={apv?.version?.helpers || {}}
                      disabled={isBusy}
                      className="w-full h-full"
                      autoScroll={false}
                      onPayloadChange={(p) => {
                        if (deepEqual(p, currentPayload)) return;
                        setCurrentPayload(p);
                      }}
                      onMessagesUpdate={(messages, messagesEvald) => {
                        setFullMessageContent(messages);
                        setFullMessageContentEvaluated(messagesEvald);
                        return;
                      }}
                      onUnload={setCurrentPayload}
                    />
                  )}

                  {(apv?.version?.type === PromptVersionTypes.COMPLETION ||
                    apv?.version?.type === PromptVersionTypes.LEGACY) && (
                    <PromptComposer
                      template={apv?.version?.template}
                      payload={currentPayload}
                      disabled={isBusy}
                      onPromptGenerate={setGeneratedPrompt}
                      onUnload={setCurrentPayload}
                    />
                  )}
                  <div className="w-64">
                    <ModelPicker
                      models={models}
                      selectedModels={selectedModelIds}
                      disabled={isBusy}
                      onModelClick={onModelSelect}
                    />
                    <div className="mt-4 flex">
                      <div className="flex-1">
                        <Selector
                          values={runsSelectorValues}
                          onChange={setSelectedRunCount}
                          defaultValue={selectedRunCount}
                          classNames="w-28"
                          disabled={isBusy}
                        />
                      </div>
                      <div className="text-right">
                        <button className="standard" onClick={start} disabled={isBusy}>
                          <FontAwesomeIcon icon={faPersonRunning} className="mr-1" />
                          Start
                        </button>
                      </div>
                    </div>
                  </div>
                </div>
              </div>
            </div>
          )}
          {selectedTab === 1 && (
            <div>
              {progress && <ProgressTracker progress={progress} cost={getCost()} />}

              <div className="my-6 grid grid-cols-2 gap-x-16 text-sm text-gray-700">
                <div>
                  <h2 className="text-lg font-semibold text-gray-700 mb-0">Analysis</h2>
                  {summary ? <div dangerouslySetInnerHTML={{ __html: summary.results }} /> : <Skeleton count={3} />}
                </div>
                <div>
                  <h2 className="text-lg font-semibold text-gray-700 mb-0">Model Grading</h2>
                  {grades ? (
                    <div className="grid grid-cols-2 grow-0 w-72">
                      {Object.entries(grades.grades)
                        .sort(([a], [b]) => a.localeCompare(b))
                        .map(([key, value]) => (
                          <div key={key}>
                            <div className="text-gray-500">{key}</div> <div className="font-semibold">{value}</div>
                          </div>
                        ))}
                    </div>
                  ) : (
                    <Skeleton count={3} />
                  )}
                </div>
              </div>
              <PerformanceGraph
                metrics={completionMetrics}
                coherence={completionCoherence}
                fluency={completionFluency}
                bias={completionBias}
                factual={completionFactual}
                toxicity={completionToxicity}
                relevancy={completionRelevancy}
              />
            </div>
          )}
          {selectedTab === 2 && <ResultsTable results={getRawMetrics()} />}
        </div>
      </>
    </TabGroup>
  );
};

export default Performance;
