from __future__ import annotations

import ast
from abc import ABC, abstractmethod
import typing

if typing.TYPE_CHECKING:
    from collections.abc import Collection, Iterable


class SetDefinitions:
    """ A collection of set definitions, where each set is defined by an id, a
    name, its supersets, and the sets that are disjoint with it.  This object
    is used as a factory to create set expressions, which are combinations of
    named sets with union, intersection and complement.
    """
    __slots__ = ('__leaves',)

    def __init__(self, definitions: dict[int, dict]):
        """ Initialize the object with ``definitions``, a dict which maps each
        set id to a dict with optional keys ``"ref"`` (value is the set's name),
        ``"supersets"`` (value is a collection of set ids), and ``"disjoints"``
        (value is a collection of set ids).

        Here is an example of set definitions, with natural numbers (N), integer
        numbers (Z), rational numbers (Q), real numbers (R), imaginary numbers
        (I) and complex numbers (C)::

            {
                1: {"ref": "N", "supersets": [2]},
                2: {"ref": "Z", "supersets": [3]},
                3: {"ref": "Q", "supersets": [4]},
                4: {"ref": "R", "supersets": [6]},
                5: {"ref": "I", "supersets": [6], "disjoints": [4]},
                6: {"ref": "C"},
            }
        """
        self.__leaves: dict[int | str, Leaf] = {}

        for leaf_id, info in definitions.items():
            ref = info['ref']
            assert ref != '*', "The set reference '*' is reserved for the universal set."
            leaf = Leaf(leaf_id, ref)
            self.__leaves[leaf_id] = leaf
            self.__leaves[ref] = leaf

        # compute transitive closure of subsets and supersets
        subsets = {leaf.id: leaf.subsets for leaf in self.__leaves.values()}
        supersets = {leaf.id: leaf.supersets for leaf in self.__leaves.values()}
        for leaf_id, info in definitions.items():
            for greater_id in info.get('supersets', ()):
                # transitive closure: smaller_ids <= leaf_id <= greater_id <= greater_ids
                smaller_ids = subsets[leaf_id]
                greater_ids = supersets[greater_id]
                for smaller_id in smaller_ids:
                    supersets[smaller_id].update(greater_ids)
                for greater_id in greater_ids:
                    subsets[greater_id].update(smaller_ids)

        # compute transitive closure of disjoint relation
        disjoints = {leaf.id: leaf.disjoints for leaf in self.__leaves.values()}
        for leaf_id, info in definitions.items():
            for distinct_id in info.get('disjoints', set()):
                # all subsets[leaf_id] are disjoint from all subsets[distinct_id]
                left_ids = subsets[leaf_id]
                right_ids = subsets[distinct_id]
                for left_id in left_ids:
                    disjoints[left_id].update(right_ids)
                for right_id in right_ids:
                    disjoints[right_id].update(left_ids)

    @property
    def empty(self) -> SetExpression:
        return EMPTY_UNION

    @property
    def universe(self) -> SetExpression:
        return UNIVERSAL_UNION

    def parse(self, refs: str, raise_if_not_found: bool = True) -> SetExpression:
        """ Return the set expression corresponding to ``refs``

        :param str refs: comma-separated list of set references
            optionally preceded by ``!`` (negative item). The result is
            an union between positive item who intersect every negative
            group.
            (e.g. ``base.group_user,base.group_portal,!base.group_system``)
        """
        positives: list[Leaf] = []
        negatives: list[Leaf] = []
        for xmlid in refs.split(','):
            if xmlid.startswith('!'):
                negatives.append(~self.__get_leaf(xmlid[1:], raise_if_not_found))
            else:
                positives.append(self.__get_leaf(xmlid, raise_if_not_found))

        if positives:
            return Union(Inter([leaf] + negatives) for leaf in positives)
        else:
            return Union([Inter(negatives)])

    def from_ids(self, ids: Iterable[int], keep_subsets: bool = False) -> SetExpression:
        """ Return the set expression corresponding to given set ids. """
        if keep_subsets:
            ids = set(ids)
            ids = [leaf_id for leaf_id in ids if not any((self.__leaves[leaf_id].subsets - {leaf_id}) & ids)]
        return Union(Inter([self.__leaves[leaf_id]]) for leaf_id in ids)

    def from_key(self, key: str) -> SetExpression:
        """ Return the set expression corresponding to the given key. """
        # union_tuple = tuple(tuple(tuple(leaf_id, negative), ...), ...)
        union_tuple = ast.literal_eval(key)
        return Union([
            Inter([
                ~leaf if negative else leaf
                for leaf_id, negative in inter_tuple
                for leaf in [self.__get_leaf(leaf_id, raise_if_not_found=False)]
            ], optimal=True)
            for inter_tuple in union_tuple
        ], optimal=True)

    def get_id(self, ref: LeafIdType) -> LeafIdType | None:
        """ Return a set id from its reference, or ``None`` if it does not exist. """
        if ref == '*':
            return UNIVERSAL_LEAF.id
        leaf = self.__leaves.get(ref)
        return None if leaf is None else leaf.id

    def __get_leaf(self, ref: str | int, raise_if_not_found: bool = True) -> Leaf:
        """ Return the group object from the string.

        :param str ref: the ref of a leaf
        """
        if ref == '*':
            return UNIVERSAL_LEAF
        if not raise_if_not_found and ref not in self.__leaves:
            return Leaf(UnknownId(ref), ref)
        return self.__leaves[ref]


class SetExpression(ABC):
    """ An object that represents a combination of named sets with union,
    intersection and complement.
    """
    @abstractmethod
    def is_empty(self) -> bool:
        """ Returns whether ``self`` is the empty set, that contains nothing. """
        raise NotImplementedError()

    @abstractmethod
    def is_universal(self) -> bool:
        """ Returns whether ``self`` is the universal set, that contains all possible elements. """
        raise NotImplementedError()

    @abstractmethod
    def invert_intersect(self, factor: SetExpression) -> SetExpression | None:
        """ Performs the inverse operation of intersection (a sort of factorization)
        such that: ``self == result & factor``.
        """
        raise NotImplementedError()

    @abstractmethod
    def matches(self, user_group_ids: Iterable[int]) -> bool:
        """ Return whether the given group ids are included to ``self``. """
        raise NotImplementedError()

    @property
    @abstractmethod
    def key(self) -> str:
        """ Return a unique identifier for the expression. """
        raise NotImplementedError()

    @abstractmethod
    def __and__(self, other: SetExpression) -> SetExpression:
        raise NotImplementedError()

    @abstractmethod
    def __or__(self, other: SetExpression) -> SetExpression:
        raise NotImplementedError()

    @abstractmethod
    def __invert__(self) -> SetExpression:
        raise NotImplementedError()

    @abstractmethod
    def __eq__(self, other) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def __le__(self, other: SetExpression) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def __lt__(self, other: SetExpression) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def __hash__(self):
        raise NotImplementedError()


class Union(SetExpression):
    """ Implementation of a set expression, that represents it as a union of
    intersections of named sets or their complement.
    """
    def __init__(self, inters: Iterable[Inter] = (), optimal=False):
        if inters and not optimal:
            inters = self.__combine((), inters)
        self.__inters = sorted(inters, key=lambda inter: inter.key)
        self.__key = str(tuple(inter.key for inter in self.__inters))
        self.__hash = hash(self.__key)

    @property
    def key(self) -> str:
        return self.__key

    @staticmethod
    def __combine(inters: Iterable[Inter], inters_to_add: Iterable[Inter]) -> list[Inter]:
        """ Combine some existing union of intersections with extra intersections. """
        result = list(inters)

        todo = list(inters_to_add)
        while todo:
            inter_to_add = todo.pop()
            if inter_to_add.is_universal():
                return [UNIVERSAL_INTER]
            if inter_to_add.is_empty():
                continue

            for index, inter in enumerate(result):
                merged = inter._union_merge(inter_to_add)
                if merged is not None:
                    result.pop(index)
                    todo.append(merged)
                    break
            else:
                result.append(inter_to_add)

        return result

    def is_empty(self) -> bool:
        """ Returns whether ``self`` is the empty set, that contains nothing. """
        return not self.__inters

    def is_universal(self) -> bool:
        """ Returns whether ``self`` is the universal set, that contains all possible elements. """
        return any(item.is_universal() for item in self.__inters)

    def invert_intersect(self, factor: SetExpression) -> Union | None:
        """ Performs the inverse operation of intersection (a sort of factorization)
        such that: ``self == result & factor``.
        """
        if factor == self:
            return UNIVERSAL_UNION

        rfactor = ~factor
        if rfactor.is_empty() or rfactor.is_universal():
            return None
        rself = ~self

        assert isinstance(rfactor, Union)
        inters = [inter for inter in rself.__inters if inter not in rfactor.__inters]
        if len(rself.__inters) - len(inters) != len(rfactor.__inters):
            # not possible to invert the intersection
            return None

        rself_value = Union(inters)
        return ~rself_value

    def __and__(self, other: SetExpression) -> Union:
        assert isinstance(other, Union)
        if self.is_universal():
            return other
        if other.is_universal():
            return self
        if self.is_empty() or other.is_empty():
            return EMPTY_UNION
        if self == other:
            return self
        return Union(
            self_inter & other_inter
            for self_inter in self.__inters
            for other_inter in other.__inters
        )

    def __or__(self, other: SetExpression) -> Union:
        assert isinstance(other, Union)
        if self.is_empty():
            return other
        if other.is_empty():
            return self
        if self.is_universal() or other.is_universal():
            return UNIVERSAL_UNION
        if self == other:
            return self
        inters = self.__combine(self.__inters, other.__inters)
        return Union(inters, optimal=True)

    def __invert__(self) -> Union:
        if self.is_empty():
            return UNIVERSAL_UNION
        if self.is_universal():
            return EMPTY_UNION

        # apply De Morgan's laws
        inverses_of_inters = [
            # ~(A & B) = ~A | ~B
            Union(Inter([~leaf]) for leaf in inter.leaves)
            for inter in self.__inters
        ]
        result = inverses_of_inters[0]
        # ~(A | B) = ~A & ~B
        for inverse in inverses_of_inters[1:]:
            result = result & inverse

        return result

    def matches(self, user_group_ids) -> bool:
        if self.is_empty() or not user_group_ids:
            return False
        if self.is_universal():
            return True
        user_group_ids = set(user_group_ids)
        return any(inter.matches(user_group_ids) for inter in self.__inters)

    def __bool__(self):
        raise NotImplementedError()

    def __eq__(self, other) -> bool:
        return isinstance(other, Union) and self.__key == other.__key

    def __le__(self, other: SetExpression) -> bool:
        if not isinstance(other, Union):
            return False
        if self.__key == other.__key:
            return True
        if self.is_universal() or other.is_empty():
            return False
        if other.is_universal() or self.is_empty():
            return True
        return all(
            any(self_inter <= other_inter for other_inter in other.__inters)
            for self_inter in self.__inters
        )

    def __lt__(self, other: SetExpression) -> bool:
        return self != other and self.__le__(other)

    def __str__(self):
        """ Returns an intersection union representation of groups using user-readable references.

            e.g. ('base.group_user' & 'base.group_multi_company') | ('base.group_portal' & ~'base.group_multi_company') | 'base.group_public'
        """
        if self.is_empty():
            return "~*"

        def leaf_to_str(leaf):
            return f"{'~' if leaf.negative else ''}{leaf.ref!r}"

        def inter_to_str(inter, wrapped=False):
            result = " & ".join(leaf_to_str(leaf) for leaf in inter.leaves) or "*"
            return f"({result})" if wrapped and len(inter.leaves) > 1 else result

        wrapped = len(self.__inters) > 1
        return " | ".join(inter_to_str(inter, wrapped) for inter in self.__inters)

    def __repr__(self):
        return repr(self.__str__())

    def __hash__(self):
        return self.__hash


class Inter:
    """ Part of the implementation of a set expression, that represents an
    intersection of named sets or their complement.
    """
    __slots__ = ('key', 'leaves')

    def __init__(self, leaves: Iterable[Leaf] = (), optimal=False):
        if leaves and not optimal:
            leaves = self.__combine((), leaves)
        self.leaves: list[Leaf] = sorted(leaves, key=lambda leaf: leaf.key)
        self.key: tuple[tuple[LeafIdType, bool], ...] = tuple(leaf.key for leaf in self.leaves)

    @staticmethod
    def __combine(leaves: Iterable[Leaf], leaves_to_add: Iterable[Leaf]) -> list[Leaf]:
        """ Combine some existing intersection of leaves with extra leaves. """
        result = list(leaves)
        for leaf_to_add in leaves_to_add:
            for index, leaf in enumerate(result):
                if leaf.isdisjoint(leaf_to_add):  # leaf & leaf_to_add = empty
                    return [EMPTY_LEAF]
                if leaf <= leaf_to_add:  # leaf & leaf_to_add = leaf
                    break
                if leaf_to_add <= leaf:  # leaf & leaf_to_add = leaf_to_add
                    result[index] = leaf_to_add
                    break
            else:
                if not leaf_to_add.is_universal():
                    result.append(leaf_to_add)
        return result

    def is_empty(self) -> bool:
        return any(item.is_empty() for item in self.leaves)

    def is_universal(self) -> bool:
        """ Returns whether ``self`` is the universal set, that contains all possible elements. """
        return not self.leaves

    def matches(self, user_group_ids) -> bool:
        return all(leaf.matches(user_group_ids) for leaf in self.leaves)

    def _union_merge(self, other: Inter) -> Inter | None:
        """ Return the union of ``self`` with another intersection, if it can be
        represented as an intersection. Otherwise return ``None``.
        """
        # the following covers cases like (A & B) | A -> A
        if self.is_universal() or other <= self:
            return self
        if self <= other:
            return other

        # combine complementary parts: (A & ~B) | (A & B) -> A
        if len(self.leaves) == len(other.leaves):
            opposite_index = None
            # we use the property that __leaves are ordered
            for index, self_leaf, other_leaf in zip(range(len(self.leaves)), self.leaves, other.leaves):
                if self_leaf.id != other_leaf.id:
                    return None
                if self_leaf.negative != other_leaf.negative:
                    if opposite_index is not None:
                        return None  # we already have two opposite leaves
                    opposite_index = index
            if opposite_index is not None:
                leaves = list(self.leaves)
                leaves.pop(opposite_index)
                return Inter(leaves, optimal=True)
        return None

    def __and__(self, other: Inter) -> Inter:
        if self.is_empty() or other.is_empty():
            return EMPTY_INTER
        if self.is_universal():
            return other
        if other.is_universal():
            return self
        leaves = self.__combine(self.leaves, other.leaves)
        return Inter(leaves, optimal=True)

    def __eq__(self, other) -> bool:
        return isinstance(other, Inter) and self.key == other.key

    def __le__(self, other: Inter) -> bool:
        return self.key == other.key or all(
            any(self_leaf <= other_leaf for self_leaf in self.leaves)
            for other_leaf in other.leaves
        )

    def __lt__(self, other: Inter) -> bool:
        return self != other and self <= other

    def __hash__(self):
        return hash(self.key)


class Leaf:
    """ Part of the implementation of a set expression, that represents a named
    set or its complement.
    """
    __slots__ = ('disjoints', 'id', 'inverse', 'key', 'negative', 'ref', 'subsets', 'supersets')

    def __init__(self, leaf_id: LeafIdType, ref: str | int | None = None, negative: bool = False):
        self.id = leaf_id
        self.ref = ref or str(leaf_id)
        self.negative = bool(negative)
        self.key: tuple[LeafIdType, bool] = (leaf_id, self.negative)

        self.subsets: set[LeafIdType] = {leaf_id}       # all the leaf ids that are <= self
        self.supersets: set[LeafIdType] = {leaf_id}     # all the leaf ids that are >= self
        self.disjoints: set[LeafIdType] = set()         # all the leaf ids disjoint from self
        self.inverse: Leaf | None = None

    def __invert__(self) -> Leaf:
        if self.inverse is None:
            self.inverse = Leaf(self.id, self.ref, negative=not self.negative)
            self.inverse.inverse = self
            self.inverse.subsets = self.subsets
            self.inverse.supersets = self.supersets
            self.inverse.disjoints = self.disjoints
        return self.inverse

    def is_empty(self) -> bool:
        return self.ref == '*' and self.negative

    def is_universal(self) -> bool:
        return self.ref == '*' and not self.negative

    def isdisjoint(self, other: Leaf) -> bool:
        if self.negative:
            return other <= ~self
        elif other.negative:
            return self <= ~other
        else:
            return self.id in other.disjoints

    def matches(self, user_group_ids: Collection[int]) -> bool:
        return (self.id not in user_group_ids) if self.negative else (self.id in user_group_ids)

    def __eq__(self, other) -> bool:
        return isinstance(other, Leaf) and self.key == other.key

    def __le__(self, other: Leaf) -> bool:
        if self.is_empty() or other.is_universal():
            return True
        elif self.is_universal() or other.is_empty():
            return False
        elif self.negative:
            return other.negative and ~other <= ~self
        elif other.negative:
            return self.id in other.disjoints
        else:
            return self.id in other.subsets

    def __lt__(self, other: Leaf) -> bool:
        return self != other and self <= other

    def __hash__(self):
        return hash(self.key)


class UnknownId(str):
    """ Special id object for unknown leaves.  It behaves as being strictly
    greater than any other kind of id.
    """
    __slots__ = ()

    def __lt__(self, other) -> bool:
        if isinstance(other, UnknownId):
            return super().__lt__(other)
        return False

    def __gt__(self, other) -> bool:
        if isinstance(other, UnknownId):
            return super().__gt__(other)
        return True


LeafIdType = int | typing.Literal["*"] | UnknownId

# constants
UNIVERSAL_LEAF = Leaf('*')
EMPTY_LEAF = ~UNIVERSAL_LEAF

EMPTY_INTER = Inter([EMPTY_LEAF])
UNIVERSAL_INTER = Inter()

EMPTY_UNION = Union()
UNIVERSAL_UNION = Union([UNIVERSAL_INTER])
