#!/usr/bin/env python3

import argparse

import matplotlib.pyplot as plt

from math import ceil

from matplotlib import pylab
from mpl_toolkits.mplot3d import Axes3D

from linpy import *

x, y, z = symbols('x y z')

_x, _y, _z = x.asdummy(), y.asdummy(), z.asdummy()

def translate(domain, *, dx=0, dy=0, dz=0):
    domain &= Polyhedron([x - _x + dx, y - _y + dy, z - _z + dz])
    domain = domain.project([x, y, z])
    domain = domain.subs({_x: x, _y: y, _z: z})
    return domain

def _menger(domain, size):
    result = domain
    result |= translate(domain, dx=0, dy=size, dz=0)
    result |= translate(domain, dx=0, dy=2*size, dz=0)
    result |= translate(domain, dx=size, dy=0, dz=0)
    result |= translate(domain, dx=size, dy=2*size, dz=0)
    result |= translate(domain, dx=2*size, dy=0, dz=0)
    result |= translate(domain, dx=2*size, dy=size, dz=0)
    result |= translate(domain, dx=2*size, dy=2*size, dz=0)
    result |= translate(domain, dx=0, dy=0, dz=size)
    result |= translate(domain, dx=0, dy=2*size, dz=size)
    result |= translate(domain, dx=2*size, dy=0, dz=size)
    result |= translate(domain, dx=2*size, dy=2*size, dz=size)
    result |= translate(domain, dx=0, dy=0, dz=2*size)
    result |= translate(domain, dx=0, dy=size, dz=2*size)
    result |= translate(domain, dx=0, dy=2*size, dz=2*size)
    result |= translate(domain, dx=size, dy=0, dz=2*size)
    result |= translate(domain, dx=size, dy=2*size, dz=2*size)
    result |= translate(domain, dx=2*size, dy=0, dz=2*size)
    result |= translate(domain, dx=2*size, dy=size, dz=2*size)
    result |= translate(domain, dx=2*size, dy=2*size, dz=2*size)
    return result

def menger(domain, count=1, cut=False):
    size = 1
    for i in range(count):
        domain = _menger(domain, size)
        size *= 3
    if cut:
        domain &= Le(x + y + z, ceil(3 * size / 2))
    return domain

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Compute a Menger sponge.')
    parser.add_argument('-n', '--iterations', type=int, default=1,
        help='number of iterations (default: 1)')
    parser.add_argument('-c', '--cut', action='store_true', default=False,
        help='cut the sponge')
    args = parser.parse_args()
    cube = Le(0, x) & Le(x, 1) & Le(0, y) & Le(y, 1) & Le(0, z) & Le(z, 1)
    fractal = menger(cube, args.iterations, args.cut)
    fig = plt.figure(facecolor='white')
    plot = fig.add_subplot(1, 1, 1, projection='3d', aspect='equal')
    fractal.plot(plot)
    pylab.show()
