module Sound.DF.UGen where

import Control.Monad
import Sound.DF.Node

-- * Primitive unit generators

-- | Uniform input type operator.
uniform_operator :: Type -> Int -> String -> [Node] -> Node
uniform_operator t n s ps =
    if all (\p -> node_type p == t) ps
    then A s ps (replicate n (Port t 1)) 
    else error (show ("output operator", ps))

-- | Single channel output.
out1 :: Node -> Node
out1 p = uniform_operator Real_Type 0 "df_out1" [p]

-- | Two channel output.
out2 :: (Node, Node) -> Node
out2 (p, q) = uniform_operator Real_Type 0 "df_out2" [p, q]

-- | Three channel output.
out3 :: (Node, Node, Node) -> Node
out3 (p, q, r) = uniform_operator Real_Type 0 "df_out3" [p, q, r]

-- | Operating sample rate.
sample_rate :: Node
sample_rate = A "df_sample_rate" [] [Port Real_Type 1]

-- | Equal to.
eq :: Node -> Node -> Node
eq = numerical_comparison_operator "df_eq"

-- | If 'p' then 'q' else 'r'.
select2 :: Node -> Node -> Node -> Node
select2 p q r = 
    if node_type p == Boolean_Type && 
       node_type q == node_type r
    then A "df_select2" [p, q, r] [Port (node_type q) 1]
    else error (show ("select2", p, q, r))

-- | Binary boolean valued operator.
logical_operator :: String -> Node -> Node -> Node
logical_operator s p q =
    if node_type p == Boolean_Type && node_type q == Boolean_Type
    then A s [p, q] [Port Boolean_Type 1]
    else error (show ("logical operator", s, p, q))
    
-- | Logical and.
n_and :: Node -> Node -> Node
n_and = logical_operator "df_and"

-- | Logical or.
n_or :: Node -> Node -> Node
n_or = logical_operator "df_or"

-- | Buffer read.
b_read :: Node -> Node -> Node
b_read p q =
    if node_type p == Integer_Type && node_type q == Integer_Type
    then A "df_b_read" [p, q] [Port Real_Type 1]
    else error (show ("b_read", p, q))

-- | Buffer write.
b_write :: Node -> Node -> Node -> Node
b_write p q r =
    if node_type p == Integer_Type && 
       node_type q == Integer_Type && 
       node_type r == Real_Type
    then A "df_b_write" [p, q, r] []
    else error (show ("b_write", p, q, r))
 
-- | White noise (0, 1).
white_noise_u :: Node -> Node
white_noise_u p = A "df_random" [p] [Port Real_Type 1]

-- * Ordinary unit generators

-- | Linear pan.
pan2 :: Node -> Node -> (Node, Node)
pan2 p q = (p * q, p * (q - 1.0))

-- | Reversed tuple constructor, (ie. @flip (,)@)
swap :: a -> b -> (b, a)
swap = flip (,)

-- | Duplicate a value into a tuple.
split :: a -> (a, a)
split p = (p, p)

-- | Single sample delay with indicated initial value.
unit_delay :: ID m => Constant -> Node -> m Node
unit_delay y0 = rec y0 . swap

-- | Single place infinte impulse response filter with indicated initial value.
iir1 :: ID m => Constant -> (Node -> Node -> Node) -> Node -> m Node
iir1 y0 f i = rec y0 (split . f i)

-- | Two place infinte impulse response filter.
iir2 :: ID m => (Node -> Node -> Node -> Node) -> Node -> m Node
iir2 f i = recm 
             (Real_Constant 0)
             (liftM split . (\y1 -> do y2 <- unit_delay (Real_Constant 0) y1
                                       return (f i y1 y2)))

-- | Single place finte impulse response filter.
fir1 :: ID m => (Node -> Node -> Node) -> Node -> m Node
fir1 f i = do x1 <- unit_delay (Real_Constant 0) i
              return (f i x1)

-- | Two place finte impulse response filter.
fir2 :: ID m => (Node -> Node -> Node -> Node) -> Node -> m Node
fir2 f i = do x1 <- unit_delay (Real_Constant 0) i
              x2 <- unit_delay (Real_Constant 0) x1
              return (f i x1 x2)

-- | Ordinary biquad filter section.
biquad :: ID m => (Node -> Node -> Node -> Node -> Node -> Node) -> Node -> m Node
biquad f i = recm
               (Real_Constant 0)
               (liftM split . (\y1 -> do x1 <- unit_delay (Real_Constant 0) i
                                         x2 <- unit_delay (Real_Constant 0) x1
                                         y2 <- unit_delay (Real_Constant 0) y1
                                         return (f i x1 x2 y1 y2)))

-- | Counter from indicated initial value.
counter :: ID m => Constant -> Node -> m Node
counter y0 = iir1 y0 (+)

-- | Environment value, equal to @'two_pi' / 'sample_rate'@.
radians_per_sample :: Node
radians_per_sample = two_pi / sample_rate

-- | r = cycle (two-pi), hz = frequency, sr = sample rate
hz_to_incr :: Node -> Node -> Node -> Node
hz_to_incr r hz sr = (r / sr) * hz

-- | Two pi.
two_pi :: Floating a => a
two_pi = 2.0 * pi

-- | If 'q >= p' then 'q - p' else 'q'.
clipr :: Node -> Node -> Node
clipr p q = select2 (q `n_gte` p) (q - p) q

-- | r = right hand edge, ip = initial phase, x = increment
phasor :: ID m => Constant -> Node -> Node -> m Node
phasor ip r = iir1 ip (\x y1 -> clipr r (x + y1))

-- | Sine oscillator, f = frequency in hz.
sin_osc :: ID m => Node -> Double -> m Node
sin_osc f ip = 
    do p <- phasor (Real_Constant ip) two_pi (hz_to_incr two_pi f sample_rate)
       return (sin p)

-- | Non-band limited sawtooth oscillator.
lf_saw :: ID m => Node -> Double -> m Node
lf_saw f ip = do p <- phasor (Real_Constant ip) 2.0 (hz_to_incr 2.0 f sample_rate)
                 return (p - 1.0)

-- | Non-band limited pulse oscillator, w = width (0,1).
lf_pulse :: ID m => Node -> Double -> Node -> m Node
lf_pulse f ip w = 
    do p <- phasor (Real_Constant ip) 1.0 (hz_to_incr 1.0 f sample_rate)
       return (select2 (p `n_gte` w) 0.0 1.0)

-- | Midi note number to cycles per second.
midi_cps :: Floating a => a -> a
midi_cps a = 440.0 * (2.0 ** ((a - 69.0) * (1.0 / 12.0)))

-- | Multiply and add.
mul_add :: Num a => a -> a -> a -> a
mul_add i m a = (i * m) + a

calc_fb :: Floating a => a -> a -> a
calc_fb delayt decayt = exp ((log 0.001 * delayt) / decayt)

-- | Delay.
delay :: ID m => Node -> Node -> Node -> m Node
delay b s n =
    do wi <- phasor (Integer_Constant 0) n 1
       let ri = clipr n (wi + 1)
       return (mrg (b_read b ri) (b_write b wi s))

-- | Comb filter.
buf_comb_n :: ID m => Node -> Node -> Node -> Node -> m Node
buf_comb_n b s dlt dct =
    do let n = n_lrint (dlt * sample_rate)
           fb = calc_fb dlt dct
           c i = do x <- delay b i n
                    return (split (s + (fb * x)))
       recm (Real_Constant 0) c

-- | Resonant low pass filter, f = frequency, r = resonance.
rlpf :: ID m => Node -> Node -> Node -> m Node
rlpf i f r = 
    let qr = max 0.001 r
        pf = f * radians_per_sample
        d = tan (pf * qr * 0.5)
        c = (1.0 - d) / (1.0 + d)
        b1 = (1.0 + c) * cos pf
        b2 = negate c
        a0 = (1.0 + c - b1) * 0.25
    in iir2 (\x y1 y2 -> a0 * x + b1 * y1 + b2 * y2) i

-- | Constrain p in (-q, q).
clip2 :: Node -> Node -> Node
clip2 p q =
    let nq = negate q 
    in min q (max p nq)

-- | White noise (-1, 1).
white_noise :: Node -> Node
white_noise p = white_noise_u p * 2.0 - 1.0

-- | White noise (-1, 1).
white_noise_m :: ID m => m Node
white_noise_m = 
    do i <- generateID
       return (white_noise (n_integer_constant i))

-- | Brown noise (-1, 1).
brown_noise_m :: ID m => m Node
brown_noise_m = 
    do w <- white_noise_m
       let w8 = w / 8.0
       iir1 
         (Real_Constant 0) 
         (\x y1 -> let z = x + y1 
                       r = select2 (z `n_lt` (-1.0)) ((-2.0) - z) z
                   in select2 (z `n_gt` 1.0) (2.0 - z) r)
         w8

-- | Two zero fixed midpass filter.
bpz2 :: ID m => Node -> m Node
bpz2 = fir2 (\x _ x2 -> (x - x2) * 0.5)

-- | Two zero fixed midcut filter.
brz2 :: ID m => Node -> m Node
brz2 = fir2 (\x _ x2 -> (x + x2) * 0.5)

-- | Two point average filter
lpz1 :: ID m => Node -> m Node
lpz1 = fir1 (\x x1 -> (x + x1) * 0.5)

-- | Two zero fixed lowpass filter
lpz2 :: ID m => Node -> m Node
lpz2 = fir2 (\x x1 x2 -> (x + (2.0 * x1) + x2) * 0.25)

-- | One pole filter.
one_pole :: ID m => Node -> Node -> m Node
one_pole i cf = iir1 
                  (Real_Constant 0) 
                  (\x y1 -> ((1.0 - abs cf) * x) + (cf * y1)) 
                  i

-- | One zero filter.
one_zero :: ID m => Node -> Node -> m Node
one_zero i cf = fir1 (\x x1 -> ((1.0 - abs cf) * x) + (cf * x1)) i

-- | Second order filter section.
sos :: ID m => Node -> Node -> Node -> Node -> Node -> Node -> m Node
sos i a0 a1 a2 b1 b2 = 
    let f x x1 x2 y1 y2 = a0*x + a1*x1 + a2*x2 + b1*y1 + b2*y2
    in biquad f i

-- | Impulse oscillator (non band limited).
impulse :: ID m => Node -> Double -> m Node
impulse f ip =
    do let i = hz_to_incr 1.0 f sample_rate
       p <- phasor (Real_Constant ip) 1.0 i
       x1 <- unit_delay (Real_Constant 0) p
       let s = (x1 `n_lt` 0.5) `n_and` (p `n_gte` 0.5)
       return (select2 s 1.0 0.0)

-- | Two pole resonant filter.
resonz :: ID m => Node -> Node -> Node -> m Node
resonz i f rq = 
    let ff = f * radians_per_sample
        b = ff * rq
        r = 1.0 - b * 0.5
        two_r = 2.0 * r
        r2 = r * r
        ct = (two_r * cos ff) / (1.0 + r2)
        b1 = two_r * ct
        b2 = negate r2
        a0 = (1.0 - r2) * 0.5
    in iir2 (\x y1 y2 -> let y0 = x + b1 * y1 + b2 * y2
                         in a0 * (y0 - y2)) i

-- | Sample and hold.
latch :: ID m => Node -> Node -> m Node
latch i t = iir1 
              (Real_Constant 0) 
              (\x y1 -> select2 (t `n_gt` 0.0) x y1) 
              i

-- | Linear range conversion.
lin_lin :: Fractional a => a -> a -> a -> a -> a -> a
lin_lin i in_l in_r out_l out_r =
    let s = (out_r - out_l) / (in_r - in_l)
        o = out_l - (s * in_l)
    in (i * s) + o

-- | Exponential range conversion.
lin_exp :: Floating a => a -> a -> a -> a -> a -> a
lin_exp i in_l in_r out_l out_r =
    let rt = out_r / out_l
        rn = 1.0 / (in_r - in_l)
        rr = rn * negate in_l
    in out_l * (rt ** (i * rn + rr))

-- | Exponential decay.
decay :: ID m => Node -> Node -> m Node
decay i dt =
    let b1 = exp (log 0.001 / (dt * sample_rate))
    in iir1 (Real_Constant 0) (\x y1 -> x + b1 * y1) i

-- | Exponential decay (equvalent to @decay dcy - decay atk@).
decay2 :: ID m => Node -> Node -> Node -> m Node
decay2 i atk dcy =  liftM2 (-) (decay i dcy) (decay i atk)

-- | Single sample delay.
delay1 :: ID m => Node -> m Node
delay1 = iir1 (Real_Constant 0) (\_ y1 -> y1)

-- | Two sample delay.
delay2 :: ID m => Node -> m Node
delay2 = iir2 (\_ _ y2 -> y2)

-- | Simple averaging filter.
lag :: ID m => Node -> Node -> m Node
lag i t = let b1 = exp (log (0.001 / (t * sample_rate)))
          in iir1 (Real_Constant 0) (\x y1 -> x + b1 * (y1 - x)) i

-- | Nested lag filter.
lag2 :: ID m => Node -> Node -> m Node
lag2 i t = do a <- lag i t
              lag a t

-- | Twice nested lag filter.
lag3 :: ID m => Node -> Node -> m Node
lag3 i t = do a <- lag i t
              b <- lag a t
              lag b t