module CsoundExpr.Translator.AssignmentElimination 
    (eliminateAssignment,
     SubstExpr(..), RateInfo(..), Opcode, 
     numArgName, strArgName)
where

import Control.Monad.State
import qualified Data.Map as M
import Data.Either

import qualified CsoundExpr.Translator.Cs.CsTree     as La
import qualified CsoundExpr.Translator.Cs.CsoundFile as Cs

import CsoundExpr.Translator.ExprTree.ExprTree
import CsoundExpr.Translator.Types

type SubstTable = M.Map Int (RateInfo, Opcode)

type Opcode = (String, [Cs.ArgIn])

data RateInfo = ZeroOut
              | SingleOut Cs.Rate
              | MultiOut (Purity La.Label) Int [Cs.Rate]
                deriving (Show, Eq, Ord)

data SubstExpr = SubstExpr 
                 { lineNum  :: Int
                 , argOut   :: Maybe Cs.ArgOut
                 , rateInfo :: RateInfo
                 , body     :: Opcode
                 }  deriving (Show)


ppSubstExpr x = "(" ++ (show $ lineNum x) ++ "," ++ (show $ argOut x) ++ "," ++ (show $ body x)

rates = exprType . exprTag . layerOp
op    = exprOp   . exprTag . layerOp



eliminateAssignment :: [ExprLayer Int LaExpr Int] -> [SubstExpr]
eliminateAssignment xs = filter (flip elem ids . lineNum) ys
    where (ys, (ids, _)) = runState (mapM stateFun xs) ([], M.empty)
          stateFun x = state $ \(ids, m) -> substFun x ids m 

substFun :: ExprLayer Int LaExpr Int 
         -> [Int] -> SubstTable 
         -> (SubstExpr, ([Int], SubstTable))
substFun x is m
    | isArgOut  x       = (SubstExpr id (aoArgOut x) (aoRateInfo m x)  (aoBody m x),
                           (id : is, m))
    | La.isVal   $ op x = (SubstExpr id (valArgOut x) (valRateInfo x) (valBody x), 
                           (is, M.insert id (valSubstTableValue x) m))
    | La.isParam $ op x = (SubstExpr id (paramArgOut x) (paramRateInfo x) (paramBody x),
                           (is, M.insert id (paramSubstTableValue x) m))
    | La.isArg   $ op x = (SubstExpr id (aiArgOut x) (aiRateInfo x) (aiBody x),
                           (is, M.insert id (aiSubstTableValue x) m))
    | La.isOpr   $ op x = (SubstExpr id (oprArgOut x) (oprRateInfo x) (oprBody m x), 
                           (is ++ oprIds m x, M.insert id (oprSubstTableValue m x) m))
    | La.isOpc   $ op x = (SubstExpr id (opcArgOut x) (opcRateInfo x) (opcBody m x),
                           (is ++ opcIds m x, M.insert id (opcSubstTableValue m x) m))
    where id = layerOut x


-- val

valArgOut   x = Just $ numArgName Cs.I (layerOut x)
valRateInfo x = SingleOut Cs.I
valBody     x = ("=", [Cs.ArgInValue $ toValue $ La.value $ op x])

valSubstTableValue x = (valRateInfo x, valBody x)

-- param

paramArgOut   x = Just $ numArgName Cs.I (layerOut x)
paramRateInfo x = SingleOut Cs.I
paramBody     x = ("=", [Cs.ArgInParam $ Cs.Param $ La.paramId $ op x])

paramSubstTableValue x = (paramRateInfo x, paramBody x)

-- argOut

aoArgOut     x = Just $ strArgName (toCsRate $ La.argRate $ op x) (La.argName $ op x)
aoRateInfo m x = fst $ m M.! aoArgInId x
aoBody     m x = snd $ m M.! aoArgInId x 

aoArgInId = head . layerIn

-- argIn

aiArgOut   x = Just $ numArgName (toCsRate $ head $ rates x) (layerOut x)
aiRateInfo x = SingleOut $ toCsRate $ head $ rates x
aiBody     x = ("=", [Cs.ArgInName $ strArgName (toCsRate $ La.argRate $ op x) (La.argName $ op x)])

aiSubstTableValue x = (aiRateInfo x, aiBody x)

-- opr

oprArgOut     x = Just $ numArgName (toCsRate $ head $ rates x) (layerOut x)
oprRateInfo   x = SingleOut $ toCsRate $ head $ rates x
oprBody     m x = ("=", [Cs.ArgInOpr (La.oprName $ op x) (toOprType $ La.oprType $ op x) args])
    where args = map (substArg m) $ layerIn x

oprSubstTableValue m x = (oprRateInfo x, oprBody m x)

oprIds m x = filter (not . isSubstId m) $ layerIn x

-- opc

opcArgOut   x = case (rates x) of
                  []     -> Nothing
                  (r:[]) -> Just $ numArgName (toCsRate $ r) (layerOut x)
                  rs     -> let i = exprOutPort $ layerOp x
                            in  Just $ numArgName (toCsRate $ rs !! i) (layerOut x)
opcRateInfo x = case (rates x) of
                  []     -> ZeroOut
                  (r:[]) -> SingleOut (toCsRate r)
                  rs     -> MultiOut (exprPurity  $ layerOp x) 
                                     (exprOutPort $ layerOp x) 
                                     (map toCsRate rs)
opcBody   m x = (La.opcName $ op x, map (substArg m) $ layerIn x)

opcSubstTableValue m x = (opcRateInfo x, opcBody m x)

opcIds m x = case (rates x) of
               [] -> outId : inIds
               _  -> inIds
    where inIds = filter (not . isSubstId m) $ layerIn x
          outId = layerOut x

-- subst args

substArg :: SubstTable -> Int -> Cs.ArgIn
substArg m id = 
    case m M.! id of
      (SingleOut r, (opc, as)) -> if opc == "="
                                  then head as
                                  else Cs.ArgInName $ numArgName r id
      (MultiOut _ i rs, _)     -> Cs.ArgInName $ numArgName (rs !! i) id


isSubstId :: SubstTable -> Int -> Bool
isSubstId m id = 
    case m M.! id of
      (SingleOut r, (opc, as)) -> opc == "="
      _                        -> False

--

numArgName :: Cs.Rate -> Int -> Cs.ArgName
numArgName rate id = Cs.ArgName rate (show id)


strArgName :: Cs.Rate -> String -> Cs.ArgName
strArgName rate name = case rate of
                      Cs.SetupRate -> Cs.ArgName rate name
                      _            -> Cs.ArgName rate ("x" ++ name)