import json

import httpx
import pytest
import respx

from cluster_agent.jobbergate.finish import fetch_job_status, finish_active_jobs
from cluster_agent.jobbergate.constants import JobSubmissionStatus
from cluster_agent.utils.exception import SlurmrestdError
from cluster_agent.settings import SETTINGS


@pytest.mark.asyncio
async def test_fetch_pending_submissions__success():
    """
    Test that the ``fetch_job_status()`` function can successfully retrieve
    job_state from Slurm and convert it into a JobSubmissionStatus.
    """
    async with respx.mock:
        respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock(
            return_value=httpx.Response(
                status_code=200, json=dict(access_token="dummy-token")
            )
        )
        respx.get(f"{SETTINGS.BASE_SLURMRESTD_URL}/slurm/v0.0.36/job/11").mock(
            return_value=httpx.Response(
                status_code=200,
                json=dict(
                    jobs=[
                        dict(job_state="COMPLETED"),
                    ],
                ),
            )
        )
        respx.get(f"{SETTINGS.BASE_SLURMRESTD_URL}/slurm/v0.0.36/job/22").mock(
            return_value=httpx.Response(
                status_code=200,
                json=dict(
                    jobs=[
                        dict(job_state="FAILED"),
                    ],
                ),
            )
        )
        respx.get(f"{SETTINGS.BASE_SLURMRESTD_URL}/slurm/v0.0.36/job/33").mock(
            return_value=httpx.Response(
                status_code=200,
                json=dict(
                    jobs=[
                        dict(job_state="UNMAPPED_STATUS"),
                    ],
                ),
            )
        )

        assert await fetch_job_status(11) == JobSubmissionStatus.COMPLETED
        assert await fetch_job_status(22) == JobSubmissionStatus.FAILED
        assert await fetch_job_status(33) == JobSubmissionStatus.SUBMITTED


@pytest.mark.asyncio
async def test_fetch_pending_submissions__raises_SlurmrestdError_if_response_is_not_200(
):
    """
    Test that the ``fetch_job_status()`` will raise a ``SlurmrestdError`` if the
    response is not a 200.
    """
    async with respx.mock:
        respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock(
            return_value=httpx.Response(
                status_code=200, json=dict(access_token="dummy-token")
            )
        )
        respx.get(f"{SETTINGS.BASE_SLURMRESTD_URL}/slurm/v0.0.36/job/11").mock(
            return_value=httpx.Response(status_code=400)
        )
        with pytest.raises(
            SlurmrestdError, match="Failed to fetch job status from slurm"
        ):
            await fetch_job_status(11)


@pytest.mark.asyncio
async def test_finish_active_jobs():
    """
    Test that the ``finish_active_jobs()`` function can fetch active job submissions,
    retrieve the job state from slurm, map it to a ``JobSubmissionStatus``, and update
    the job submission status via the API.
    """
    active_job_submissions_data = [
        dict(id=1, slurm_job_id=11),  # Will complete
        dict(id=2, slurm_job_id=22),  # Jobbergate API gives a 400
        dict(id=3, slurm_job_id=33),  # Slurm REST API gives a 400
        dict(id=4, slurm_job_id=44),  # Slurm has no matching job
        dict(id=5, slurm_job_id=55),  # Unmapped status
    ]

    async with respx.mock:
        respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock(
            return_value=httpx.Response(
                status_code=200, json=dict(access_token="dummy-token")
            )
        )
        fetch_route = respx.get(
            f"{SETTINGS.BASE_API_URL}/jobbergate/job-submissions/agent/active"
        )
        fetch_route.mock(
            return_value=httpx.Response(
                status_code=200,
                json=active_job_submissions_data,
            )
        )

        def _map_slurm_request(request: httpx.Request):
            slurm_job_id = int(request.url.path.split("/")[-1])
            mapper = {
                11: "COMPLETED",
                22: "FAILED",
                33: "COMPLETED",
                44: "COMPLETED",
                55: "UNMAPPED_STATUS",
            }
            return httpx.Response(
                status_code=400 if slurm_job_id == 33 else 200,
                json=dict(
                    jobs=[]
                    if slurm_job_id == 44
                    else [
                        dict(job_state=mapper[slurm_job_id]),
                    ],
                ),
            )

        slurm_route = respx.get(
            url__regex=rf"{SETTINGS.BASE_SLURMRESTD_URL}/slurm/v0.0.36/job/\d+"
        )
        slurm_route.mock(side_effect=_map_slurm_request)

        def _map_update_request(request: httpx.Request):
            job_submission_id = int(request.url.path.split("/")[-1])
            return httpx.Response(status_code=400 if job_submission_id == 2 else 200)

        update_route = respx.put(
            url__regex=rf"{SETTINGS.BASE_API_URL}/jobbergate/job-submissions/agent/\d+"
        )
        update_route.mock(side_effect=_map_update_request)

        await finish_active_jobs()

        def _map_slurm_call(request: httpx.Request):
            return int(request.url.path.split("/")[-1])

        assert slurm_route.call_count == 5
        assert [_map_slurm_call(c.request) for c in slurm_route.calls] == [
            11,
            22,
            33,
            44,
            55,
        ]

        assert fetch_route.call_count == 1

        def _map_update_call(request: httpx.Request):
            return (
                int(request.url.path.split("/")[-1]),
                json.loads(request.content.decode("utf-8"))["new_status"],
            )

        assert update_route.call_count == 2
        assert [_map_update_call(c.request) for c in update_route.calls] == [
            (1, JobSubmissionStatus.COMPLETED),
            (2, JobSubmissionStatus.FAILED),
        ]
