(**
 	Module: Process	
	Description: Faust process classes
	@author WANG Haisheng	
	Created: 03/06/2013	Modified: 14/08/2013
*)

open Types;;
open Aux;;
open Basic;;
open Symbol;;
open Value;;
open Signal;;
open Beam;;

exception NotYetDone;;
exception Dimension_error of string;;
exception Process_error of string;;


(* PARSER *)

let exp_of_string s = (Parser.main Lexer.token (Lexing.from_string s));;


class dimension : int * int -> dimension_type = 
  fun (init : int * int) -> 
    object (self)
      val _input = fst init
      val _output = snd init

      method input = _input
      method output = _output

      method par : dimension_type -> dimension_type = 
	fun dim -> 
	  new dimension 
	    ((self#input + dim#input), (self#output + dim#output))
	    
      method seq : dimension_type -> dimension_type = 
	fun dim -> 
	  if self#output = dim#input then
	    new dimension (self#input, dim#output)
	  else raise (Dimension_error "seq dimension not matched.")
	      
      method split : dimension_type -> dimension_type =
	fun dim ->
	  if dim#input mod self#output = 0 then
	    new dimension (self#input, dim#output)
	  else raise (Dimension_error "split dimension not matched.")

      method merge : dimension_type -> dimension_type =
	fun dim ->
	  if self#output mod dim#input = 0 then
	    new dimension (self#input, dim#output)
	  else raise (Dimension_error "merge dimension not matched.")

      method _rec : dimension_type -> dimension_type = 
	fun dim ->
	  if self#output >= dim#input && self#input >= dim#output then
	    new dimension (self#input - dim#output, self#output)
	  else raise (Dimension_error "rec dimension not matched.")
    end;;

class virtual process = 
  fun (exp_init : faust_exp) ->
    object
      val _exp = exp_init
      val virtual _dim : dimension_type
      val virtual _delay : int
      method exp = _exp
      method dim = _dim
      method delay = _delay
      method virtual eval : beam_type -> beam_type
    end

class proc_const : faust_exp -> process_type = 
  fun (exp_init : faust_exp) ->
    let _const = 
      match exp_init with
      | Const b -> b
      | _ -> raise (Process_error "const process constructor.") in

    object (self)
      inherit process exp_init
      val _dim = new dimension (0,1)
      val _delay = 0
      method private const = _const
      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  if input#get = [||] then
	    new beam [| new signal (new rate 0 1) 
			 (fun t -> new value self#const)|]
	  else
	    raise (Process_error "proc_const accepts no input.")
      end;;


class proc_ident : faust_exp -> process_type = 
  fun (exp_init : faust_exp) ->
    let _symbol = 
	  match exp_init with
	  | Ident s -> s
	  | _ -> raise (Process_error "ident process constructor.") in

    object (self)
      inherit process exp_init
      val _dim = new dimension (dimension_of_symbol _symbol)
      val _delay = delay_of_symbol _symbol      
      method private symb = _symbol

      method private beam_of_ident : int -> signal_type -> beam_type = 
	fun (n : int) ->
	  fun (s : signal_type) ->
	    if n = (self#dim)#input then 
	      new beam [|s|]
	    else raise (Process_error ("Ident " ^ string_of_symbol self#symb))

      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  let n = Array.length input#get in
	  match self#symb with
	  | Pass -> self#beam_of_ident n input#get.(0)
	  | Stop -> if n = 1 then new beam [||] 
	            else raise (Process_error "Ident !")
	  | Add -> self#beam_of_ident n 
		((input#get.(0))#add input#get.(1))
	  | Sub -> self#beam_of_ident n 
		((input#get.(0))#sub input#get.(1))
	  | Mul -> self#beam_of_ident n 
		((input#get.(0))#mul input#get.(1))
	  | Div -> self#beam_of_ident n 
		((input#get.(0))#div input#get.(1))
	  | Power -> self#beam_of_ident n
		((input#get.(0))#power input#get.(1))
	  | And -> self#beam_of_ident n
		((input#get.(0))#_and input#get.(1))
	  | Or -> self#beam_of_ident n
		((input#get.(0))#_or input#get.(1))
	  | Xor -> self#beam_of_ident n
		((input#get.(0))#_xor input#get.(1))
	  | Mem -> self#beam_of_ident n 
		((input#get.(0))#mem)
	  | Delay -> self#beam_of_ident n 
		((input#get.(0))#delay input#get.(1))
	  | Floor -> self#beam_of_ident n 
		((input#get.(0))#floor)
	  | Ceil -> self#beam_of_ident n 
		((input#get.(0))#ceil)
	  | Rint -> self#beam_of_ident n 
		((input#get.(0))#rint)
	  | Int -> self#beam_of_ident n 
		((input#get.(0))#int)
	  | Float -> self#beam_of_ident n 
		((input#get.(0))#float)
	  | Sin -> self#beam_of_ident n 
		((input#get.(0))#sin)
	  | Asin -> self#beam_of_ident n 
		((input#get.(0))#asin)
	  | Cos -> self#beam_of_ident n 
		((input#get.(0))#cos)
	  | Acos -> self#beam_of_ident n 
		((input#get.(0))#acos)
	  | Tan -> self#beam_of_ident n 
		((input#get.(0))#tan)
	  | Atan -> self#beam_of_ident n 
		((input#get.(0))#atan)
	  | Atan2 -> self#beam_of_ident n 
		((input#get.(0))#atan2 input#get.(1))
	  | Exp -> self#beam_of_ident n 
		((input#get.(0))#exp)
	  | Sqrt -> self#beam_of_ident n 
		((input#get.(0))#sqrt)
	  | Ln -> self#beam_of_ident n 
		((input#get.(0))#ln)
	  | Lg -> self#beam_of_ident n 
		((input#get.(0))#lg)
	  | Abs -> self#beam_of_ident n 
		((input#get.(0))#abs)
	  | Mod -> self#beam_of_ident n 
		((input#get.(0))#_mod input#get.(1))
	  | Fmod -> self#beam_of_ident n 
		((input#get.(0))#fmod input#get.(1))
	  | Remainder -> self#beam_of_ident n 
		((input#get.(0))#remainder input#get.(1))
	  | Vectorize -> self#beam_of_ident n 
		((input#get.(0))#vectorize input#get.(1))
	  | Vconcat -> self#beam_of_ident n 
		((input#get.(0))#vconcat input#get.(1))
	  | Vpick -> self#beam_of_ident n 
		((input#get.(0))#vpick input#get.(1))
	  | Serialize -> self#beam_of_ident n 
		(input#get.(0))#serialize
	  | Gt -> self#beam_of_ident n 
		((input#get.(0))#gt input#get.(1))
	  | Lt -> self#beam_of_ident n 
		((input#get.(0))#lt input#get.(1))
	  | Geq -> self#beam_of_ident n 
		((input#get.(0))#geq input#get.(1))
	  | Leq -> self#beam_of_ident n 
		((input#get.(0))#leq input#get.(1))
	  | Eq -> self#beam_of_ident n 
		((input#get.(0))#eq input#get.(1))
	  | Neq -> self#beam_of_ident n 
		((input#get.(0))#neq input#get.(1))
	  | Max -> self#beam_of_ident n 
		((input#get.(0))#max input#get.(1))
	  | Min -> self#beam_of_ident n 
		((input#get.(0))#min input#get.(1))
	  | Shl -> self#beam_of_ident n 
		((input#get.(0))#shl input#get.(1))
	  | Shr -> self#beam_of_ident n 
		((input#get.(0))#shr input#get.(1))
	  | Prefix -> self#beam_of_ident n 
		((input#get.(1))#prefix input#get.(0))
	  | Select2 -> self#beam_of_ident n 
		((input#get.(0))#select2 input#get.(1) input#get.(2))
	  | Select3 -> self#beam_of_ident n 
		((input#get.(0))#select3 input#get.(1) 
		   input#get.(2) input#get.(3))
	  | Rdtable -> self#beam_of_ident n 
		((input#get.(1))#rdtable input#get.(0) input#get.(2))
	  | Rwtable -> self#beam_of_ident n 
		((input#get.(0))#rwtable input#get.(1) 
		   input#get.(2) input#get.(3) input#get.(4))
	  | other -> 
	      let err_message = "GUI not supported: " 
		^ (string_of_symbol other) ^ "." in
	      raise (Process_error err_message)
      end;;

class virtual process_binary =
  fun (exp_init : faust_exp) ->
    let (exp_left, exp_right) = 
      match exp_init with
      | Par (e1, e2) -> (e1, e2)
      |	Seq (e1, e2) -> (e1, e2)
      |	Split (e1, e2) -> (e1, e2)
      |	Merge (e1, e2) -> (e1, e2)
      |	Rec (e1, e2) -> (e1, e2)
      | _ -> raise (Process_error "binary process constructor.") in
    let proc_left = (new proc_factory)#make exp_left in
    let proc_right = (new proc_factory)#make exp_right in
    
    object
      inherit process exp_init
      method private proc_left = proc_left
      method private proc_right = proc_right

      val _dim = 
	match exp_init with
	| Par (e1, e2) -> (proc_left#dim)#par proc_right#dim
	| Seq (e1, e2) -> (proc_left#dim)#seq proc_right#dim
	| Split (e1, e2) -> (proc_left#dim)#split proc_right#dim
	| Merge (e1, e2) -> (proc_left#dim)#merge proc_right#dim
	| Rec (e1, e2) -> (proc_left#dim)#_rec proc_right#dim
	| _ -> raise (Process_error "binary process constructor.")

      val _delay = 
	match exp_init with
	| Par (e1, e2) -> max proc_left#delay proc_right#delay
	| Seq (e1, e2) -> proc_left#delay + proc_right#delay
	| Split (e1, e2) -> proc_left#delay + proc_right#delay
	| Merge (e1, e2) -> proc_left#delay + proc_right#delay
	| Rec (e1, e2) -> 1 + proc_left#delay + proc_right#delay
	| _ -> raise (Process_error "binary process constructor.")
    end

and proc_par : faust_exp -> process_type = 
  fun (exp_init : faust_exp) ->
    object (self)
      inherit process_binary exp_init
      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  let (sub_input1, sub_input2) = input#cut self#proc_left#dim#input in
	  let sub_output1 = self#proc_left#eval sub_input1 in
	  let sub_output2 = self#proc_right#eval sub_input2 in
	  sub_output1#append sub_output2
      end

and proc_split : faust_exp -> process_type =
  fun (exp_init : faust_exp) ->
    object (self)
      inherit process_binary exp_init
      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  let mid_output = self#proc_left#eval input in
	  let mid_input = mid_output#matching self#proc_right#dim#input in
	  self#proc_right#eval mid_input
      end

and proc_merge : faust_exp -> process_type =
  fun (exp_init : faust_exp) -> 
    object (self)
      inherit process_binary exp_init
      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  let mid_output = self#proc_left#eval input in
	  let mid_input = mid_output#matching self#proc_right#dim#input in
	  self#proc_right#eval mid_input
      end

and proc_seq : faust_exp -> process_type =
  fun (exp_init : faust_exp) -> 
    object (self)
      inherit process_binary exp_init
      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  let mid_output = self#proc_left#eval input in
	  self#proc_right#eval mid_output
      end

and proc_rec : faust_exp -> process_type =
  fun (exp_init : faust_exp) -> 
    object (self)
      inherit process_binary exp_init    	  
      method eval : beam_type -> beam_type = 
	fun (input : beam_type) ->
	  let memory = Hashtbl.create self#delay in
	  let rates = ref (Array.make self#dim#output (new rate 0 1)) in

	  let split : (time -> value_type array) -> (time -> value_type) array = 
	    fun beam_at ->
	      let get_signal = 
		fun beam_func -> fun i -> fun t -> 
		(beam_func t).(i) in
	      Array.init self#dim#output (get_signal beam_at) in

	  let feedback : (time -> value_type array) -> beam = 
	    fun beam_at ->
	      let signals_at = split beam_at in
	      let delay_by_one = fun s -> fun t -> s (t - 1) in
	      let delay_signal_funcs = Array.map delay_by_one 
		  (Array.sub signals_at 0 self#proc_right#dim#input) in
	      new beam (array_map2 (new signal) 
			  (Array.sub !rates 0 self#proc_right#dim#input) 
			  delay_signal_funcs) in

	  let rec beam_at : time -> value_type array = 
	    fun (t : time) ->	      
	      if t < 0 then 
		Array.make self#dim#output (new value Zero)
	      else if Hashtbl.mem memory t then
		Hashtbl.find memory t		  
	      else
		let beam_fb_in = feedback beam_at in
		let beam_fb_out = self#proc_right#eval beam_fb_in in
		let beam_in = beam_fb_out#append input in
		let beam_out = self#proc_left#eval beam_in in
		let values = beam_out#at t in
		let () = (rates := beam_out#frequency) in
		let () = Hashtbl.add memory t values in
		let () = if t - self#delay >= 0 then 
		  Hashtbl.remove memory (t - self#delay) else () in
		values in	  
	  new beam (array_map2 (new signal) !rates (split beam_at))	      
    end

and proc_factory = 
  object
      method make : faust_exp -> process_type = 
	fun (exp : faust_exp) ->
	  match exp with
	  | Const b -> new proc_const exp
	  | Ident s -> new proc_ident exp
	  | Par (e1, e2) -> new proc_par exp
	  | Seq (e1, e2) -> new proc_seq exp
	  | Split (e1, e2) -> new proc_split exp
	  | Merge (e1, e2) -> new proc_merge exp
	  | Rec (e1, e2) -> new proc_rec exp
  end;;