# SPDX-License-Identifier: MIT
# Copyright (C) 2022 Max Bachmann
from __future__ import annotations

from rapidfuzz.fuzz import ratio
from rapidfuzz.process_cpp_impl import FLOAT32 as _FLOAT32
from rapidfuzz.process_cpp_impl import FLOAT64 as _FLOAT64
from rapidfuzz.process_cpp_impl import INT8 as _INT8
from rapidfuzz.process_cpp_impl import INT16 as _INT16
from rapidfuzz.process_cpp_impl import INT32 as _INT32
from rapidfuzz.process_cpp_impl import INT64 as _INT64
from rapidfuzz.process_cpp_impl import UINT8 as _UINT8
from rapidfuzz.process_cpp_impl import UINT16 as _UINT16
from rapidfuzz.process_cpp_impl import UINT32 as _UINT32
from rapidfuzz.process_cpp_impl import UINT64 as _UINT64
from rapidfuzz.process_cpp_impl import cdist as _cdist
from rapidfuzz.process_cpp_impl import cpdist as _cpdist
from rapidfuzz.process_cpp_impl import extract, extract_iter, extractOne

__all__ = ["extract", "extract_iter", "extractOne", "cdist", "cpdist"]


def _dtype_to_type_num(dtype):
    import numpy as np

    if dtype is None:
        return None
    if dtype is np.int32:
        return _INT32
    if dtype is np.int8:
        return _INT8
    if dtype is np.int16:
        return _INT16
    if dtype is np.int64:
        return _INT64
    if dtype is np.uint8:
        return _UINT8
    if dtype is np.uint16:
        return _UINT16
    if dtype is np.uint32:
        return _UINT32
    if dtype is np.uint64:
        return _UINT64
    if dtype is np.float32:
        return _FLOAT32
    if dtype is np.float64:
        return _FLOAT64

    msg = "unsupported dtype"
    raise TypeError(msg)


def cdist(
    queries,
    choices,
    *,
    scorer=ratio,
    processor=None,
    score_cutoff=None,
    score_hint=None,
    score_multiplier=1,
    dtype=None,
    workers=1,
    **kwargs,
):
    import numpy as np

    dtype = _dtype_to_type_num(dtype)
    return np.asarray(
        _cdist(
            queries,
            choices,
            scorer=scorer,
            processor=processor,
            score_cutoff=score_cutoff,
            score_hint=score_hint,
            score_multiplier=score_multiplier,
            dtype=dtype,
            workers=workers,
            **kwargs,
        )
    )


def cpdist(
    queries,
    choices,
    *,
    scorer=ratio,
    processor=None,
    score_cutoff=None,
    score_hint=None,
    score_multiplier=1,
    dtype=None,
    workers=1,
    **kwargs,
):
    import numpy as np

    dtype = _dtype_to_type_num(dtype)
    distance_matrix = _cpdist(
        queries,
        choices,
        scorer=scorer,
        processor=processor,
        score_cutoff=score_cutoff,
        score_hint=score_hint,
        score_multiplier=score_multiplier,
        dtype=dtype,
        workers=workers,
        **kwargs,
    )
    return np.asarray(distance_matrix)
