import React, { useMemo } from 'react';
import PropTypes from 'prop-types';
import { useCommonChartStyles } from './charts/use_common_chart_styles';
import { Typography, Box, Divider, Stack } from '@mui/material';
import { Doughnut } from 'react-chartjs-2';
import 'chart.js/auto';
import { observationsWithResponseForAdmNodeByValue, observationsWithResponseForAdmNode } from './observation_adm_node_filter_utils';
import Grid from '@mui/material/Grid';
import { getColor, getPointScaleColor } from './heatmaps/heatmap_colors';
import { isPointScaleQuestion, getNumPoints } from '../shared/utils';

const options = {
  responsive: true,
  maintainAspectRatio: false,
  layout: {
    padding: {
      left: 0,
      right: 0,
    },
  },
  plugins: {
    legend: {
      display: false,
    },
    tooltip: {
      enabled: false,
    },
  },
};

const LegendRectangle = ({ color }) => (
  <Box
    sx={{
      width: 24,
      height: 24,
      backgroundColor: color,
      margin: 0,
    }}
  />
);

const formatPercentage = (value) => (isNaN(value) ? 0 : Math.round(value));

const createChartData = (yesCount, noCount, commonColor) => ({
  labels: ['Yes', 'No'],
  datasets: [
    {
      data: [yesCount, noCount],
      backgroundColor: [getColor(formatPercentage((yesCount / (yesCount + noCount)) * 100)), commonColor.grey],
      borderWidth: 0,
      borderRadius: 0,
    },
  ],
});

const createSlidingScaleChartData = (pointCounts, observations) => {
  const labels = Object.keys(pointCounts).sort((a, b) => Number(b) - Number(a));
  const numPoints = labels.length;
  const data = labels.map((label) => pointCounts[label]);
  const backgroundColors = labels.map((label) => getPointScaleColor(Number(label), numPoints));

  return {
    labels,
    datasets: [
      {
        data: data,
        backgroundColor: backgroundColors,
        borderWidth: 0,
        borderRadius: 0,
      },
    ],
  };
};

const PointScaleChart = ({ numPoints, title, observations, admNode, showSummaryOnly = false }) => {
  const data = {};

  Array.from({ length: numPoints }, (_, i) => i + 1).forEach((pointValue) => {
    data[pointValue] = observationsWithResponseForAdmNodeByValue(observations, admNode, pointValue.toString()).length;
  });

  const chartData = createSlidingScaleChartData(data, observations);
  const totalCount = observationsWithResponseForAdmNode(observations, admNode).length;

  return (
    <Stack direction="column" alignItems="center" gap={2}>
      <Typography
        variant="h6"
        gutterBottom
        style={{
          overflow: 'hidden',
          textOverflow: 'ellipsis',
          whiteSpace: 'nowrap',
        }}
      >
        {title}
      </Typography>
      {totalCount === 0 && <div>No matching observations.</div>}
      {totalCount > 0 && (
        <>
          <Box sx={{ height: '150px', width: 'auto', maxWidth: '100%' }}>
            <Doughnut data={chartData} options={options} />
          </Box>
          <Box sx={{ width: '50%' }}>
            {chartData.labels.map((label, index) => (
              <React.Fragment key={index}>
                <Stack direction="row" spacing={1} justifyContent="space-between" alignItems="center">
                  <Stack direction="column">
                    <Typography color="textSecondary">{label}</Typography>
                    <Typography variant="body2">
                      <strong>
                        {chartData.datasets[0].data[index]} ({formatPercentage((chartData.datasets[0].data[index] / totalCount) * 100)}%)
                      </strong>
                    </Typography>
                  </Stack>
                  <LegendRectangle color={chartData.datasets[0].backgroundColor[index]} />
                </Stack>
                {index < chartData.labels.length - 1 && <Divider sx={{ marginTop: '6px', marginBottom: '6px' }} />}
              </React.Fragment>
            ))}
          </Box>
        </>
      )}
    </Stack>
  );
};

const SingleChart = ({ title, yesObservations, noObservations, commonColor }) => {
  const totalCount = yesObservations.concat(noObservations).length;
  const yesPercentage = formatPercentage((yesObservations.length / totalCount) * 100);
  const noPercentage = formatPercentage((noObservations.length / totalCount) * 100);
  const chartData = createChartData(yesObservations.length, noObservations.length, commonColor);

  return (
    <Stack direction="column" alignItems="center" gap={2}>
      <Typography
        variant="h6"
        gutterBottom
        style={{
          overflow: 'hidden',
          textOverflow: 'ellipsis',
          whiteSpace: 'nowrap',
        }}
      >
        {title}
      </Typography>
      {totalCount === 0 && <div>No matching observations.</div>}
      {totalCount > 0 && (
        <>
          <Box sx={{ height: '150px', width: 'auto', maxWidth: '100%' }}>
            <Doughnut data={chartData} options={options} />
          </Box>
          <Box sx={{ width: '50%' }}>
            <Stack direction="row" spacing={1} justifyContent="space-between" alignItems="center">
              <div>
                <Typography color="textSecondary">Yes</Typography>
                <Typography variant="body2">
                  <strong>
                    {yesObservations.length} ({yesPercentage}%)
                  </strong>
                </Typography>
              </div>
              <LegendRectangle color={chartData.datasets[0].backgroundColor[0]} />
            </Stack>
            <Divider sx={{ marginTop: '6px', marginBottom: '6px' }} />
            <Stack direction="row" spacing={1} justifyContent="space-between" alignItems="center">
              <div>
                <Typography color="textSecondary">No</Typography>
                <Typography variant="body2">
                  <strong>
                    {noObservations.length} ({noPercentage}%)
                  </strong>
                </Typography>
              </div>
              <LegendRectangle color={chartData.datasets[0].backgroundColor[1]} />
            </Stack>
          </Box>
        </>
      )}
    </Stack>
  );
};

const AdmNodeDonutChart = ({ admNode, observations, showHqimCharts, showSummaryOnly = false }) => {
  const { commonColor } = useCommonChartStyles();
  const isPointScale = isPointScaleQuestion(admNode);
  const numPoints = getNumPoints(admNode);

  const col1_observations = useMemo(() => {
    return showHqimCharts ? observations.filter((o) => o.hqim_in_use === true) : observations.filter((o) => o.content_area?.name === 'ELA');
  }, [showHqimCharts, observations]);

  const col2_observations = useMemo(() => {
    return showHqimCharts ? observations.filter((o) => o.hqim_in_use === false) : observations.filter((o) => o.content_area?.name === 'Math');
  }, [showHqimCharts, observations]);

  const yesObservations = useMemo(() => {
    return !isPointScale ? observationsWithResponseForAdmNodeByValue(observations, admNode, 'yes') : [];
  }, [observations, admNode, isPointScale]);

  const noObservations = useMemo(() => {
    return !isPointScale ? observationsWithResponseForAdmNodeByValue(observations, admNode, 'no') : [];
  }, [observations, admNode, isPointScale]);

  return (
    <Grid container wrap="wrap" spacing={2}>
      {isPointScale ? (
        <>
          <Grid item xs={12} sm={4} md={4}>
            <PointScaleChart
              title={showSummaryOnly ? 'All' : 'Total'}
              observations={observations}
              admNode={admNode}
              numPoints={numPoints}
              showSummaryOnly={showSummaryOnly}
            />
          </Grid>
          {!showSummaryOnly && (
            <>
              <Grid item xs={12} sm={4} md={4}>
                <PointScaleChart
                  title={showHqimCharts ? 'With HQIM' : 'ELA'}
                  observations={col1_observations}
                  admNode={admNode}
                  numPoints={numPoints}
                />
              </Grid>
              <Grid item xs={12} sm={4} md={4}>
                <PointScaleChart
                  title={showHqimCharts ? 'Without HQIM' : 'Math'}
                  observations={col2_observations}
                  admNode={admNode}
                  numPoints={numPoints}
                />
              </Grid>
            </>
          )}
        </>
      ) : (
        <>
          <Grid item xs={12} sm={4} md={4}>
            <SingleChart title="Total" yesObservations={yesObservations} noObservations={noObservations} commonColor={commonColor} />
          </Grid>
          {!showSummaryOnly && (
            <>
              <Grid item xs={12} sm={4} md={4}>
                <SingleChart
                  title={showHqimCharts ? 'With HQIM' : 'ELA'}
                  yesObservations={
                    showHqimCharts
                      ? yesObservations.filter((o) => o.hqim_in_use === true)
                      : yesObservations.filter((o) => o.content_area?.name === 'ELA')
                  }
                  noObservations={
                    showHqimCharts
                      ? noObservations.filter((o) => o.hqim_in_use === true)
                      : noObservations.filter((o) => o.content_area?.name === 'ELA')
                  }
                  commonColor={commonColor}
                />
              </Grid>
              <Grid item xs={12} sm={4} md={4}>
                <SingleChart
                  title={showHqimCharts ? 'Without HQIM' : 'Math'}
                  yesObservations={
                    showHqimCharts
                      ? yesObservations.filter((o) => o.hqim_in_use === false)
                      : yesObservations.filter((o) => o.content_area?.name === 'Math')
                  }
                  noObservations={
                    showHqimCharts
                      ? noObservations.filter((o) => o.hqim_in_use === false)
                      : noObservations.filter((o) => o.content_area?.name === 'Math')
                  }
                  commonColor={commonColor}
                />
              </Grid>
            </>
          )}
        </>
      )}
    </Grid>
  );
};

AdmNodeDonutChart.propTypes = {
  observations: PropTypes.array.isRequired,
};

export default AdmNodeDonutChart;
