# Copyright 2014 MINES ParisTech
#
# This file is part of LinPy.
#
# LinPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# LinPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with LinPy.  If not, see <http://www.gnu.org/licenses/>.

import ast
import functools
import numbers
import re

from collections import OrderedDict, defaultdict, Mapping
from fractions import Fraction, gcd


__all__ = [
    'LinExpr',
    'Symbol', 'Dummy', 'symbols',
    'Rational',
]


def _polymorphic(func):
    @functools.wraps(func)
    def wrapper(left, right):
        if isinstance(right, LinExpr):
            return func(left, right)
        elif isinstance(right, numbers.Rational):
            right = Rational(right)
            return func(left, right)
        return NotImplemented
    return wrapper


class LinExpr:
    """
    A linear expression consists of a list of coefficient-variable pairs
    that capture the linear terms, plus a constant term. Linear expressions
    are used to build constraints. They are temporary objects that typically
    have short lifespans.

    Linear expressions are generally built using overloaded operators. For
    example, if x is a Symbol, then x + 1 is an instance of LinExpr.

    LinExpr instances are hashable, and should be treated as immutable.
    """

    def __new__(cls, coefficients=None, constant=0):
        """
        Return a linear expression from a dictionary or a sequence, that maps
        symbols to their coefficients, and a constant term. The coefficients and
        the constant term must be rational numbers.

        For example, the linear expression x + 2y + 1 can be constructed using
        one of the following instructions:

        >>> x, y = symbols('x y')
        >>> LinExpr({x: 1, y: 2}, 1)
        >>> LinExpr([(x, 1), (y, 2)], 1)

        However, it may be easier to use overloaded operators:

        >>> x, y = symbols('x y')
        >>> x + 2*y + 1

        Alternatively, linear expressions can be constructed from a string:

        >>> LinExpr('x + 2*y + 1')

        A linear expression with a single symbol of coefficient 1 and no
        constant term is automatically subclassed as a Symbol instance. A linear
        expression with no symbol, only a constant term, is automatically
        subclassed as a Rational instance.
        """
        if isinstance(coefficients, str):
            if constant != 0:
                raise TypeError('too many arguments')
            return LinExpr.fromstring(coefficients)
        if coefficients is None:
            return Rational(constant)
        if isinstance(coefficients, Mapping):
            coefficients = coefficients.items()
        coefficients = list(coefficients)
        for symbol, coefficient in coefficients:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
            if not isinstance(coefficient, numbers.Rational):
                raise TypeError('coefficients must be rational numbers')
        if not isinstance(constant, numbers.Rational):
            raise TypeError('constant must be a rational number')
        if len(coefficients) == 0:
            return Rational(constant)
        if len(coefficients) == 1 and constant == 0:
            symbol, coefficient = coefficients[0]
            if coefficient == 1:
                return symbol
        coefficients = [(symbol, Fraction(coefficient))
            for symbol, coefficient in coefficients if coefficient != 0]
        coefficients.sort(key=lambda item: item[0].sortkey())
        self = object().__new__(cls)
        self._coefficients = OrderedDict(coefficients)
        self._constant = Fraction(constant)
        self._symbols = tuple(self._coefficients)
        self._dimension = len(self._symbols)
        return self

    def coefficient(self, symbol):
        """
        Return the coefficient value of the given symbol, or 0 if the symbol
        does not appear in the expression.
        """
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        return self._coefficients.get(symbol, Fraction(0))

    __getitem__ = coefficient

    def coefficients(self):
        """
        Iterate over the pairs (symbol, value) of linear terms in the
        expression. The constant term is ignored.
        """
        yield from self._coefficients.items()

    @property
    def constant(self):
        """
        The constant term of the expression.
        """
        return self._constant

    @property
    def symbols(self):
        """
        The tuple of symbols present in the expression, sorted according to
        Symbol.sortkey().
        """
        return self._symbols

    @property
    def dimension(self):
        """
        The dimension of the expression, i.e. the number of symbols present in
        it.
        """
        return self._dimension

    def __hash__(self):
        return hash((tuple(self._coefficients.items()), self._constant))

    def isconstant(self):
        """
        Return True if the expression only consists of a constant term. In this
        case, it is a Rational instance.
        """
        return False

    def issymbol(self):
        """
        Return True if an expression only consists of a symbol with coefficient
        1. In this case, it is a Symbol instance.
        """
        return False

    def values(self):
        """
        Iterate over the coefficient values in the expression, and the constant
        term.
        """
        yield from self._coefficients.values()
        yield self._constant

    def __bool__(self):
        return True

    def __pos__(self):
        return self

    def __neg__(self):
        return self * -1

    @_polymorphic
    def __add__(self, other):
        """
        Return the sum of two linear expressions.
        """
        coefficients = defaultdict(Fraction, self._coefficients)
        for symbol, coefficient in other._coefficients.items():
            coefficients[symbol] += coefficient
        constant = self._constant + other._constant
        return LinExpr(coefficients, constant)

    __radd__ = __add__

    @_polymorphic
    def __sub__(self, other):
        """
        Return the difference between two linear expressions.
        """
        coefficients = defaultdict(Fraction, self._coefficients)
        for symbol, coefficient in other._coefficients.items():
            coefficients[symbol] -= coefficient
        constant = self._constant - other._constant
        return LinExpr(coefficients, constant)

    @_polymorphic
    def __rsub__(self, other):
        return other - self

    def __mul__(self, other):
        """
        Return the product of the linear expression by a rational.
        """
        if isinstance(other, numbers.Rational):
            coefficients = ((symbol, coefficient * other)
                for symbol, coefficient in self._coefficients.items())
            constant = self._constant * other
            return LinExpr(coefficients, constant)
        return NotImplemented

    __rmul__ = __mul__

    def __truediv__(self, other):
        """
        Return the quotient of the linear expression by a rational.
        """
        if isinstance(other, numbers.Rational):
            coefficients = ((symbol, coefficient / other)
                for symbol, coefficient in self._coefficients.items())
            constant = self._constant / other
            return LinExpr(coefficients, constant)
        return NotImplemented

    @_polymorphic
    def __eq__(self, other):
        """
        Test whether two linear expressions are equal.
        """
        if isinstance(other, LinExpr):
            return self._coefficients == other._coefficients and \
                self._constant == other._constant
        return NotImplemented

    def __le__(self, other):
        from .polyhedra import Le
        return Le(self, other)

    def __lt__(self, other):
        from .polyhedra import Lt
        return Lt(self, other)

    def __ge__(self, other):
        from .polyhedra import Ge
        return Ge(self, other)

    def __gt__(self, other):
        from .polyhedra import Gt
        return Gt(self, other)

    def scaleint(self):
        """
        Return the expression multiplied by its lowest common denominator to
        make all values integer.
        """
        lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
            [value.denominator for value in self.values()])
        return self * lcd

    def subs(self, symbol, expression=None):
        """
        Substitute the given symbol by an expression and return the resulting
        expression. Raise TypeError if the resulting expression is not linear.

        >>> x, y = symbols('x y')
        >>> e = x + 2*y + 1
        >>> e.subs(y, x - 1)
        3*x - 1

        To perform multiple substitutions at once, pass a sequence or a
        dictionary of (old, new) pairs to subs.

        >>> e.subs({x: y, y: x})
        2*x + y + 1
        """
        if expression is None:
            substitutions = dict(symbol)
        else:
            substitutions = {symbol: expression}
        for symbol in substitutions:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
        result = self._constant
        for symbol, coefficient in self._coefficients.items():
            expression = substitutions.get(symbol, symbol)
            result += coefficient * expression
        return result

    @classmethod
    def _fromast(cls, node):
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.Name):
            return Symbol(node.id)
        elif isinstance(node, ast.Num):
            return Rational(node.n)
        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
            return -cls._fromast(node.operand)
        elif isinstance(node, ast.BinOp):
            left = cls._fromast(node.left)
            right = cls._fromast(node.right)
            if isinstance(node.op, ast.Add):
                return left + right
            elif isinstance(node.op, ast.Sub):
                return left - right
            elif isinstance(node.op, ast.Mult):
                return left * right
            elif isinstance(node.op, ast.Div):
                return left / right
        raise SyntaxError('invalid syntax')

    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()')

    @classmethod
    def fromstring(cls, string):
        """
        Create an expression from a string. Raise SyntaxError if the string is
        not properly formatted.
        """
        # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
        string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
        tree = ast.parse(string, 'eval')
        expr = cls._fromast(tree)
        if not isinstance(expr, cls):
            raise SyntaxError('invalid syntax')
        return expr

    def __repr__(self):
        string = ''
        for i, (symbol, coefficient) in enumerate(self.coefficients()):
            if coefficient == 1:
                if i != 0:
                    string += ' + '
            elif coefficient == -1:
                string += '-' if i == 0 else ' - '
            elif i == 0:
                string += '{}*'.format(coefficient)
            elif coefficient > 0:
                string += ' + {}*'.format(coefficient)
            else:
                string += ' - {}*'.format(-coefficient)
            string += '{}'.format(symbol)
        constant = self.constant
        if len(string) == 0:
            string += '{}'.format(constant)
        elif constant > 0:
            string += ' + {}'.format(constant)
        elif constant < 0:
            string += ' - {}'.format(-constant)
        return string

    def _repr_latex_(self):
        string = ''
        for i, (symbol, coefficient) in enumerate(self.coefficients()):
            if coefficient == 1:
                if i != 0:
                    string += ' + '
            elif coefficient == -1:
                string += '-' if i == 0 else ' - '
            elif i == 0:
                string += '{}'.format(coefficient._repr_latex_().strip('$'))
            elif coefficient > 0:
                string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
            elif coefficient < 0:
                string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
            string += '{}'.format(symbol._repr_latex_().strip('$'))
        constant = self.constant
        if len(string) == 0:
            string += '{}'.format(constant._repr_latex_().strip('$'))
        elif constant > 0:
            string += ' + {}'.format(constant._repr_latex_().strip('$'))
        elif constant < 0:
            string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
        return '$${}$$'.format(string)

    def _parenstr(self, always=False):
        string = str(self)
        if not always and (self.isconstant() or self.issymbol()):
            return string
        else:
            return '({})'.format(string)

    @classmethod
    def fromsympy(cls, expr):
        """
        Create a linear expression from a SymPy expression. Raise TypeError is
        the sympy expression is not linear.
        """
        import sympy
        coefficients = []
        constant = 0
        for symbol, coefficient in expr.as_coefficients_dict().items():
            coefficient = Fraction(coefficient.p, coefficient.q)
            if symbol == sympy.S.One:
                constant = coefficient
            elif isinstance(symbol, sympy.Dummy):
                # We cannot properly convert dummy symbols with respect to
                # symbol equalities.
                raise TypeError('cannot convert dummy symbols')
            elif isinstance(symbol, sympy.Symbol):
                symbol = Symbol(symbol.name)
                coefficients.append((symbol, coefficient))
            else:
                raise TypeError('non-linear expression: {!r}'.format(expr))
        expr = LinExpr(coefficients, constant)
        if not isinstance(expr, cls):
            raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
        return expr

    def tosympy(self):
        """
        Convert the linear expression to a SymPy expression.
        """
        import sympy
        expr = 0
        for symbol, coefficient in self.coefficients():
            term = coefficient * sympy.Symbol(symbol.name)
            expr += term
        expr += self.constant
        return expr


class Symbol(LinExpr):
    """
    Symbols are the basic components to build expressions and constraints.
    They correspond to mathematical variables. Symbols are instances of
    class LinExpr and inherit its functionalities.

    Two instances of Symbol are equal if they have the same name.
    """

    __slots__ = (
        '_name',
        '_constant',
        '_symbols',
        '_dimension',
    )

    def __new__(cls, name):
        """
        Return a symbol with the name string given in argument.
        """
        if not isinstance(name, str):
            raise TypeError('name must be a string')
        node = ast.parse(name)
        try:
            name = node.body[0].value.id
        except (AttributeError, SyntaxError):
            raise SyntaxError('invalid syntax')
        self = object().__new__(cls)
        self._name = name
        self._constant = Fraction(0)
        self._symbols = (self,)
        self._dimension = 1
        return self

    @property
    def _coefficients(self):
        # This is not implemented as an attribute, because __hash__ is not
        # callable in __new__ in class Dummy.
        return {self: Fraction(1)}

    @property
    def name(self):
        """
        The name of the symbol.
        """
        return self._name

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        """
        Return a sorting key for the symbol. It is useful to sort a list of
        symbols in a consistent order, as comparison functions are overridden
        (see the documentation of class LinExpr).

        >>> sort(symbols, key=Symbol.sortkey)
        """
        return self.name,

    def issymbol(self):
        return True

    def __eq__(self, other):
        if isinstance(other, Symbol):
            return self.sortkey() == other.sortkey()
        return NotImplemented

    def asdummy(self):
        """
        Return a new Dummy symbol instance with the same name.
        """
        return Dummy(self.name)

    def __repr__(self):
        return self.name

    def _repr_latex_(self):
        return '$${}$$'.format(self.name)


def symbols(names):
    """
    This function returns a tuple of symbols whose names are taken from a comma
    or whitespace delimited string, or a sequence of strings. It is useful to
    define several symbols at once.

    >>> x, y = symbols('x y')
    >>> x, y = symbols('x, y')
    >>> x, y = symbols(['x', 'y'])
    """
    if isinstance(names, str):
        names = names.replace(',', ' ').split()
    return tuple(Symbol(name) for name in names)


class Dummy(Symbol):
    """
    A variation of Symbol in which all symbols are unique and identified by
    an internal count index. If a name is not supplied then a string value
    of the count index will be used. This is useful when a unique, temporary
    variable is needed and the name of the variable used in the expression
    is not important.

    Unlike Symbol, Dummy instances with the same name are not equal:

    >>> x = Symbol('x')
    >>> x1, x2 = Dummy('x'), Dummy('x')
    >>> x == x1
    False
    >>> x1 == x2
    False
    >>> x1 == x1
    True
    """

    _count = 0

    def __new__(cls, name=None):
        """
        Return a fresh dummy symbol with the name string given in argument.
        """
        if name is None:
            name = 'Dummy_{}'.format(Dummy._count)
        self = super().__new__(cls, name)
        self._index = Dummy._count
        Dummy._count += 1
        return self

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        return self._name, self._index

    def __repr__(self):
        return '_{}'.format(self.name)

    def _repr_latex_(self):
        return '$${}_{{{}}}$$'.format(self.name, self._index)


class Rational(LinExpr, Fraction):
    """
    A particular case of linear expressions are rational values, i.e. linear
    expressions consisting only of a constant term, with no symbol. They are
    implemented by the Rational class, that inherits from both LinExpr and
    fractions.Fraction classes.
    """

    __slots__ = (
        '_coefficients',
        '_constant',
        '_symbols',
        '_dimension',
    ) + Fraction.__slots__

    def __new__(cls, numerator=0, denominator=None):
        self = object().__new__(cls)
        self._coefficients = {}
        self._constant = Fraction(numerator, denominator)
        self._symbols = ()
        self._dimension = 0
        self._numerator = self._constant.numerator
        self._denominator = self._constant.denominator
        return self

    def __hash__(self):
        return Fraction.__hash__(self)

    @property
    def constant(self):
        return self

    def isconstant(self):
        return True

    def __bool__(self):
        return Fraction.__bool__(self)

    def __repr__(self):
        if self.denominator == 1:
            return '{!r}'.format(self.numerator)
        else:
            return '{!r}/{!r}'.format(self.numerator, self.denominator)

    def _repr_latex_(self):
        if self.denominator == 1:
            return '$${}$$'.format(self.numerator)
        elif self.numerator < 0:
            return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
                self.denominator)
        else:
            return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
                self.denominator)
