{-# Language FlexibleInstances,GADTs,DeriveDataTypeable,StandaloneDeriving #-}
-- | Data flow nodes.
module Sound.DF.Uniform.GADT.DF where

import Data.Int {- base -}
import Data.Typeable {- base -}
import Data.Digest.Murmur32 {- murmur-hash -}

import Sound.DF.Uniform.LL
import Sound.DF.Uniform.UDF

-- * DF

-- | Data flow node.
-- K = constant, A = array, R = recursion, P = primitive, MRG = mrg.
data DF a where
    K :: K' a => a -> DF a
    A :: Vec Float -> DF (Vec Float)
    R :: K' a => R_Id -> TypeRep -> Either a (DF b,DF a) -> DF b
    P0 :: K' a => String -> TypeRep -> DF a
    P1 :: (K' a,K' b) => String -> TypeRep -> DF a -> DF b
    P2 :: (K' a,K' b,K' c) => String -> TypeRep -> DF a -> DF b -> DF c
    P3 :: (K' a,K' b,K' c,K' d) => String -> TypeRep -> DF a -> DF b -> DF c -> DF d
    MCE :: [DF a] -> DF a
    MRG :: K' a => DF a -> DF () -> DF a

deriving instance Show a => Show (DF a)

-- | Typeable instance for 'DF'.
--
-- > df_typeOf (K (undefined::Int32)) == int32_t
-- > df_typeOf (K (undefined::Float)) == float_t
-- > df_typeOf (A undefined) == vec_float_t
-- > df_typeOf (0::DF Int32) == int32_t
-- > df_typeOf (0.0::DF Float) == float_t
df_typeOf :: K' a => DF a -> TypeRep
df_typeOf df =
    case df of
      K k -> typeOf k
      A _ -> vec_float_t
      R _ t _ -> t
      P0 _ t -> t
      P1 _ t _ -> t
      P2 _ t _ _ -> t
      P3 _ t _ _ _ -> t
      MCE l -> case l of
                 [] -> error "df_typeOf: MCE []"
                 n:_ -> df_typeOf n
      MRG n _ -> df_typeOf n

deriving instance Typeable DF

-- | Name of primitive if 'DF' is 'P0' or 'P1' etc.
df_primitive :: DF a -> Maybe String
df_primitive df =
    case df of
      P0 nm _ -> Just nm
      P1 nm _ _ -> Just nm
      P2 nm _ _ _ -> Just nm
      P3 nm _ _ _ _ -> Just nm
      _ -> Nothing

-- * MRG

-- | Multiple root graph (alias for M).
mrg :: K' a => DF a -> DF () -> DF a
mrg = MRG

-- * DF Vec

-- | 'DF' 'Vec' constructor.
df_vec :: V_Id -> [Float] -> DF (Vec Float)
df_vec k v = A (Vec k (length v) v)

-- | Monadic 'DF' 'Vec' constructor.
df_vec_m :: UId m => [Float] -> m (DF (Vec Float))
df_vec_m v = do
  k <- generateId
  return (df_vec (V_Id k) v)

-- | 'DF' 'Vec' size.
df_vec_size :: DF a -> Maybe Int
df_vec_size df =
    case df of
      A (Vec _ n _) -> Just n
      _ -> Nothing

-- | 'df_vec_size' variant, tables have a guard point.
df_tbl_size :: DF a -> Maybe Int
df_tbl_size = fmap (+ (-1)) . df_vec_size

-- * Operator types

-- | Unary operator.
type Unary_Op a = a -> a

-- | Binary operator.
type Binary_Op a = a -> a -> a

-- | Ternary operator.
type Ternary_Op a = a -> a -> a -> a

-- | Quaternary operator.
type Quaternary_Op a = a -> a -> a -> a -> a

-- | Quinary operator.
type Quinary_Op a = a -> a -> a -> a -> a -> a

-- | Senary operator.
type Senary_Op a = a -> a -> a -> a -> a -> a -> a

-- * Uniform function types

-- | Binary function.
type Binary_Fn i o = i -> i -> o

-- * MCE

-- | MCE predicate, sees into MRG.
is_mce :: DF t -> Bool
is_mce n =
    case n of
      MCE _ -> True
      MRG l _ -> is_mce l
      _ -> False

-- | MCE degree, sees into MRG.
mce_degree :: DF t -> Int
mce_degree n =
    case n of
      MCE l -> length l
      MRG l _ -> mce_degree l
      _ -> 1

-- | MCE extension, sees into MRG, will not reduce.
mce_extend :: Int -> DF t -> [DF t]
mce_extend k n =
    if k < mce_degree n
    then error "mce_extend: REDUCE?"
    else case n of
                MCE l -> take k (cycle l)
                MRG _ _ -> error "mce_extend: MRG"
                _ -> replicate k n

mce2 :: DF a -> DF a -> DF a
mce2 p q = MCE [p,q]

unmce :: DF t -> [DF t]
unmce n =
    case n of
      MCE l -> l
      MRG l r -> case unmce l of
                   [] -> error "unmce: MRG?"
                   h:t -> MRG h r : t
      _ -> [n]

unmce2 :: Show t => DF t -> (DF t, DF t)
unmce2 n =
    case unmce n of
      [p,q] -> (p,q)
      _ -> error ("unmce2: " ++ show n)

lift_mce :: (DF a -> DF b) -> DF a -> DF b
lift_mce f p =
    case p of
      MCE l -> MCE (map f l)
      _ -> f p

lift_mce2 :: (DF a -> DF b -> DF c) -> DF a -> DF b -> DF c
lift_mce2 f p q =
    if is_mce p || is_mce q
    then let k = max (mce_degree p) (mce_degree q)
         in MCE (zipWith f (mce_extend k p) (mce_extend k q))
       else f p q

mce_extend3 :: DF a -> DF b -> DF c -> ([DF a],[DF b],[DF c])
mce_extend3 p q r =
    let k = max (mce_degree p) (max (mce_degree q) (mce_degree r))
    in (mce_extend k p,mce_extend k q,mce_extend k r)

lift_mce3 :: (DF a -> DF b -> DF c -> DF d) -> DF a -> DF b -> DF c -> DF d
lift_mce3 f p q r =
    if is_mce p || is_mce q || is_mce r
    then let (p',q',r') = mce_extend3 p q r
         in MCE (zipWith3 f p' q' r')
    else f p q r

-- * Primitive constructors

-- | 'lift_mce' of 'P1'.
mk_p1 :: (K' a, K' b) => String -> TypeRep -> DF a -> DF b
mk_p1 nm ty = lift_mce (P1 nm ty)

-- | Unary operator.
mk_uop :: (K' a) => String -> Unary_Op (DF a)
mk_uop nm p = mk_p1 nm (df_typeOf p) p

-- | 'lift_mce2' of 'P2'.
mk_p2 :: (K' a, K' b, K' c) => String -> TypeRep -> DF a -> DF b -> DF c
mk_p2 nm ty = lift_mce2 (P2 nm ty)

-- | Binary operator.
mk_binop :: K' a => String -> Binary_Op (DF a)
mk_binop nm p q = mk_p2 nm (df_typeOf p) p q

-- | 'lift_mce3' of 'P3'.
mk_p3 :: (K' a, K' b, K' c, K' d) => String -> TypeRep -> DF a -> DF b -> DF c -> DF d
mk_p3 nm ty = lift_mce3 (P3 nm ty)

-- | Binary operator.
mk_ternaryop :: K' a => String -> Ternary_Op (DF a)
mk_ternaryop nm p q r = mk_p3 nm (df_typeOf p) p q r

-- | 'DF' multiply and add.
df_mul_add :: K_Num a => DF a -> DF a -> DF a -> DF a
df_mul_add = mk_ternaryop "df_mul_add"

-- | Optimising addition primitive.  If either input is a multiplier
-- node, unfold to a multiplier-add node.
--
-- > df_add_optimise (2 * 3) (4::DF Int32)
-- > df_add_optimise (2::DF Int32) (3 * 4)
df_add_optimise :: K_Num a => DF a -> DF a -> DF a
df_add_optimise p q =
    case (p,q) of
      (P2 "df_mul" t l r,_) -> mk_p3 "df_mul_add" t l r q
      (_,P2 "df_mul" t l r) -> mk_p3 "df_mul_add" t l r p
      _ -> mk_binop "df_add" p q

instance K_Num a => Num (DF a) where
    (+) = df_add_optimise
    (*) = mk_binop "df_mul"
    (-) = mk_binop "df_sub"
    negate = mk_uop "df_negate"
    abs = mk_uop "df_abs"
    signum = mk_uop "df_signum"
    fromInteger = K . fromInteger

instance Fractional (DF Float) where
    (/) = mk_p2 "df_div" float_t
    recip = mk_p1 "df_recip" float_t
    fromRational = K . fromRational

instance Floating (DF Float) where
  pi = K pi
  exp = mk_p1 "df_exp" float_t
  sqrt = mk_p1 "df_sqrt" float_t
  log = mk_p1 "df_log" float_t
  (**) = mk_p2 "df_pow" float_t
  logBase = undefined
  sin = mk_p1 "df_sin" float_t
  tan = mk_p1 "df_tan" float_t
  cos = mk_p1 "df_cos" float_t
  asin = undefined
  atan = undefined
  acos = undefined
  sinh = undefined
  tanh = undefined
  cosh = undefined
  asinh = undefined
  atanh = undefined
  acosh = undefined

-- * Bits

-- | "Data.Bits" @.&.@.
df_bw_and :: DF Int32 -> DF Int32 -> DF Int32
df_bw_and = mk_p2 "df_bw_and" int32_t

-- | "Data.Bits" @.|.@.
df_bw_or :: DF Int32 -> DF Int32 -> DF Int32
df_bw_or = mk_p2 "df_bw_or" int32_t

-- | "Data.Bits" @complement@.
df_bw_not :: DF Int32 -> DF Int32
df_bw_not = mk_p1 "df_bw_not" int32_t

-- * Ord

-- | '==', equal to.
df_eq :: K_Ord a => DF a -> DF a -> DF Bool
df_eq = mk_p2 "df_eq" bool_t

-- | '<', less than.
df_lt :: K_Ord a => DF a -> DF a -> DF Bool
df_lt = mk_p2 "df_lt" bool_t

-- | '>=', greater than or equal to.
df_gte :: K_Ord a => DF a -> DF a -> DF Bool
df_gte = mk_p2 "df_gte" bool_t

-- | '>', greater than.
df_gt :: K_Ord a => DF a -> DF a -> DF Bool
df_gt = mk_p2 "df_gt" bool_t

-- | '<=', less than or equal to.
df_lte :: K_Ord a => DF a -> DF a -> DF Bool
df_lte = mk_p2 "df_lte" bool_t

-- | 'max', select maximum.
df_max :: K_Ord a => DF a -> DF a -> DF a
df_max = mk_binop "df_max"

-- | 'min', select minimum.
df_min :: K_Ord a => DF a -> DF a -> DF a
df_min = mk_binop "df_min"

-- * Cast

-- | Cast floating point to integer.
df_float_to_int32 :: DF Float -> DF Int32
df_float_to_int32 = mk_p1 "df_float_to_int32" int32_t

-- | Cast integer to floating point.
df_int32_to_float :: DF Int32 -> DF Float
df_int32_to_float = mk_p1 "df_int32_to_float" float_t

-- | Scale 'Int32' to (-1,1) normalised 'Float'.
--
-- > maxBound == (2147483647::Int32)
i32_to_normal_f32 :: DF Int32 -> DF Float
i32_to_normal_f32 = (/ 2147483647) . df_int32_to_float

-- * Integral

-- | Integral modulo, ie. 'mod'.
df_mod :: Binary_Op (DF Int32)
df_mod = mk_p2 "df_mod" int32_t

-- | Floating point modulo, ie. "Foreign.C.Math" /fmodf/.
df_fmodf :: Binary_Op (DF Float)
df_fmodf = mk_p2 "df_fmodf" float_t

-- * RealFrac

-- | ceilf(3)
df_ceilf :: DF Float -> DF Float
df_ceilf = mk_p1 "df_ceilf" float_t

-- | floorf(3)
df_floorf :: DF Float -> DF Float
df_floorf = mk_p1 "df_floorf" float_t

-- | lrintf(3), ie. round to nearest integer.
df_lrintf :: DF Float -> DF Int32
df_lrintf = mk_p1 "df_lrintf" int32_t

-- | roundf(3)
df_roundf :: DF Float -> DF Float
df_roundf = mk_p1 "df_roundf" float_t

-- * Backward arcs

-- | Introduce backward arc with implicit unit delay.
--
-- The function receives the previous output as input, initially @y0@,
-- and returns a /(feed-forward,feed-backward)/ pair.
--
-- > rec_r (R_Id 0) (0::Int32) ((\i->(i,i)) . (+) 1)
-- > rec_r (R_Id 0) (0.0::Float) ((\i->(i,i)) . (+) 1.0)
rec_r :: K' a => R_Id -> a -> (DF a -> (DF b,DF a)) -> DF b
rec_r n y0 f =
    let t = typeOf y0
        i = R n t (Left y0)
    in case f i of
         (MCE _,MCE _) -> error "rec_h: MCE"
         r -> R n t (Right r)

-- | Monadic variant of 'rec_r'.
rec_m :: (K' a,UId m) => a -> (DF a -> (DF b,DF a)) -> m (DF b)
rec_m y0 f = do
  n <- generateId
  return (rec_r (R_Id n) y0 f)

-- | Hash-eq variant of 'rec_r'.
rec_h :: (K' a,Show b) => a -> (DF a -> (DF b,DF a)) -> DF b
rec_h y0 f =
    let n = abs (fromIntegral (asWord32 (hash32 (show (f (K y0))))))
    in rec_r (R_Id n) y0 f

-- | Variant of 'rec_m' with monadic action in backward arc.
rec_mM :: (K' a,UId m) => a -> (DF a -> m (DF b,DF a)) -> m (DF b)
rec_mM i f = do
  n <- generateId
  let t = typeOf i
      r_r = R (R_Id n) t (Left i)
  r <- f r_r
  return (R (R_Id n) t (Right r))

-- * Primitives

-- | Single channel input (channel 0).
in1 :: DF Float
in1 = P0 "df_in1" float_t

-- | Single channel output (channel 0).
out1 :: DF Float -> DF ()
out1 = mk_p1 "df_out1" nil_t

-- | Two channel output (channels 1 & 2).
out2 :: DF Float -> DF Float -> DF ()
out2 = mk_p2 "df_out2" nil_t

-- | Three channel output.
out3 :: DF Float -> DF Float -> DF Float -> DF ()
out3 = mk_p3 "df_out3" nil_t

-- | MCE collapsing output.
out :: DF Float -> DF ()
out n =
    case n of
      MCE [p] -> out1 p
      MCE [p,q] -> out2 p q
      MCE [p,q,r] -> out3 p q r
      MCE _ -> error "out: MCE"
      _ -> out1 n

-- | Single control input.
ctl1 :: DF Int32 -> DF Float
ctl1 = mk_p1 "df_ctl1" float_t

-- | Logical '&&'.
df_and :: DF Bool -> DF Bool -> DF Bool
df_and = mk_p2 "df_and" bool_t

-- | Logical '||'.
df_or :: DF Bool -> DF Bool -> DF Bool
df_or = mk_p2 "df_or" bool_t

-- | Logical 'not'.
df_not :: DF Bool -> DF Bool
df_not = mk_p1 "df_not" bool_t

-- | If /p/ then /q/ else /r/.  /p/ must have type bool, and /q/
-- and /r/ must have equal types.
select2 :: K' a => DF Bool -> DF a -> DF a -> DF a
select2 p q = mk_p3 "df_select2" (df_typeOf q) p q

-- | Operating sample rate.
w_sample_rate :: DF Float
w_sample_rate = P0 "df_sample_rate" float_t

-- | Number of frames in current control period.
w_kr_nframes :: DF Int32
w_kr_nframes = P0 "df_kr_nframes" int32_t

-- | 'True' at first frame of each control period.
w_kr_edge :: DF Bool
w_kr_edge = P0 "df_kr_edge" bool_t

-- | Buffer read, read from buffer /p/ at index /q/.
b_read :: DF Int32 -> DF Int32 -> DF Float
b_read = mk_p2 "df_b_read" float_t

-- | Buffer write, write to buffer /p/ at index /q/ value /r/.
b_write :: DF Int32 -> DF Int32 -> DF Float -> DF ()
b_write = mk_p3 "df_b_write" nil_t

-- | Array read.
a_read :: DF (Vec Float)-> DF Int32 -> DF Float
a_read = mk_p2 "df_a_read" float_t

-- | Array write.
a_write :: DF (Vec Float) -> DF Int32 -> DF Float -> DF ()
a_write = mk_p3 "df_a_write" nil_t

-- * Untyped

-- | Transform typed 'DF' to un-typed 'UDF'.
df_erase :: K' a => DF a -> UDF
df_erase n =
    case n of
      K i -> UDF_K (to_k i)
      A a -> UDF_A a
      R k _ (Left i) -> UDF_R k (Left (to_k i))
      R k _ (Right (i,j)) -> UDF_R k (Right (df_erase i,df_erase j))
      P0 nm t -> UDF_P nm t []
      P1 nm t i -> UDF_P nm t [df_erase i]
      P2 nm t i j -> UDF_P nm t [df_erase i,df_erase j]
      P3 nm t i j k -> UDF_P nm t [df_erase i,df_erase j,df_erase k]
      MCE _ -> error "df_erase: MCE"
      MRG i j -> UDF_MRG (df_erase i) (df_erase j)