import { CompareLocationParameters, MaskAnnotation } from "@customTypes/data";
import { MakeGenerics, useNavigate, useSearch } from "@tanstack/react-location";
import React, {
  ComponentProps,
  useCallback,
  useEffect,
  useRef,
  useState,
  useMemo,
} from "react";

import { DatasetPicker } from "components/DatasetPicker";
import Loader from "components/Loader";
import { PredictionRunPicker } from "components/PredictionRunPicker";
import AppLayout from "./Layouts/AppLayout";
import CreatableSelect from "react-select/creatable";

import { Prediction } from "@customTypes/data";
import produce from "immer";

import { usePredictions } from "hooks/usePredictions";
import MaskVisualizer2 from "components/MaskVisualizer2";
import { loadImageFromSrc, redToAlpha } from "helpers/imageHelpers";
import { maskAnnotationsUrl, trainingImagesUrl } from "../constants";
import { djangoBackend } from "services/apiServices";

import { OnChangeValue, StylesConfig, createFilter } from "react-select";
import { useTags } from "hooks/useTags";
import { Masonry } from "masonic";

interface SelectOption {
  readonly value: string;
  readonly label: string;
}

type SearchProps = MakeGenerics<{
  Search: {
    tags: string[];
  };
}>;

const metrics = ["iou", "mae"] as const;

function TagPicker({
  onTagChange,
  selectedTags,
}: {
  onTagChange: (tags: string[]) => void;
  selectedTags: string[];
}): JSX.Element {
  const selectedNonVisionTags = selectedTags.filter(
    (tag) => !tag.startsWith("vision_")
  );
  const selectedVisionTags = selectedTags.filter((tag) =>
    tag.startsWith("vision_")
  );

  const { isLoading: tagLoading, error: tagError, data: tagPages } = useTags();
  const tagData = tagPages?.pages.flatMap((page) => {
    return page.results.map((tag) => tag.id);
  });

  const handleChangeVisionTag = (
    newValue: OnChangeValue<SelectOption, true>
  ) => {
    const newSelectedVisionTags = newValue.map((tagOption) => tagOption.value);
    onTagChange(newSelectedVisionTags.concat(selectedNonVisionTags));
  };

  const handleChangeNonVisionTag = (
    newValue: OnChangeValue<SelectOption, true>
  ) => {
    const newSelectedNonVisionTags = newValue.map(
      (tagOption) => tagOption.value
    );
    onTagChange(selectedVisionTags.concat(newSelectedNonVisionTags));
  };

  if (tagError) {
    return <div>Error: {tagError}</div>;
  }

  const tagOptions = tagData?.map((tag) => {
    return { value: tag, label: tag };
  });

  type IsMulti = true;
  const customStyles: StylesConfig<SelectOption, IsMulti> = {
    multiValueLabel: (provided) => {
      return {
        ...provided,
        ...{ overflowX: "scroll", textOverflow: "initial" },
      };
    },
  };

  return (
    <div className="">
      <span>Vision Tags</span>
      <CreatableSelect
        filterOption={createFilter({ ignoreAccents: false })} // Speed optimisation
        isMulti
        value={selectedVisionTags.map((tag) => ({ value: tag, label: tag }))}
        onChange={handleChangeVisionTag}
        options={tagOptions?.filter((tag) => tag.label.startsWith("vision_"))}
        isLoading={tagLoading}
      />
      <span>Tags</span>
      <CreatableSelect
        isMulti
        onChange={handleChangeNonVisionTag}
        options={tagOptions?.filter((tag) => !tag.label.startsWith("vision_"))}
        value={selectedNonVisionTags.map((tag) => ({ value: tag, label: tag }))}
        isLoading={tagLoading}
        styles={customStyles}
      />
    </div>
  );
}

function PredictionsGrid({
  predictions,
  onClick,
  masonryUniqueKey,
}: {
  predictions: Map<string, Prediction>[];
  onClick: (imageId: string) => void;
  masonryUniqueKey: string;
}) {
  const CardWithClick = useCallback(
    (props) => <PredictionsGridCard {...props} onClick={onClick} />,
    []
  );

  return (
    <Masonry
      items={predictions}
      key={masonryUniqueKey}
      columnGutter={8}
      columnWidth={172}
      overscanBy={3}
      render={CardWithClick}
    />
  );
}

const PredictionsGridCard = ({
  data: prediction,
  onClick,
}: {
  data: Map<string, Prediction>;
  onClick: (imageId: string) => void;
}) => {
  const size = 200;

  let imageId: string | undefined = undefined;
  let imageUrl: string | undefined = undefined;
  let initialPredictionMaskUrl: string | undefined = undefined;

  const predictionMetrics: Record<typeof metrics[number], number[]> = {
    iou: [],
    mae: [],
  };

  prediction.forEach((value) => {
    if (imageUrl === undefined) {
      imageUrl = value.image.imageUrl200x200;
    }
    if (imageId === undefined) {
      imageId = value.image.id;
    }
    if (initialPredictionMaskUrl === undefined) {
      initialPredictionMaskUrl = value.imageUrlFullSize;
    }

    for (const metric of metrics) {
      const metricValue = value.metrics[metric];

      if (metricValue !== undefined) {
        predictionMetrics[metric].push(value.metrics[metric]);
      }
    }
  });

  // undefined when not yet loaded.
  // null when loading.
  // string when loaded.
  const [predictionMaskUrl, setPredictionMaskUrl] = useState<
    string | undefined | "loading"
  >(undefined);

  const [mouseOver, setMouseOver] = useState(false);

  const maskImageUrl =
    mouseOver &&
    predictionMaskUrl !== undefined &&
    predictionMaskUrl !== "loading"
      ? predictionMaskUrl
      : undefined;

  useEffect(() => {
    if (
      mouseOver &&
      predictionMaskUrl === undefined &&
      initialPredictionMaskUrl !== undefined
    ) {
      setPredictionMaskUrl("loading");
      console.log(initialPredictionMaskUrl);
      loadImageFromSrc(initialPredictionMaskUrl).then((mask) => {
        setPredictionMaskUrl((predictionMaskUrl) => {
          if (predictionMaskUrl === "loading") {
            return redToAlpha(mask)?.toDataURL();
          } else {
            return predictionMaskUrl;
          }
        });
      });
    }
  }, [mouseOver, initialPredictionMaskUrl]);

  useEffect(() => {
    setPredictionMaskUrl(undefined);
  }, [imageId]);

  return (
    <div className="bg-white flex relative hover:z-10 hover:shadow-2xl">
      <button
        onClick={() => {
          if (imageId !== undefined) {
            onClick(imageId);
          }
        }}
        onMouseOver={() => {
          setMouseOver(true);
        }}
        onMouseOut={() => {
          setMouseOver(false);
        }}
        className="bg-cover bg-center grow"
        style={{
          height: size,
          maskImage:
            maskImageUrl !== undefined ? `url(${maskImageUrl})` : undefined,
          maskSize: "cover",
          maskPosition: "center",
          backgroundImage: `url(${imageUrl})`,

          WebkitMaskImage:
            maskImageUrl !== undefined ? `url(${maskImageUrl})` : undefined,
          WebkitMaskSize: "cover",
          WebkitMaskPosition: "center",
        }}
      />
      {mouseOver && (
        <div className="absolute w-full grid place-items-center grid-rows-1 grid-flow-col top-full bg-black/50 p-2 backdrop-blur rounded-b-xl text-white shadow-2xl">
          {Object.entries(predictionMetrics).map(([metric, metricValues]) => (
            <div key={metric} className="flex flex-col items-center">
              <span className="uppercase font-bold">{metric}</span>
              {metricValues.map((metricValue, index) => (
                <span key={index}>{metricValue.toFixed(2)}</span>
              ))}
            </div>
          ))}
        </div>
      )}
    </div>
  );
};

async function getGroundTruthUrl(imageId: string) {
  const url = `${trainingImagesUrl}${imageId}/`;

  const response = await djangoBackend.get(url);
  const promises = [];

  for (const maskAnnotation of response.data.maskAnnotations) {
    const url = `${maskAnnotationsUrl}${maskAnnotation.id}/`;
    promises.push(
      djangoBackend.get(url).then((response) => {
        return response.data as MaskAnnotation;
      })
    );
  }

  const maskUrls = await Promise.all(promises);

  // TODO: Unsure about which one to choose.  By default take the most recent one.
  return maskUrls.reduce((prev, current) => {
    if (prev.createdAt > current.createdAt) {
      return prev;
    } else {
      return current;
    }
  }, maskUrls[0]).maskUrl;
}

function initialMasks(numberOfPredictions: number) {
  return [
    {
      image: undefined,
      mask: null,
      label: "Reference Image",
    },
    {
      image: undefined,
      mask: undefined,
      label: "Ground Truth",
    },
  ].concat(
    Array.from(Array(numberOfPredictions)).map((_, index) => {
      return {
        image: undefined,
        mask: undefined,
        label: `Prediction ${index + 1}`,
      };
    })
  );
}

const MaskVisualizerDialog = ({
  imageId,
  predictionRuns,
  predictions,
  onPreviousImage,
  onNextImage,
  onClose,
}: {
  imageId: string;
  predictionRuns: string[];
  predictions: Prediction[];
  onPreviousImage: () => void;
  onNextImage: () => void;
  onClose: () => void;
}) => {
  const [masks, setMasks] = useState<
    ComponentProps<typeof MaskVisualizer2>["masks"]
  >(initialMasks(predictionRuns.length));

  useEffect(() => {
    document.body.style.overflow = "hidden";

    return () => {
      document.body.style.overflow = "unset";
    };
  }, [imageId]);

  useEffect(() => {
    setMasks(initialMasks(predictionRuns.length));

    if (predictions.length > 0) {
      loadImageFromSrc(predictions[0].image.imageUrlFullSize).then((image) => {
        setMasks((masks) =>
          produce(masks, (draft) => {
            for (const mask of draft) {
              Object.assign(mask, { ...mask, image });
            }
          })
        );
      });
    }

    getGroundTruthUrl(imageId).then((groundTruthUrl) => {
      loadImageFromSrc(groundTruthUrl).then((image) => {
        setMasks((masks) =>
          produce(masks, (draft) => {
            Object.assign(draft[1], { ...draft[1], mask: image });
          })
        );
      });
    });

    for (let i = 0; i < predictions.length; i++) {
      loadImageFromSrc(predictions[i].imageUrlFullSize).then((image) => {
        setMasks((masks) =>
          produce(masks, (draft) => {
            Object.assign(draft[i + 2], { ...draft[i + 2], mask: image });
          })
        );
      });
    }
  }, [imageId]);

  return (
    <div className="fixed w-full h-full top-0 left-0 z-20">
      <MaskVisualizer2
        masks={masks}
        onClose={onClose}
        onPreviousImage={onPreviousImage}
        onNextImage={onNextImage}
      />
    </div>
  );
};

export default function CompareContainer() {
  const navigate = useNavigate();
  const { dataset, predictionRuns, imageId } =
    useSearch<CompareLocationParameters>();

  const { isLoading, data: predictionsByRunId } = usePredictions({
    predictionRuns: predictionRuns || [],
  });

  const [selectedSortMetric, setSelectedSortMetric] = useState<string>(
    metrics[0]
  );

  const sortedPredictionRuns = useMemo(() => {
    if (predictionsByRunId === undefined || predictionRuns === undefined) {
      return undefined;
    }

    return predictionsByRunId.sort((a, b) => {
      const metricA = a.get(predictionRuns[0])?.metrics[selectedSortMetric];
      const metricB = b.get(predictionRuns[0])?.metrics[selectedSortMetric];

      if (metricA === undefined || metricB === undefined) {
        return 0;
      }

      return metricA > metricB ? -1 : 1;
    });
  }, [predictionsByRunId, selectedSortMetric]);

  const getPredictions = (imageId: string): Prediction[] => {
    // TODO: (eliot) fix this ugly code, find better data structure
    // It's weird that we're passing runs as parameters of usePredictionTable, could be a setter?
    if (predictionsByRunId === undefined || predictionRuns === undefined) {
      console.log("returning early", predictionRuns, predictionsByRunId);
      return [];
    }

    const predictionRunsForImage = predictionsByRunId.find(
      (predictionRunId) => {
        return predictionRunId.get(predictionRuns[0])?.image.id === imageId;
      }
    );

    return Array.from(predictionRunsForImage?.values() || []);
  };

  const onDatasetChange = (dataset: string | undefined) => {
    navigate({
      search: (old) => ({
        ...old,
        predictionRuns: undefined,
        dataset: dataset,
      }),
    });
  };

  const onPredictionRunsChange = (predictionRuns: string[]) => {
    navigate({
      search: (old) => ({
        ...old,
        predictionRuns: predictionRuns,
      }),
    });
  };

  const navigateToImageId = (imageId: string | undefined) => {
    navigate({
      search: (old) => ({
        ...old,
        imageId: imageId,
      }),
    });
  };

  // TODO: Find why using `imageId` directly doesn't work.
  const imageIdRef = useRef(imageId);
  imageIdRef.current = imageId;

  const getImageIdRelative = (relativeIndex: (index: number) => number) => {
    if (sortedPredictionRuns !== undefined && predictionRuns !== undefined) {
      const currentIndex = sortedPredictionRuns.findIndex((predictionRun) => {
        return (
          predictionRun.get(predictionRuns[0])?.image.id === imageIdRef.current
        );
      });

      if (currentIndex !== undefined) {
        const newIndex = relativeIndex(currentIndex);

        if (newIndex >= 0 || newIndex < sortedPredictionRuns.length) {
          const newImageId = sortedPredictionRuns[newIndex].get(
            predictionRuns[0]
          )?.image.id;

          if (newImageId !== undefined) {
            return newImageId;
          }
        }
      }
    }

    return imageIdRef.current;
  };

  const search = useSearch<SearchProps>();

  const onTagChange = (tags: string[]) => {
    navigate({
      search: (old) => ({
        ...old,
        tags: tags,
      }),
    });
  };

  return (
    <AppLayout requiresAuthentication={true}>
      {/* {isLoading && <h1>Loading data</h1>} */}

      {imageId !== undefined &&
        predictionRuns !== undefined &&
        predictionsByRunId !== undefined && (
          <MaskVisualizerDialog
            imageId={imageId}
            predictionRuns={predictionRuns}
            predictions={getPredictions(imageId)}
            onClose={() => navigateToImageId(undefined)}
            onPreviousImage={() =>
              navigateToImageId(getImageIdRelative((index) => index - 1))
            }
            onNextImage={() =>
              navigateToImageId(getImageIdRelative((index) => index + 1))
            }
          />
        )}

      <main className="flex flex-row">
        <aside className="flex flex-col divide-y w-[400px]">
          <div className="flex flex-col gap-2 p-4">
            <DatasetPicker
              initialDataset={dataset}
              onDatasetChange={onDatasetChange}
            />
            <PredictionRunPicker
              selectedPredictionRuns={predictionRuns ?? []}
              dataset={dataset}
              onPredictionRunChange={onPredictionRunsChange}
            />
          </div>

          {predictionRuns !== undefined && (
            <div className="p-4">
              <TagPicker
                onTagChange={onTagChange}
                selectedTags={search.tags ?? []}
              />
            </div>
          )}
          {predictionRuns !== undefined && (
            <div className="p-4 flex flex-col gap-2">
              <span>Sort by</span>
              {metrics.map((metric) => (
                <label key={metric} className="uppercase">
                  <input
                    type="radio"
                    value={metric}
                    name={metric}
                    onChange={(event) =>
                      setSelectedSortMetric(event.target.value)
                    }
                    checked={selectedSortMetric === metric}
                  />{" "}
                  {metric}
                </label>
              ))}
            </div>
          )}
          {predictionRuns !== undefined && (
            <div className="p-4">
              Found {predictionsByRunId?.length} predictions
              {isLoading && <Loader size="small" />}
            </div>
          )}
        </aside>

        {sortedPredictionRuns !== undefined && (
          <section className="p-2 grow">
            <PredictionsGrid
              predictions={sortedPredictionRuns}
              onClick={(imageId) => navigateToImageId(imageId)}
              masonryUniqueKey="predictionImages"
            />
          </section>
        )}
      </main>
    </AppLayout>
  );
}
