# Copyright 2014 MINES ParisTech
#
# This file is part of LinPy.
#
# LinPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# LinPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with LinPy.  If not, see <http://www.gnu.org/licenses/>.

import ast
import functools
import numbers
import re

from collections import OrderedDict, defaultdict, Mapping
from fractions import Fraction, gcd


__all__ = [
    'Expression',
    'Symbol', 'Dummy', 'symbols',
    'Rational',
]


def _polymorphic(func):
    @functools.wraps(func)
    def wrapper(left, right):
        if isinstance(right, Expression):
            return func(left, right)
        elif isinstance(right, numbers.Rational):
            right = Rational(right)
            return func(left, right)
        return NotImplemented
    return wrapper


class Expression:
    """
    This class implements linear expressions.
    """

    def __new__(cls, coefficients=None, constant=0):
        """
        Create a new expression.
        """
        if isinstance(coefficients, str):
            if constant != 0:
                raise TypeError('too many arguments')
            return Expression.fromstring(coefficients)
        if coefficients is None:
            return Rational(constant)
        if isinstance(coefficients, Mapping):
            coefficients = coefficients.items()
        coefficients = list(coefficients)
        for symbol, coefficient in coefficients:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
            if not isinstance(coefficient, numbers.Rational):
                raise TypeError('coefficients must be rational numbers')
        if not isinstance(constant, numbers.Rational):
            raise TypeError('constant must be a rational number')
        if len(coefficients) == 0:
            return Rational(constant)
        if len(coefficients) == 1 and constant == 0:
            symbol, coefficient = coefficients[0]
            if coefficient == 1:
                return symbol
        coefficients = [(symbol, Fraction(coefficient))
            for symbol, coefficient in coefficients if coefficient != 0]
        coefficients.sort(key=lambda item: item[0].sortkey())
        self = object().__new__(cls)
        self._coefficients = OrderedDict(coefficients)
        self._constant = Fraction(constant)
        self._symbols = tuple(self._coefficients)
        self._dimension = len(self._symbols)
        return self

    def coefficient(self, symbol):
        """
        Return the coefficient value of the given symbol.
        """
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        return Rational(self._coefficients.get(symbol, 0))

    __getitem__ = coefficient

    def coefficients(self):
        """
        Return a list of the coefficients of an expression
        """
        for symbol, coefficient in self._coefficients.items():
            yield symbol, Rational(coefficient)

    @property
    def constant(self):
        """
        Return the constant value of an expression.
        """
        return Rational(self._constant)

    @property
    def symbols(self):
        """
        Return a list of symbols in an expression.
        """
        return self._symbols

    @property
    def dimension(self):
        """
        Create and return a new linear expression from a string or a list of coefficients and a constant.
        """
        return self._dimension

    def __hash__(self):
        return hash((tuple(self._coefficients.items()), self._constant))

    def isconstant(self):
        """
        Return true if an expression is a constant.
        """
        return False

    def issymbol(self):
        """
        Return true if an expression is a symbol.
        """
        return False

    def values(self):
        """
        Return the coefficient and constant values of an expression.
        """
        for coefficient in self._coefficients.values():
            yield Rational(coefficient)
        yield Rational(self._constant)

    def __bool__(self):
        return True

    def __pos__(self):
        return self

    def __neg__(self):
        return self * -1

    @_polymorphic
    def __add__(self, other):
        """
        Return the sum of two expressions.
        """
        coefficients = defaultdict(Fraction, self._coefficients)
        for symbol, coefficient in other._coefficients.items():
            coefficients[symbol] += coefficient
        constant = self._constant + other._constant
        return Expression(coefficients, constant)

    __radd__ = __add__

    @_polymorphic
    def __sub__(self, other):
        """
        Return the difference between two expressions.
        """
        coefficients = defaultdict(Fraction, self._coefficients)
        for symbol, coefficient in other._coefficients.items():
            coefficients[symbol] -= coefficient
        constant = self._constant - other._constant
        return Expression(coefficients, constant)

    @_polymorphic
    def __rsub__(self, other):
        return other - self

    def __mul__(self, other):
        """
        Return the product of two expressions if other is a rational number.
        """
        if isinstance(other, numbers.Rational):
            coefficients = ((symbol, coefficient * other)
                for symbol, coefficient in self._coefficients.items())
            constant = self._constant * other
            return Expression(coefficients, constant)
        return NotImplemented

    __rmul__ = __mul__

    def __truediv__(self, other):
        if isinstance(other, numbers.Rational):
            coefficients = ((symbol, coefficient / other)
                for symbol, coefficient in self._coefficients.items())
            constant = self._constant / other
            return Expression(coefficients, constant)
        return NotImplemented

    @_polymorphic
    def __eq__(self, other):
        """
        Test whether two expressions are equal
        """
        return isinstance(other, Expression) and \
            self._coefficients == other._coefficients and \
            self._constant == other._constant

    def __le__(self, other):
        from .polyhedra import Le
        return Le(self, other)

    def __lt__(self, other):
        from .polyhedra import Lt
        return Lt(self, other)

    def __ge__(self, other):
        from .polyhedra import Ge
        return Ge(self, other)

    def __gt__(self, other):
        from .polyhedra import Gt
        return Gt(self, other)

    def scaleint(self):
        """
        Multiply an expression by a scalar to make all coefficients integer values.
        """
        lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
            [value.denominator for value in self.values()])
        return self * lcm

    def subs(self, symbol, expression=None):
        """
        Subsitute symbol by expression in equations and return the resulting
        expression.
        """
        if expression is None:
            if isinstance(symbol, Mapping):
                symbol = symbol.items()
            substitutions = symbol
        else:
            substitutions = [(symbol, expression)]
        result = self
        for symbol, expression in substitutions:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
            coefficients = [(othersymbol, coefficient)
                for othersymbol, coefficient in result._coefficients.items()
                if othersymbol != symbol]
            coefficient = result._coefficients.get(symbol, 0)
            constant = result._constant
            result = Expression(coefficients, constant) + coefficient*expression
        return result

    @classmethod
    def _fromast(cls, node):
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.Name):
            return Symbol(node.id)
        elif isinstance(node, ast.Num):
            return Rational(node.n)
        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
            return -cls._fromast(node.operand)
        elif isinstance(node, ast.BinOp):
            left = cls._fromast(node.left)
            right = cls._fromast(node.right)
            if isinstance(node.op, ast.Add):
                return left + right
            elif isinstance(node.op, ast.Sub):
                return left - right
            elif isinstance(node.op, ast.Mult):
                return left * right
            elif isinstance(node.op, ast.Div):
                return left / right
        raise SyntaxError('invalid syntax')

    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')

    @classmethod
    def fromstring(cls, string):
        """
        Create an expression from a string.
        """
        # add implicit multiplication operators, e.g. '5x' -> '5*x'
        string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
        tree = ast.parse(string, 'eval')
        return cls._fromast(tree)

    def __repr__(self):
        string = ''
        for i, (symbol, coefficient) in enumerate(self.coefficients()):
            if coefficient == 1:
                if i != 0:
                    string += ' + '
            elif coefficient == -1:
                string += '-' if i == 0 else ' - '
            elif i == 0:
                string += '{}*'.format(coefficient)
            elif coefficient > 0:
                string += ' + {}*'.format(coefficient)
            else:
                string += ' - {}*'.format(-coefficient)
            string += '{}'.format(symbol)
        constant = self.constant
        if len(string) == 0:
            string += '{}'.format(constant)
        elif constant > 0:
            string += ' + {}'.format(constant)
        elif constant < 0:
            string += ' - {}'.format(-constant)
        return string

    def _repr_latex_(self):
        string = ''
        for i, (symbol, coefficient) in enumerate(self.coefficients()):
            if coefficient == 1:
                if i != 0:
                    string += ' + '
            elif coefficient == -1:
                string += '-' if i == 0 else ' - '
            elif i == 0:
                string += '{}'.format(coefficient._repr_latex_().strip('$'))
            elif coefficient > 0:
                string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
            elif coefficient < 0:
                string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
            string += '{}'.format(symbol._repr_latex_().strip('$'))
        constant = self.constant
        if len(string) == 0:
            string += '{}'.format(constant._repr_latex_().strip('$'))
        elif constant > 0:
            string += ' + {}'.format(constant._repr_latex_().strip('$'))
        elif constant < 0:
            string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
        return '$${}$$'.format(string)

    def _parenstr(self, always=False):
        string = str(self)
        if not always and (self.isconstant() or self.issymbol()):
            return string
        else:
            return '({})'.format(string)

    @classmethod
    def fromsympy(cls, expr):
        """
        Convert sympy object to an expression.
        """
        import sympy
        coefficients = []
        constant = 0
        for symbol, coefficient in expr.as_coefficients_dict().items():
            coefficient = Fraction(coefficient.p, coefficient.q)
            if symbol == sympy.S.One:
                constant = coefficient
            elif isinstance(symbol, sympy.Symbol):
                symbol = Symbol(symbol.name)
                coefficients.append((symbol, coefficient))
            else:
                raise ValueError('non-linear expression: {!r}'.format(expr))
        return Expression(coefficients, constant)

    def tosympy(self):
        """
        Return an expression as a sympy object.  
        """
        import sympy
        expr = 0
        for symbol, coefficient in self.coefficients():
            term = coefficient * sympy.Symbol(symbol.name)
            expr += term
        expr += self.constant
        return expr


class Symbol(Expression):

    def __new__(cls, name):
        """
        Create and return a symbol from a string.
        """
        if not isinstance(name, str):
            raise TypeError('name must be a string')
        self = object().__new__(cls)
        self._name = name.strip()
        self._coefficients = {self: Fraction(1)}
        self._constant = Fraction(0)
        self._symbols = (self,)
        self._dimension = 1
        return self

    @property
    def name(self):
        return self._name

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        return self.name,

    def issymbol(self):
        return True

    def __eq__(self, other):
        return self.sortkey() == other.sortkey()

    def asdummy(self):
        """
        Return a symbol as a Dummy Symbol.
        """
        return Dummy(self.name)

    @classmethod
    def _fromast(cls, node):
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.Name):
            return Symbol(node.id)
        raise SyntaxError('invalid syntax')

    def __repr__(self):
        return self.name

    def _repr_latex_(self):
        return '$${}$$'.format(self.name)

    @classmethod
    def fromsympy(cls, expr):
        import sympy
        if isinstance(expr, sympy.Dummy):
            return Dummy(expr.name)
        elif isinstance(expr, sympy.Symbol):
            return Symbol(expr.name)
        else:
            raise TypeError('expr must be a sympy.Symbol instance')


class Dummy(Symbol):
    """
    This class returns a dummy symbol to ensure that no variables are repeated in an expression
    """
    _count = 0

    def __new__(cls, name=None):
        """
        Create and return a new dummy symbol.
        """
        if name is None:
            name = 'Dummy_{}'.format(Dummy._count)
        elif not isinstance(name, str):
            raise TypeError('name must be a string')
        self = object().__new__(cls)
        self._index = Dummy._count
        self._name = name.strip()
        self._coefficients = {self: Fraction(1)}
        self._constant = Fraction(0)
        self._symbols = (self,)
        self._dimension = 1
        Dummy._count += 1
        return self

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        return self._name, self._index

    def __repr__(self):
        return '_{}'.format(self.name)

    def _repr_latex_(self):
        return '$${}_{{{}}}$$'.format(self.name, self._index)


def symbols(names):
    """
    Transform strings into instances of the Symbol class
    """
    if isinstance(names, str):
        names = names.replace(',', ' ').split()
    return tuple(Symbol(name) for name in names)


class Rational(Expression, Fraction):
    """
    This class represents integers and rational numbers of any size.
    """

    def __new__(cls, numerator=0, denominator=None):
        self = object().__new__(cls)
        self._coefficients = {}
        self._constant = Fraction(numerator, denominator)
        self._symbols = ()
        self._dimension = 0
        self._numerator = self._constant.numerator
        self._denominator = self._constant.denominator
        return self

    def __hash__(self):
        return Fraction.__hash__(self)

    @property
    def constant(self):
        """
        Return rational as a constant.
        """
        return self

    def isconstant(self):
        """
        Test whether a value is a constant.
        """
        return True

    def __bool__(self):
        return Fraction.__bool__(self)

    def __repr__(self):
        if self.denominator == 1:
            return '{!r}'.format(self.numerator)
        else:
            return '{!r}/{!r}'.format(self.numerator, self.denominator)

    def _repr_latex_(self):
        if self.denominator == 1:
            return '$${}$$'.format(self.numerator)
        elif self.numerator < 0:
            return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
                self.denominator)
        else:
            return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
                self.denominator)

    @classmethod
    def fromsympy(cls, expr):
        """
        Create a rational object from a sympy expression
        """
        import sympy
        if isinstance(expr, sympy.Rational):
            return Rational(expr.p, expr.q)
        elif isinstance(expr, numbers.Rational):
            return Rational(expr)
        else:
            raise TypeError('expr must be a sympy.Rational instance')
