-- | Tools to build Fm synthesis graphs 
--
-- Example
--
-- > f a = fmOut1 $ do
-- > 	x1 <- fmOsc 1
-- > 	x2 <- fmOsc 2
-- > 	x1 `fmod` [(a, x2)]
-- > 	return x1
module Csound.Air.Fm(
	-- * Fm graph
	Fm, FmNode,
	fmOsc', fmOsc, fmSig,
	fmod, 
	fmOut, fmOut1, fmOut2,

	-- * Simplified Fm graph
	FmSpec(..), FmGraph(..), fmRun,
	-- ** Specific graphs
	-- | Algorithms for DX7 fm synth
	dx_1,  dx_2,  dx_3,  dx_4 {-,  dx_5,  dx_6,  dx_7,  dx_8,
	dx_9,  dx_10, dx_11, dx_12, dx_13, dx_14, dx_15, dx_16,
	dx_17, dx_18, dx_19, dx_20, dx_21, dx_22, dx_23, dx_24,
	dx_25, dx_26, dx_27, dx_28, dx_29, dx_30, dx_31, dx_32 -}
) where

import qualified Data.IntMap as IM

import Control.Monad.Trans.State.Strict
import Control.Monad

import Csound.Typed
import Csound.Air.Wave
import Csound.SigSpace

-- Fm graph rendering

type Fm a = State St a

newtype FmNode = FmNode { unFmNode :: Int }

type FmIdx = (Int, Sig)

data Fmod = Fmod (Sig -> SE Sig) Sig [FmIdx] | Fsig Sig

data St = St 
	{ newIdx     :: Int
	, units      :: [Fmod]
	, links      :: IM.IntMap [FmIdx]
	}

defSt = St 
	{ newIdx = 0
	, units = []	
	, links = IM.empty }

renderGraph :: [Fmod] -> [FmIdx] -> Sig -> SE [Sig]
renderGraph units outs cps = do
	refs <- initUnits (length units)
	mapM_ (loopUnit refs) (zip [0 .. ] units)
	mapM (renderIdx refs) outs
	where
		initUnits n = mapM (const $ newRef (0 :: Sig)) [1 .. n]

		loopUnit refs (n, x) = writeRef (refs !! n) =<< case x of
			Fsig asig -> return asig
			Fmod wave mod subs -> do
				s <- fmap sum $ mapM (renderModIdx refs) subs
				wave (cps * mod + s)
			where 

		renderIdx :: [Ref Sig] -> (Int, Sig) -> SE Sig
		renderIdx refs (idx, amp) = mul amp $ readRef (refs !! idx)

		renderModIdx :: [Ref Sig] -> (Int, Sig) -> SE Sig
		renderModIdx refs (idx, amp) = mul (amp * mod) $ readRef (refs !! idx)	
			where mod = case (units !! idx) of
					Fmod _ m _ -> m * cps
					_          -> 1


mkGraph :: St -> [Fmod]
mkGraph s = zipWith extractMod (reverse $ units s) [0 .. ]
	where
		extractMod x n = case x of
			Fmod alg w _ -> Fmod alg w (maybe [] id $ IM.lookup n (links s))
			_            -> x

toFmIdx :: (Sig, FmNode) -> FmIdx
toFmIdx (amp, FmNode n) = (n, amp)

---------------------------------------------------------
-- constructors

-- | Creates fm node with generic wave.
--
-- > fmOsc' wave modFreq
fmOsc' :: (Sig -> SE Sig) -> Sig -> Fm FmNode
fmOsc' wave idx = newFmod (Fmod wave idx [])

-- | Creates fm node with sine wave.
--
-- > fmOsc modFreq
fmOsc :: Sig -> Fm FmNode
fmOsc = fmOsc' rndOsc

-- | Creates fm node with signal generator (it's independent from the main frequency).
fmSig :: Sig -> Fm FmNode
fmSig a = newFmod (Fsig a)

newFmod :: Fmod -> Fm FmNode
newFmod a = state $ \s -> 
	let n  = newIdx s
	    s1 = s { newIdx = n + 1, units = a : units s }
	in  (FmNode n, s1)

-- modulator

fmod :: FmNode -> [(Sig, FmNode)] -> Fm ()
fmod (FmNode idx) mods = state $ \s -> 
	((), s { links = IM.insertWithKey (\_ a b -> a ++ b) idx (fmap toFmIdx mods) (links s) })

-- outputs

-- | Renders Fm synth to function.
fmOut :: Fm [(Sig, FmNode)] -> Sig -> SE [Sig]
fmOut fm = renderGraph (mkGraph s) (fmap toFmIdx outs)
	where (outs, s) = runState fm defSt

-- | Renders mono output.
fmOut1 :: Fm FmNode -> Sig -> SE Sig
fmOut1 fm cps = fmap head $ fmOut (fmap (\x -> [(1, x)]) fm) cps

-- | Renders stereo output.
fmOut2 :: Fm (FmNode, FmNode) -> Sig -> SE Sig2
fmOut2 fm cps = fmap (\[a, b] -> (a, b)) $ fmOut (fmap (\(a, b) -> [(1, a), (1, b)]) fm) cps

-----------------------------------------------------------------------

data FmSpec = FmSpec 
	{ fmWave :: [Sig -> SE Sig]
	, fmCps :: [Sig]
	, fmInd :: [Sig]
	, fmOuts :: [Sig] }

data FmGraph = FmGraph 
	{ fmGraph 	:: [(Int, [Int])]
	, fmGraphOuts :: [Int] }

fmRun :: FmGraph -> FmSpec -> Sig -> SE Sig
fmRun graph spec' cps = fmap sum $ ($ cps) $ fmOut $ do
	ops <- zipWithM fmOsc' (fmWave spec) (fmCps spec)
	mapM_ (mkMod ops (fmInd spec)) (fmGraph graph)
	return $ zipWith (toOut ops) (fmOuts spec) (fmGraphOuts graph)
	where 
		spec = addDefaults spec'
		toOut xs amp n = (amp, xs !! n)
		mkMod ops ixs (n, ms) = (ops !! n) `fmod` (fmap (\m -> (ixs !! m, ops !! m)) ms)

addDefaults :: FmSpec -> FmSpec
addDefaults spec = spec 
	{ fmWave = fmWave spec ++ repeat rndOsc	
	, fmCps  = fmCps  spec ++ repeat 1
	, fmInd  = fmInd  spec ++ repeat 1
	, fmOuts = fmOuts spec ++ repeat 1 }

{-|
>	 	+--+
>		6  |
>		+--+
>		5
>		|
>	2	4
>	|	|
>	1	3
>	+---+
-}
dx_1 = FmGraph 
	{ fmGraphOuts = [1, 3]
	, fmGraph = 
		[ (1, [2])
		, (3, [4])
		, (4, [5])
		, (5, [6])
		, (6, [6]) ]}

{-|
>         6 
>         |
>         5
>   +--+  |
>	2  |  4
>	+--+  |
>	1     3
>   +-----+
-}
dx_2 = FmGraph 
	{ fmGraphOuts = [1, 3]
	, fmGraph = 
		[ (1, [2])
		, (2, [2])
		, (3, [4])
		, (5, [6]) ]}

{-|
>	    +--+
>	3   6  |
>	|   +--+
>	2   5
>	|	|
>	1 	4
>	+---+
-}
dx_3 = FmGraph 
	{ fmGraphOuts = [1, 4]
	, fmGraph = 
		[ (1, [2])
		, (2, [3])
		, (4, [5])
		, (5, [6])
		, (6, [6]) ]}

{-|
>			+--+
>		3	6  |
>		|	|  |	
>		2	5  |
>		|	|  |
>		1	4  |
>		|	+--+
>       +---+ 
-}
dx_4 = FmGraph
	{ fmGraphOuts = [1, 4]
	, fmGraph = 
		[ (1, [2])
		, (2, [3])
		, (4, [5])
		, (5, [6])
		, (6, [4]) ]}

{-
dx12 = DxGraph 
	{ dxGraphOuts = [3, 1]
	, dxGraph = 
		[ (3, [4, 5, 6])
		, (1, [2])
		, (2, [2]) ]}

-}