module Sound.DF.Node where

import Control.Monad
import Data.Unique

-- * The Node data type

-- | Recursion identifer.
data R_ID = R_ID Int
            deriving (Eq)

-- | Enumeration of types of data on ports.
data Type = Real_Type 
          | Integer_Type
          | Boolean_Type
            deriving (Eq, Show)

-- | Constant values.
data Constant = Real_Constant Double
              | Integer_Constant Int
                deriving (Eq)

-- | How to display constants.
instance Show Constant where
    show (Real_Constant x) = show x
    show (Integer_Constant x) = show x

-- | Port meta data.
data Port = Port { port_data_type :: Type
                 , port_tokens :: Int }
            deriving (Eq)

-- | Data flow node.
data Node = S { constant :: Constant }
          | A { operator :: String
              , inputs :: [Node]
              , outputs :: [Port] }
          | R { identifier :: R_ID
              , input :: Either Constant (Node, Node) }
          | P { proxy :: Node
              , port :: Int }
          | M { mleft :: Node
              , mright :: Node }
            deriving (Eq)

-- | Node identifier.
type NodeID = Int

-- | Port identifier.
type PortID = Int

-- | Multiple root graph (alias for M).
mrg :: Node -> Node -> Node
mrg = M

-- | How to display nodes.
instance Show Node where
    show (S x) = show x
    show (A p _ _) = p
    show (R (R_ID d) (Left i)) = "rR_" ++ show d ++ ":" ++ show i
    show (R (R_ID d) (Right _)) = "wR_" ++ show d
    show (P _ p) = "proxy_" ++ show p
    show (M l r) = "m(" ++ show l ++ "," ++ show r ++ ")"

-- * Querying data type on ports

-- | Type of a constant value.
constant_type :: Constant -> Type
constant_type (Real_Constant _) = Real_Type
constant_type (Integer_Constant _) = Integer_Type

-- | Type of a node.
node_type :: Node -> Type
node_type (S c) = constant_type c
node_type (A _ _ [Port t _]) = t
node_type (A _ _ _) = error "node_type: A: non unary output"
node_type (R _ (Left c)) = constant_type c
node_type (R _ (Right (n, _))) = node_type n
node_type (P n i) = port_data_type (outputs n !! i)
node_type (M l _) = node_type l

-- * Numeric primitives for class instances

-- | Lift constant to node.
n_real_constant :: Double -> Node
n_real_constant = S . Real_Constant

-- | Lift constant to node.
n_integer_constant :: Int -> Node
n_integer_constant = S . Integer_Constant

-- | Unary operator over Real and Integer values.
numerical_unary_operator :: String -> Node -> Node
numerical_unary_operator s p = A s [p] [Port (node_type p) 1]

-- | Binary operator over Real and Integer values.
numerical_binary_operator :: String -> Node -> Node -> Node
numerical_binary_operator s p q =
    let pt = node_type p
        qt = node_type q
    in if pt /= qt
       then error (show ("binary operator", s, pt, qt, p, q))
       else A s [p, q] [Port pt 1]

-- | Unary operator over Real values.
real_unary_operator :: String -> Node -> Node
real_unary_operator s p =
    if node_type p == Real_Type
    then A s [p] [Port Real_Type 1]
    else error (show ("real unary operator", s, p))

-- | Binary operator over Real values.
real_binary_operator :: String -> Node -> Node -> Node
real_binary_operator s p q =
    if node_type p == Real_Type && node_type q == Real_Type
    then A s [p, q] [Port Real_Type 1]
    else error (show ("real binary operator", s, p, q))
    
-- | Addition.
n_add :: Node -> Node -> Node
n_add = numerical_binary_operator "df_add"

-- | Multiplication.
n_mul :: Node -> Node -> Node
n_mul = numerical_binary_operator "df_mul"

-- | Subtraction.
n_sub :: Node -> Node -> Node
n_sub = numerical_binary_operator "df_sub"

-- | Negation.
n_negate :: Node -> Node
n_negate = numerical_unary_operator "df_negate"

-- | Absolute value.
n_abs :: Node -> Node
n_abs p | node_type p == Real_Type = A "df_fabs" [p] [Port Real_Type 1]
        | node_type p == Integer_Type = A "df_iabs" [p] [Port Integer_Type 1]
        | otherwise = error "n_abs" {- quieten compiler -}

-- | Sign of.
n_signum :: Node -> Node
n_signum = numerical_unary_operator "df_signum"

instance Num Node where
  (+) = n_add
  (*) = n_mul
  (-) = n_sub
  negate = n_negate
  abs = n_abs
  signum = n_signum
  fromInteger = n_integer_constant . fromInteger

-- | Division.
n_div :: Node -> Node -> Node
n_div = real_binary_operator "df_div"

-- | Reciprocal.
n_recip :: Node -> Node
n_recip = real_unary_operator "df_recip"

instance Fractional Node where
  (/) = n_div
  recip = n_recip
  fromRational = n_real_constant . fromRational

-- | Natural exponential.
n_exp :: Node -> Node
n_exp = real_unary_operator "df_exp"

-- | Square root.
n_sqrt :: Node -> Node
n_sqrt = real_unary_operator "df_sqrt"

-- | Natural logarithm.
n_log :: Node -> Node
n_log = real_unary_operator "df_log"

-- | 'p' to the power of 'q'.
n_pow :: Node -> Node -> Node
n_pow = real_binary_operator "df_pow"

-- | Sine.
n_sin :: Node -> Node
n_sin = real_unary_operator "df_sin"

-- | Cosine.
n_cos :: Node -> Node
n_cos = real_unary_operator "df_cos"

-- | Tangent.
n_tan :: Node -> Node
n_tan = real_unary_operator "df_tan"

instance Floating Node where
  pi = n_real_constant pi
  exp = n_exp
  sqrt = n_sqrt
  log = n_log
  (**) = n_pow
  logBase = undefined
  sin = n_sin
  tan = n_tan
  cos = n_cos
  asin = undefined
  atan = undefined
  acos = undefined
  sinh = undefined
  tanh = undefined
  cosh = undefined
  asinh = undefined
  atanh = undefined
  acosh = undefined

-- | Operator from Real or Integer values to a Boolean value.
numerical_comparison_operator :: String -> Node -> Node -> Node
numerical_comparison_operator s p q =
    let pt = node_type p
        qt = node_type q
    in if pt == qt && (pt == Integer_Type || pt == Real_Type)
       then A s [p, q] [Port Boolean_Type 1]
       else error (show ("comparison operator", s, pt, qt, p, q))

-- | Less than.
n_lt :: Node -> Node -> Node
n_lt = numerical_comparison_operator "df_lt"

-- | Greater than or equal to.
n_gte :: Node -> Node -> Node
n_gte = numerical_comparison_operator "df_gte"

-- | Greater than.
n_gt :: Node -> Node -> Node
n_gt = numerical_comparison_operator "df_gt"

-- | Less than or equal to.
n_lte :: Node -> Node -> Node
n_lte = numerical_comparison_operator "df_lte"

-- | Maximum.
n_max :: Node -> Node -> Node
n_max = numerical_binary_operator "df_max"

-- | Minimum.
n_min :: Node -> Node -> Node
n_min = numerical_binary_operator "df_min"

instance Ord Node where
  compare = undefined
  (<) = undefined
  (>=) = undefined
  (>) = undefined
  (<=) = undefined
  max = n_max
  min = n_min

-- | Real valued floor. 
n_floor :: Node -> Node
n_floor = real_unary_operator "df_floor"

-- | Integer valued floor.
n_lrint :: Node -> Node
n_lrint p 
    | node_type p == Real_Type = A "df_lrint" [p] [Port Integer_Type 1]
    | otherwise = error "n_lrint"

{-
class (Real a, Fractional a) => RealFrac a where
  properFraction :: (Integral b) => a -> (b, a)
  truncate :: (Integral b) => a -> b
  round :: (Integral b) => a -> b
  ceiling :: (Integral b) => a -> b
  floor :: (Integral b) => a -> b
-}

-- * Class of monads generating identifers

class (Monad m) => ID m where
   generateID :: m Int

instance ID IO where
   generateID = liftM hashUnique newUnique

-- * Backward arcs

-- | Introduce backward arc with implicit unit delay.
rec_r :: R_ID -> Constant -> (Node -> (Node, Node)) -> Node
rec_r n i f = R n (Right (f (R n (Left i))))

-- | Monadic variant of rec_r.
rec :: ID m => Constant -> (Node -> (Node, Node)) -> m Node
rec i f = do n <- generateID
             return (rec_r (R_ID n) i f)

-- | Variant or rec with monadic action in backward arc.
recm :: ID m => Constant -> (Node -> m (Node, Node)) -> m Node
recm i f = 
    do n <- generateID
       let r_r = R (R_ID n) (Left i)
       r <- f r_r
       return (R (R_ID n) (Right r))