module CsoundExpr.Translator.MultiOutsFolding 
    (foldMultiOuts)
where

import Data.List
import Data.Function
import Data.Maybe

import qualified Data.Map as M

import Control.Monad
import Control.Monad.State


import CsoundExpr.Translator.AssignmentElimination

import qualified CsoundExpr.Translator.Cs.CsoundFile as Cs
import CsoundExpr.Translator.Cs.Utils



isMO :: SubstExpr -> Bool
isMO x = case rateInfo x of
           (MultiOut _ _ _) -> True
           _                -> False


foldMultiOuts :: [SubstExpr] -> [Cs.OpcodeExpr]
foldMultiOuts xs = map snd $ 
                   unionSortedBy (compare `on` fst)
                                 (map procS s) 
                                 (procMOs (length xs) $ sortMOs $ groupMOs m)
    where (m, s) = partition isMO xs


procS :: SubstExpr -> (Int, Cs.OpcodeExpr)
procS x = (lineNum x, uncurry (Cs.OpcodeExpr $ outs x) $ body x)
    where outs = fromMaybe [] . (liftM (:[])) . argOut


groupMOs :: [SubstExpr] -> [[SubstExpr]]
groupMOs = groupBy ((==) `on` pred) . sortBy (compare `on` pred)
    where pred x = case rateInfo x of
                     MultiOut purity _ rates -> (purity, rates, body x) 


sortMOs :: [[SubstExpr]] -> [(Int, [(Maybe Cs.ArgOut, Cs.Rate)], Opcode)]
sortMOs = (flattenMOs =<< ) . map foldMOs


data FoldedMO = FoldedMO
                { moLineNum  :: Int
                , moPort     :: Int
                , moArgOut   :: Cs.ArgOut
                } deriving (Show)

foldMOs :: [SubstExpr] -> ([FoldedMO], [Cs.Rate], Opcode)
foldMOs xs = (map outs xs, rates xs, opc xs)             
    where opc     = body . head
          rates x = case rateInfo $ head x of
                      MultiOut _ _ rs -> rs
          outs  x = case rateInfo x of
                      MultiOut _ port _ -> 
                          FoldedMO (lineNum x) port (fromJust $ argOut x)


flattenMOs :: ([FoldedMO], [Cs.Rate], Opcode) 
           -> [(Int, [(Maybe Cs.ArgOut, Cs.Rate)], Opcode)]
flattenMOs (xs, rates, opc) = map (\(i, os) -> (i, os, opc)) $ flatten rates xs
    
flatten :: [Cs.Rate] -> [FoldedMO] -> [(Int, [(Maybe Cs.ArgOut, Cs.Rate)])]
flatten rs xs = map (\x -> (formId x, formRates x)) ys
    where formId    = minimum . map (fst . fromJust) . filter isJust 
          formRates = flip zip rs . map (liftM snd) 
          ys     = transpose $ map (getPortVals m k) [0 .. n - 1]
          n      = length rs
          (m, k) = foldPorts xs


getPortVals :: M.Map Int [(Int, Cs.ArgOut)]
            -> Int -> Int 
            -> [Maybe (Int, Cs.ArgOut)]
getPortVals m n id = map Just q ++ replicate (n - length q) Nothing
    where q = case M.lookup id m of
                (Just xs) -> xs
                Nothing   -> []

foldPorts :: [FoldedMO] -> (M.Map Int [(Int, Cs.ArgOut)], Int)
foldPorts = (\x -> (M.fromList x, getMaxLength x)) . map select . 
            groupBy ((==) `on` moPort) . sortBy (compare `on` moPort) 
    where getMaxLength = maximum . map (length . snd)
          select xs    = let key = moPort $ head xs
                             val = map (\x -> (moLineNum x, moArgOut x)) xs
                         in  (key, val)
         


procMOs :: Int -> [(Int, [(Maybe Cs.ArgOut, Cs.Rate)], Opcode)] 
        -> [(Int, Cs.OpcodeExpr)]
procMOs n xs = fst $ runState (mapM f xs) n
    where f (id, outs, opc) = liftM (\x -> (id, x)) $ mkOpcode outs opc
          mkOpcode outs opc = liftM (\x -> uncurry (Cs.OpcodeExpr x) opc) $ 
                              mapM mkOuts outs


mkOuts :: (Maybe Cs.ArgOut, Cs.Rate) -> State Int Cs.ArgOut
mkOuts (arg, rate) = 
    State $ \n -> case arg of
                    Nothing  -> (numArgName rate n, n + 1)
                    (Just a) -> (a, n)