import { GridApiPro, GridColDef, GridComparatorFn, GridRenderCellParams } from '@mui/x-data-grid-pro';
import { DiseaseArea } from 'data/DiseaseAreaData';
import { PatientCountType } from 'components/grid/GridCountType';
import { MutableRefObject, useMemo } from 'react';
import { CompactGridWrapper } from 'components/grid/CompactGridWrapper';
import { DiseaseAreaMetrics } from 'data/DiseaseAreaMetricsData';
import { orderBy } from 'lodash';
import useMemoTranslation from 'hooks/UseMemoTranslation';
import { generateWidthFromProperty } from 'util/grid/TableUtils';
import { PatientCountGridCell, PatientCountGridCellProps } from 'diseaseAreas/diseaseAreasGrid/PatientCountGridCell';
import { ZeroCountGridCell } from 'components/grid/cell/ZeroCountGridCell';
import { DiseaseAreaGridCell } from './DiseaseAreaGridCell';
import { dashboard, diseaseArea, icd10Codes, lastUpdated, total } from 'util/Constants';
import { renderCellTooltip } from 'components/grid/cell/GridCellTooltip';
import { ConvertAllCodesToUserFormat } from 'data/conversions/Icd10CodeConversions';
import { LoadingState } from 'components/LoadingStateUtil';
import { RefreshStatus } from 'data/RefreshStatusData';

export interface DiseaseAreaCountGridProps {
  loadingState: LoadingState;
  diseaseAreaMetrics: DiseaseAreaMetrics[];
  countType: PatientCountType;
  refreshStatus: RefreshStatus | undefined;
  onDiseaseAreaChange: Function;
  onCountClick(
    count: number,
    diseaseArea: DiseaseArea,
    patientRecordCountType: PatientCountType,
    biobankId?: string,
    sampleTypeId?: string
  ): void;
  apiRef?: MutableRefObject<GridApiPro>;
}

export const DiseaseAreaCountGrid = ({
  loadingState,
  diseaseAreaMetrics,
  countType,
  refreshStatus,
  onDiseaseAreaChange,
  onCountClick,
  apiRef,
}: DiseaseAreaCountGridProps) => {
  const columns = useColumns(diseaseAreaMetrics, countType, onCountClick, onDiseaseAreaChange);
  const rows = useRows(diseaseAreaMetrics, countType, onCountClick, refreshStatus);

  return (
    <CompactGridWrapper
      loadingState={loadingState}
      apiRef={apiRef}
      rows={rows}
      columns={columns}
      initialState={{ pinnedColumns: { left: ['diseaseArea', 'icd10Codes', 'total'] } }}
      hideFooterRowCount
      hideFooterSelectedRowCount
      disableRowSelectionOnClick
    />
  );
};

const useColumns = (
  diseaseAreaMetrics: DiseaseAreaMetrics[],
  countType: PatientCountType,
  onCountClick: (
    count: number,
    diseaseArea: DiseaseArea,
    patientCountType: PatientCountType,
    biobankId?: string,
    sampleTypeId?: string
  ) => void,
  onDiseaseAreaChange: Function
): GridColDef[] => {
  const { t } = useMemoTranslation();

  return useMemo(() => {
    const startColumns: GridColDef[] = [
      {
        field: diseaseArea,
        headerName: t(diseaseArea),
        width: generateWidthFromProperty(
          diseaseAreaMetrics.flatMap(m => [m.diseaseArea.name]) ?? [],
          170,
          row => row,
          9
        ),
        align: 'left',
        filterable: false,
        valueFormatter: ({ value }) => value && value.name,
        sortComparator: diseaseAreaComparator,
        renderCell: (params: GridRenderCellParams<DiseaseArea>) => {
          if (params.value) {
            return <DiseaseAreaGridCell diseaseArea={params.value} onDiseaseAreaChange={onDiseaseAreaChange} />;
          }
        },
      },
      {
        field: icd10Codes,
        headerName: t(icd10Codes),
        width: 200,
        renderCell: renderCellTooltip,
      },
      {
        field: total,
        headerName: t(total),
        type: 'number',
        width: generateWidthFromProperty(diseaseAreaMetrics.flatMap(m => [m.patientCount]) ?? [], 170, row => row, 7),
        headerAlign: 'center',
        align: 'center',
        filterable: false,
        valueFormatter: ({ value }) => value && value.count,
        sortComparator: patientCountComparator,
        renderCell: (params: GridRenderCellParams<PatientCountGridCellProps>) => {
          if (params.value) {
            return (
              <PatientCountGridCell
                diseaseArea={params.value.diseaseArea}
                biobankId={params.value.biobankId}
                sampleTypeId={params.value.sampleTypeId}
                count={params.value.count}
                countType={params.value.countType}
                pendingRefresh={params.value.pendingRefresh}
                onCountClick={onCountClick}
              />
            );
          } else {
            return <ZeroCountGridCell />;
          }
        },
      },
    ];

    const countColumns: GridColDef[] = [];

    const columnFields =
      countType === 'byBiobank'
        ? GetBiobankColumnFields(diseaseAreaMetrics)
        : GetSampleTypeColumnFields(diseaseAreaMetrics);

    columnFields.forEach(field => {
      countColumns.push({
        field: field.id,
        headerName: field.name,
        type: 'number',
        width: generateWidthFromProperty(diseaseAreaMetrics.flatMap(m => [m.patientCount]) ?? [], 170, row => row, 7),
        headerAlign: 'center',
        align: 'center',
        filterable: false,
        valueFormatter: ({ value }) => value && value.count,
        sortComparator: patientCountComparator,
        renderCell: (params: GridRenderCellParams<PatientCountGridCellProps>) => {
          if (params.value) {
            return (
              <PatientCountGridCell
                diseaseArea={params.value.diseaseArea}
                biobankId={params.value.biobankId}
                sampleTypeId={params.value.sampleTypeId}
                count={params.value.count}
                countType={params.value.countType}
                pendingRefresh={params.value.pendingRefresh}
                onCountClick={onCountClick}
              />
            );
          } else {
            return <ZeroCountGridCell />;
          }
        },
      });
    });

    const endColumns: GridColDef[] = [
      {
        field: dashboard,
        headerName: t(dashboard),
        type: 'boolean',
        width: 110,
        headerAlign: 'center',
        align: 'center',
        renderCell: (params: GridRenderCellParams) => (params.value ? 'x' : ''),
      },
      {
        field: lastUpdated,
        headerName: t(lastUpdated),
        type: 'date',
        width: 150,
        headerAlign: 'center',
        align: 'center',
        filterable: false,
        renderCell: (params: GridRenderCellParams<Date>) => (
          <span title={params.value ? params.value.toString() : ''}>
            {params.value
              ? params.value.toLocaleString('en-US', { year: 'numeric', month: 'short', day: 'numeric' })
              : ''}
          </span>
        ),
      },
    ];

    return startColumns.concat(countColumns, endColumns);
  }, [t, countType, diseaseAreaMetrics, onCountClick, onDiseaseAreaChange]);
};

const useRows = (
  diseaseAreaMetrics: DiseaseAreaMetrics[],
  countType: PatientCountType,
  onCountClick: (
    count: number,
    diseaseArea: DiseaseArea,
    patientCountType: PatientCountType,
    biobankId?: string,
    sampleTypeId?: string
  ) => void,
  refreshStatus: RefreshStatus | undefined
) => {
  return useMemo(() => {
    const rows: any[] = [];

    diseaseAreaMetrics.forEach((d, index) => {
      let row: any = {};

      let totalCount: PatientCountGridCellProps = {
        diseaseArea: d.diseaseArea,
        countType: 'all',
        count: d.patientCount.patientCount,
        pendingRefresh: d.patientCount.pendingRefresh,
        onCountClick: onCountClick,
      };

      row.id = index;
      row.diseaseArea = d.diseaseArea;
      row.icd10Codes = ConvertAllCodesToUserFormat(d.diseaseArea.icd10Codes).join(', ');
      row.total = totalCount;
      row.dashboard = d.diseaseArea.availableToShowOnDashboard;
      row.lastUpdated = refreshStatus?.lastRefreshed ? new Date(refreshStatus.lastRefreshed) : undefined;

      if (countType === 'byBiobank') {
        d.patientCountsByBiobank.forEach(c => {
          if (c.diseaseAreaId === d.diseaseArea.diseaseAreaId) {
            let count: PatientCountGridCellProps = {
              diseaseArea: d.diseaseArea,
              biobankId: c.biobankId,
              count: c.patientCount,
              countType: 'byBiobank',
              pendingRefresh: c.pendingRefresh,
              onCountClick: onCountClick,
            };

            Object.defineProperty(row, `${c.biobankId}`, { value: count, writable: true });
          }
        });
      } else {
        d.patientCountsBySampleType.forEach(c => {
          if (c.diseaseAreaId === d.diseaseArea.diseaseAreaId) {
            let count: PatientCountGridCellProps = {
              diseaseArea: d.diseaseArea,
              sampleTypeId: c.sampleTypeId,
              count: c.patientCount,
              countType: 'bySampleType',
              pendingRefresh: c.pendingRefresh,
              onCountClick: onCountClick,
            };

            Object.defineProperty(row, `${c.sampleTypeId}`, { value: count, writable: true });
          }
        });
      }

      rows.push(row);
    });

    return rows;
  }, [countType, diseaseAreaMetrics, onCountClick, refreshStatus]);
};

function GetBiobankColumnFields(diseaseAreaMetrics: DiseaseAreaMetrics[]) {
  let counts = diseaseAreaMetrics.flatMap(m => m.patientCountsByBiobank);
  let biobankIdToName = new Map<string, string>();
  let totals = new Map<string, number>();

  counts.forEach(count => {
    if (totals.has(count.biobankId)) {
      let biobankTotal = totals.get(count.biobankId) ?? 0;
      totals.set(count.biobankId, biobankTotal + count.patientCount);
    } else {
      biobankIdToName.set(count.biobankId, count.biobank);
      totals.set(count.biobankId, count.patientCount);
    }
  });

  let totalArray = Array.from(totals, ([id, totalPatientCount]) => ({ id, totalPatientCount })).filter(
    c => c.totalPatientCount > 0
  );
  let result = orderBy(totalArray, ['totalPatientCount', 'biobankId'], ['desc', 'asc']);

  return result.map(c => ({
    id: c.id,
    name: biobankIdToName.get(c.id) ?? '',
  }));
}

function GetSampleTypeColumnFields(diseaseAreaMetrics: DiseaseAreaMetrics[]) {
  let counts = diseaseAreaMetrics.flatMap(m => m.patientCountsBySampleType);
  let sampleIdToName = new Map<string, string>();
  let totals = new Map<string, number>();

  counts.forEach(count => {
    if (totals.has(count.sampleTypeId)) {
      let sampleTypeTotal = totals.get(count.sampleTypeId) ?? 0;
      totals.set(count.sampleTypeId, sampleTypeTotal + count.patientCount);
    } else {
      sampleIdToName.set(count.sampleTypeId, count.sampleType);
      totals.set(count.sampleTypeId, count.patientCount);
    }
  });

  let totalArray = Array.from(totals, ([id, totalPatientCount]) => ({ id, totalPatientCount })).filter(
    c => c.totalPatientCount > 0
  );
  let result = orderBy(totalArray, ['totalPatientCount', 'sampleTypeId'], ['desc', 'asc']);

  return result.map(c => ({
    id: c.id,
    name: sampleIdToName.get(c.id) ?? '',
  }));
}

const diseaseAreaComparator: GridComparatorFn = (v1, v2) => {
  if ((v1 as DiseaseArea).name < (v2 as DiseaseArea).name) {
    return -1;
  } else if ((v1 as DiseaseArea).name > (v2 as DiseaseArea).name) {
    return 1;
  } else {
    return 0;
  }
};

const patientCountComparator: GridComparatorFn = (v1, v2) => {
  let val1 = v1 ?? 0;
  let val2 = v2 ?? 0;
  return (val1 as PatientCountGridCellProps).count - (val2 as PatientCountGridCellProps).count;
};
