# Copyright (c) 2009-2015 testtools developers. See LICENSE for details.

"""Matchers that operate with knowledge of Python data structures."""

from ..helpers import map_values
from ._higherorder import (
    Annotate,
    MatchesAll,
    MismatchesAll,
)
from ._impl import Mismatch

__all__ = [
    "ContainsAll",
    "MatchesListwise",
    "MatchesSetwise",
    "MatchesStructure",
]


def ContainsAll(items):
    """Make a matcher that checks whether a list of things is contained
    in another thing.

    The matcher effectively checks that the provided sequence is a subset of
    the matchee.
    """
    from ._basic import Contains

    return MatchesAll(*map(Contains, items), first_only=False)


class MatchesListwise:
    """Matches if each matcher matches the corresponding value.

    More easily explained by example than in words:

    >>> from ._basic import Equals
    >>> MatchesListwise([Equals(1)]).match([1])
    >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
    >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
    Differences: [
    2 != 1
    1 != 2
    ]
    >>> matcher = MatchesListwise([Equals(1), Equals(2)], first_only=True)
    >>> print (matcher.match([3, 4]).describe())
    3 != 1
    """

    def __init__(self, matchers, first_only=False):
        """Construct a MatchesListwise matcher.

        :param matchers: A list of matcher that the matched values must match.
        :param first_only: If True, then only report the first mismatch,
            otherwise report all of them. Defaults to False.
        """
        self.matchers = matchers
        self.first_only = first_only

    def match(self, values):
        from ._basic import HasLength

        mismatches = []
        length_mismatch = Annotate(
            "Length mismatch", HasLength(len(self.matchers))
        ).match(values)
        if length_mismatch:
            mismatches.append(length_mismatch)
        for matcher, value in zip(self.matchers, values):
            mismatch = matcher.match(value)
            if mismatch:
                if self.first_only:
                    return mismatch
                mismatches.append(mismatch)
        if mismatches:
            return MismatchesAll(mismatches)


class MatchesStructure:
    """Matcher that matches an object structurally.

    'Structurally' here means that attributes of the object being matched are
    compared against given matchers.

    `fromExample` allows the creation of a matcher from a prototype object and
    then modified versions can be created with `update`.

    `byEquality` creates a matcher in much the same way as the constructor,
    except that the matcher for each of the attributes is assumed to be
    `Equals`.

    `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
    the matcher, rather than just using `Equals`.
    """

    def __init__(self, **kwargs):
        """Construct a `MatchesStructure`.

        :param kwargs: A mapping of attributes to matchers.
        """
        self.kws = kwargs

    @classmethod
    def byEquality(cls, **kwargs):
        """Matches an object where the attributes equal the keyword values.

        Similar to the constructor, except that the matcher is assumed to be
        Equals.
        """
        from ._basic import Equals

        return cls.byMatcher(Equals, **kwargs)

    @classmethod
    def byMatcher(cls, matcher, **kwargs):
        """Matches an object where the attributes match the keyword values.

        Similar to the constructor, except that the provided matcher is used
        to match all of the values.
        """
        return cls(**map_values(matcher, kwargs))

    @classmethod
    def fromExample(cls, example, *attributes):
        from ._basic import Equals

        kwargs = {}
        for attr in attributes:
            kwargs[attr] = Equals(getattr(example, attr))
        return cls(**kwargs)

    def update(self, **kws):
        new_kws = self.kws.copy()
        for attr, matcher in kws.items():
            if matcher is None:
                new_kws.pop(attr, None)
            else:
                new_kws[attr] = matcher
        return type(self)(**new_kws)

    def __str__(self):
        kws = []
        for attr, matcher in sorted(self.kws.items()):
            kws.append(f"{attr}={matcher}")
        return "{}({})".format(self.__class__.__name__, ", ".join(kws))

    def match(self, value):
        matchers = []
        values = []
        for attr, matcher in sorted(self.kws.items()):
            matchers.append(Annotate(attr, matcher))
            values.append(getattr(value, attr))
        return MatchesListwise(matchers).match(values)


class MatchesSetwise:
    """Matches if all the matchers match elements of the value being matched.

    That is, each element in the 'observed' set must match exactly one matcher
    from the set of matchers, with no matchers left over.

    The difference compared to `MatchesListwise` is that the order of the
    matchings does not matter.
    """

    def __init__(self, *matchers):
        self.matchers = matchers

    def match(self, observed):
        remaining_matchers = set(self.matchers)
        not_matched = []
        for value in observed:
            for matcher in remaining_matchers:
                if matcher.match(value) is None:
                    remaining_matchers.remove(matcher)
                    break
            else:
                not_matched.append(value)
        if not_matched or remaining_matchers:
            remaining_matchers_list = list(remaining_matchers)
            # There are various cases that all should be reported somewhat
            # differently.

            # There are two trivial cases:
            # 1) There are just some matchers left over.
            # 2) There are just some values left over.

            # Then there are three more interesting cases:
            # 3) There are the same number of matchers and values left over.
            # 4) There are more matchers left over than values.
            # 5) There are more values left over than matchers.

            if len(not_matched) == 0:
                if len(remaining_matchers_list) > 1:
                    count = len(remaining_matchers_list)
                    msg = f"There were {count} matchers left over: "
                else:
                    msg = "There was 1 matcher left over: "
                msg += ", ".join(map(str, remaining_matchers_list))
                return Mismatch(msg)
            elif len(remaining_matchers_list) == 0:
                if len(not_matched) > 1:
                    return Mismatch(
                        f"There were {len(not_matched)} values left over: {not_matched}"
                    )
                else:
                    return Mismatch(f"There was 1 value left over: {not_matched}")
            else:
                common_length = min(len(remaining_matchers_list), len(not_matched))
                if common_length == 0:
                    raise AssertionError("common_length can't be 0 here")
                if common_length > 1:
                    msg = f"There were {common_length} mismatches"
                else:
                    msg = "There was 1 mismatch"
                if len(remaining_matchers_list) > len(not_matched):
                    extra_matchers = remaining_matchers_list[common_length:]
                    msg += f" and {len(extra_matchers)} extra matcher"
                    if len(extra_matchers) > 1:
                        msg += "s"
                    msg += ": " + ", ".join(map(str, extra_matchers))
                elif len(not_matched) > len(remaining_matchers_list):
                    extra_values = not_matched[common_length:]
                    msg += f" and {len(extra_values)} extra value"
                    if len(extra_values) > 1:
                        msg += "s"
                    msg += ": " + str(extra_values)
                return Annotate(
                    msg, MatchesListwise(remaining_matchers_list[:common_length])
                ).match(not_matched[:common_length])
