From: Vivien Maisonneuve <v.maisonneuve@gmail.com>
Date: Tue, 1 Jul 2014 17:05:07 +0000 (+0200)
Subject: Overloading for Domains.project_out(), to be improved
X-Git-Tag: 1.0~176
X-Git-Url: https://svn.cri.mines-paristech.fr/git/linpy.git/commitdiff_plain/d6444456ceaeed6fcf1d44386edefb5b1fb8ec66?hp=4ae512f39c14835badbfab6fc1ce877f601d104e

Overloading for Domains.project_out(), to be improved
---

diff --git a/examples/diamond.py b/examples/diamond.py
index ad0942b..1681054 100755
--- a/examples/diamond.py
+++ b/examples/diamond.py
@@ -5,4 +5,4 @@ from pypol import *
 x, y = symbols('x y')
 diam = Ge(y, x - 1) & Le(y, x + 1) & Ge(y, -x - 1) & Le(y, -x + 1)
 print('diamond:', diam)
-print('projected on x:', diam.drop_dims('y'))
+print('projected on x:', diam.project_out([y]))
diff --git a/pypol/domains.py b/pypol/domains.py
index c844e55..6b47fe8 100644
--- a/pypol/domains.py
+++ b/pypol/domains.py
@@ -5,7 +5,7 @@ import re
 from . import islhelper
 
 from .islhelper import mainctx, libisl, isl_set_basic_sets
-from .linexprs import Expression
+from .linexprs import Expression, Symbol
 
 
 __all__ = [
@@ -154,6 +154,15 @@ class Domain:
 
     def project_out(self, symbols):
         # use to remove certain variables
+        if isinstance(symbols, str):
+            symbols = symbols.replace(',', ' ').split()
+        else:
+            symbols = list(symbols)
+            for i, symbol in enumerate(symbols):
+                if isinstance(symbol, Symbol):
+                    symbols[i] = symbol.name
+                elif not isinstance(symbol, str):
+                    raise TypeError('symbols must be strings or Symbol instances')
         islset = self._toislset(self.polyhedra, self.symbols)
         # the trick is to walk symbols in reverse order, to avoid index updates
         for index, symbol in reversed(list(enumerate(self.symbols))):