import ast
import functools
import re

from . import islhelper

from .islhelper import mainctx, libisl, isl_set_basic_sets
from .linexprs import Expression, Symbol, symbolnames


__all__ = [
    'Domain',
    'And', 'Or', 'Not',
]


@functools.total_ordering
class Domain:

    __slots__ = (
        '_polyhedra',
        '_symbols',
        '_dimension',
    )

    def __new__(cls, *polyhedra):
        from .polyhedra import Polyhedron
        if len(polyhedra) == 1:
            polyhedron = polyhedra[0]
            if isinstance(polyhedron, str):
                return cls.fromstring(polyhedron)
            elif isinstance(polyhedron, Polyhedron):
                return polyhedron
            else:
                raise TypeError('argument must be a string '
                    'or a Polyhedron instance')
        else:
            for polyhedron in polyhedra:
                if not isinstance(polyhedron, Polyhedron):
                    raise TypeError('arguments must be Polyhedron instances')
            symbols = cls._xsymbols(polyhedra)
            islset = cls._toislset(polyhedra, symbols)
            return cls._fromislset(islset, symbols)

    @classmethod
    def _xsymbols(cls, iterator):
        """
        Return the ordered tuple of symbols present in iterator.
        """
        symbols = set()
        for item in iterator:
            symbols.update(item.symbols)
        return tuple(sorted(symbols))

    @property
    def polyhedra(self):
        return self._polyhedra

    @property
    def symbols(self):
        return self._symbols

    @property
    def dimension(self):
        return self._dimension

    def disjoint(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_make_disjoint(mainctx, islset)
        return self._fromislset(islset, self.symbols)

    def isempty(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        empty = bool(libisl.isl_set_is_empty(islset))
        libisl.isl_set_free(islset)
        return empty

    def __bool__(self):
        return not self.isempty()

    def isuniverse(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        universe = bool(libisl.isl_set_plain_is_universe(islset))
        libisl.isl_set_free(islset)
        return universe

    def isbounded(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        bounded = bool(libisl.isl_set_is_bounded(islset))
        libisl.isl_set_free(islset)
        return bounded

    def __eq__(self, other):
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = other._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_equal(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def isdisjoint(self, other):
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = self._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def issubset(self, other):
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = self._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_subset(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def __le__(self, other):
        return self.issubset(other)

    def __lt__(self, other):
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = self._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def complement(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_complement(islset)
        return self._fromislset(islset, self.symbols)

    def __invert__(self):
        return self.complement()

    def simplify(self):
        #does not change anything in any of the examples
        #isl seems to do this naturally
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_remove_redundancies(islset)
        return self._fromislset(islset, self.symbols)

    def polyhedral_hull(self):
        # several types of hull are available
        # polyhedral seems to be the more appropriate, to be checked
        from .polyhedra import Polyhedron
        islset = self._toislset(self.polyhedra, self.symbols)
        islbset = libisl.isl_set_polyhedral_hull(islset)
        return Polyhedron._fromislbasicset(islbset, self.symbols)

    def project_out(self, symbols):
        # use to remove certain variables
        symbols = symbolnames(symbols)
        islset = self._toislset(self.polyhedra, self.symbols)
        # the trick is to walk symbols in reverse order, to avoid index updates
        for index, symbol in reversed(list(enumerate(self.symbols))):
            if symbol in symbols:
                islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index, 1)
        # remaining symbols
        symbols = [symbol for symbol in self.symbols if symbol not in symbols]
        return Domain._fromislset(islset, symbols)

    def sample(self):
        from .polyhedra import Polyhedron
        islset = self._toislset(self.polyhedra, self.symbols)
        islbset = libisl.isl_set_sample(islset)
        return Polyhedron._fromislbasicset(islbset, self.symbols)

    def intersection(self, *others):
        if len(others) == 0:
            return self
        symbols = self._xsymbols((self,) + others)
        islset1 = self._toislset(self.polyhedra, symbols)
        for other in others:
            islset2 = other._toislset(other.polyhedra, symbols)
            islset1 = libisl.isl_set_intersect(islset1, islset2)
        return self._fromislset(islset1, symbols)

    def __and__(self, other):
        return self.intersection(other)

    def union(self, *others):
        if len(others) == 0:
            return self
        symbols = self._xsymbols((self,) + others)
        islset1 = self._toislset(self.polyhedra, symbols)
        for other in others:
            islset2 = other._toislset(other.polyhedra, symbols)
            islset1 = libisl.isl_set_union(islset1, islset2)
        return self._fromislset(islset1, symbols)

    def __or__(self, other):
        return self.union(other)

    def __add__(self, other):
        return self.union(other)

    def difference(self, other):
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = other._toislset(other.polyhedra, symbols)
        islset = libisl.isl_set_subtract(islset1, islset2)
        return self._fromislset(islset, symbols)

    def __sub__(self, other):
        return self.difference(other)

    def lexmin(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_lexmin(islset)
        return self._fromislset(islset, self.symbols)

    def lexmax(self):
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_lexmax(islset)
        return self._fromislset(islset, self.symbols)

    @classmethod
    def _fromislset(cls, islset, symbols):
        from .polyhedra import Polyhedron
        islset = libisl.isl_set_remove_divs(islset)
        islbsets = isl_set_basic_sets(islset)
        libisl.isl_set_free(islset)
        polyhedra = []
        for islbset in islbsets:
            polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
            polyhedra.append(polyhedron)
        if len(polyhedra) == 0:
            from .polyhedra import Empty
            return Empty
        elif len(polyhedra) == 1:
            return polyhedra[0]
        else:
            self = object().__new__(Domain)
            self._polyhedra = tuple(polyhedra)
            self._symbols = cls._xsymbols(polyhedra)
            self._dimension = len(self._symbols)
            return self

    @classmethod
    def _toislset(cls, polyhedra, symbols):
        polyhedron = polyhedra[0]
        islbset = polyhedron._toislbasicset(polyhedron.equalities,
            polyhedron.inequalities, symbols)
        islset1 = libisl.isl_set_from_basic_set(islbset)
        for polyhedron in polyhedra[1:]:
            islbset = polyhedron._toislbasicset(polyhedron.equalities,
                polyhedron.inequalities, symbols)
            islset2 = libisl.isl_set_from_basic_set(islbset)
            islset1 = libisl.isl_set_union(islset1, islset2)
        return islset1

    @classmethod
    def _fromast(cls, node):
        from .polyhedra import Polyhedron
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.UnaryOp):
            domain = cls._fromast(node.operand)
            if isinstance(node.operand, ast.invert):
                return Not(domain)
        elif isinstance(node, ast.BinOp):
            domain1 = cls._fromast(node.left)
            domain2 = cls._fromast(node.right)
            if isinstance(node.op, ast.BitAnd):
                return And(domain1, domain2)
            elif isinstance(node.op, ast.BitOr):
                return Or(domain1, domain2)
        elif isinstance(node, ast.Compare):
            equalities = []
            inequalities = []
            left = Expression._fromast(node.left)
            for i in range(len(node.ops)):
                op = node.ops[i]
                right = Expression._fromast(node.comparators[i])
                if isinstance(op, ast.Lt):
                    inequalities.append(right - left - 1)
                elif isinstance(op, ast.LtE):
                    inequalities.append(right - left)
                elif isinstance(op, ast.Eq):
                    equalities.append(left - right)
                elif isinstance(op, ast.GtE):
                    inequalities.append(left - right)
                elif isinstance(op, ast.Gt):
                    inequalities.append(left - right - 1)
                else:
                    break
                left = right
            else:
                return Polyhedron(equalities, inequalities)
        raise SyntaxError('invalid syntax')

    _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
    _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
    _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
    _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
    _RE_NOT = re.compile(r'\bnot\b|!|¬')
    _RE_NUM_VAR = Expression._RE_NUM_VAR
    _RE_OPERATORS = re.compile(r'(&|\||~)')

    @classmethod
    def fromstring(cls, string):
        # remove curly brackets
        string = cls._RE_BRACES.sub(r'', string)
        # replace '=' by '=='
        string = cls._RE_EQ.sub(r'\1==\2', string)
        # replace 'and', 'or', 'not'
        string = cls._RE_AND.sub(r' & ', string)
        string = cls._RE_OR.sub(r' | ', string)
        string = cls._RE_NOT.sub(r' ~', string)
        # add implicit multiplication operators, e.g. '5x' -> '5*x'
        string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
        # add parentheses to force precedence
        tokens = cls._RE_OPERATORS.split(string)
        for i, token in enumerate(tokens):
            if i % 2 == 0:
                token = '({})'.format(token)
                tokens[i] = token
        string = ''.join(tokens)
        tree = ast.parse(string, 'eval')
        return cls._fromast(tree)

    def __repr__(self):
        assert len(self.polyhedra) >= 2
        strings = [repr(polyhedron) for polyhedron in self.polyhedra]
        return 'Or({})'.format(', '.join(strings))

    @classmethod
    def fromsympy(cls, expr):
        raise NotImplementedError

    def tosympy(self):
        raise NotImplementedError


def And(*domains):
    if len(domains) == 0:
        from .polyhedra import Universe
        return Universe
    else:
        return domains[0].intersection(*domains[1:])

def Or(*domains):
    if len(domains) == 0:
        from .polyhedra import Empty
        return Empty
    else:
        return domains[0].union(*domains[1:])

def Not(domain):
    return ~domain
