{-# Language TupleSections #-}
module Csound.Dynamic.Tfm.UnfoldMultiOuts(
  unfoldMultiOuts, Selector(..)
) where

import Data.List(sortBy)
import Data.Ord(comparing)
import Control.Monad.Trans.State.Strict
import qualified Data.IntMap.Strict as IM
import Data.Either (partitionEithers)

import Csound.Dynamic.Tfm.InferTypes(Var(..), Stmt(..), InferenceResult(..))
import Csound.Dynamic.Types.Exp hiding (Var (..))
import Csound.Dynamic.Build(getRates, isMultiOutSignature)

type ChildrenMap = IM.IntMap [Port]

lookupChildren :: ChildrenMap -> Var -> [Port]
lookupChildren :: ChildrenMap -> Var -> [Port]
lookupChildren ChildrenMap
m Var
parentVar = ChildrenMap
m ChildrenMap -> Int -> [Port]
forall a. IntMap a -> Int -> a
IM.! Var -> Int
varId Var
parentVar

mkChildrenMap :: [(Var, Selector)] -> ChildrenMap
mkChildrenMap :: [(Var, Selector)] -> ChildrenMap
mkChildrenMap = ([Port] -> [Port] -> [Port]) -> [(Int, [Port])] -> ChildrenMap
forall a. (a -> a -> a) -> [(Int, a)] -> IntMap a
IM.fromListWith [Port] -> [Port] -> [Port]
forall a. [a] -> [a] -> [a]
(++) ([(Int, [Port])] -> ChildrenMap)
-> ([(Var, Selector)] -> [(Int, [Port])])
-> [(Var, Selector)]
-> ChildrenMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Var, Selector) -> (Int, [Port]))
-> [(Var, Selector)] -> [(Int, [Port])]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Var, Selector) -> (Int, [Port])
forall {m :: * -> *}. Monad m => (Var, Selector) -> (Int, m Port)
extract
    where extract :: (Var, Selector) -> (Int, m Port)
extract (Var
var, Selector
sel) = (Var -> Int
varId (Var -> Int) -> Var -> Int
forall a b. (a -> b) -> a -> b
$ Selector -> Var
selectorParent Selector
sel,
                                Port -> m Port
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Port -> m Port) -> Port -> m Port
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Port
Port (Var -> Int
varId Var
var) (Selector -> Int
selectorOrder Selector
sel))

data Port = Port
    { Port -> Int
portId    :: Int
    , Port -> Int
portOrder :: Int } deriving (Int -> Port -> ShowS
[Port] -> ShowS
Port -> String
(Int -> Port -> ShowS)
-> (Port -> String) -> ([Port] -> ShowS) -> Show Port
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Port -> ShowS
showsPrec :: Int -> Port -> ShowS
$cshow :: Port -> String
show :: Port -> String
$cshowList :: [Port] -> ShowS
showList :: [Port] -> ShowS
Show)

type SingleStmt = Stmt Var
type MultiStmt  = ([Var], RatedExp Var)

data Selector = Selector
    { Selector -> Var
selectorParent  :: Var
    , Selector -> Int
selectorOrder   :: Int
    }

unfoldMultiOuts :: InferenceResult -> ([MultiStmt], Int)
unfoldMultiOuts :: InferenceResult -> ([MultiStmt], Int)
unfoldMultiOuts InferenceResult{Bool
Int
[Stmt Var]
typedProgram :: [Stmt Var]
programLastFreshId :: Int
programHasIfs :: Bool
typedProgram :: InferenceResult -> [Stmt Var]
programLastFreshId :: InferenceResult -> Int
programHasIfs :: InferenceResult -> Bool
..} = State Int [MultiStmt] -> Int -> ([MultiStmt], Int)
forall s a. State s a -> s -> (a, s)
runState State Int [MultiStmt]
st Int
programLastFreshId
    where
      ([Stmt Var]
noSelectorStmts, [(Var, Selector)]
selectors) = [Either (Stmt Var) (Var, Selector)]
-> ([Stmt Var], [(Var, Selector)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Stmt Var) (Var, Selector)]
 -> ([Stmt Var], [(Var, Selector)]))
-> [Either (Stmt Var) (Var, Selector)]
-> ([Stmt Var], [(Var, Selector)])
forall a b. (a -> b) -> a -> b
$
        (Stmt Var -> Either (Stmt Var) (Var, Selector))
-> [Stmt Var] -> [Either (Stmt Var) (Var, Selector)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\stmt :: Stmt Var
stmt@(Stmt Var
lhs RatedExp Var
rhs) -> Either (Stmt Var) (Var, Selector)
-> (Selector -> Either (Stmt Var) (Var, Selector))
-> Maybe Selector
-> Either (Stmt Var) (Var, Selector)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Stmt Var -> Either (Stmt Var) (Var, Selector)
forall a b. a -> Either a b
Left Stmt Var
stmt) ((Var, Selector) -> Either (Stmt Var) (Var, Selector)
forall a b. b -> Either a b
Right ((Var, Selector) -> Either (Stmt Var) (Var, Selector))
-> (Selector -> (Var, Selector))
-> Selector
-> Either (Stmt Var) (Var, Selector)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Var
lhs, )) (Maybe Selector -> Either (Stmt Var) (Var, Selector))
-> Maybe Selector -> Either (Stmt Var) (Var, Selector)
forall a b. (a -> b) -> a -> b
$ RatedExp Var -> Maybe Selector
getSelector RatedExp Var
rhs) [Stmt Var]
typedProgram
      st :: State Int [MultiStmt]
st = (Stmt Var -> StateT Int Identity MultiStmt)
-> [Stmt Var] -> State Int [MultiStmt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ChildrenMap -> Stmt Var -> StateT Int Identity MultiStmt
unfoldStmt (ChildrenMap -> Stmt Var -> StateT Int Identity MultiStmt)
-> ChildrenMap -> Stmt Var -> StateT Int Identity MultiStmt
forall a b. (a -> b) -> a -> b
$ [(Var, Selector)] -> ChildrenMap
mkChildrenMap [(Var, Selector)]
selectors) ([Stmt Var] -> State Int [MultiStmt])
-> [Stmt Var] -> State Int [MultiStmt]
forall a b. (a -> b) -> a -> b
$ [Stmt Var]
noSelectorStmts

unfoldStmt :: ChildrenMap -> SingleStmt -> State Int MultiStmt
unfoldStmt :: ChildrenMap -> Stmt Var -> StateT Int Identity MultiStmt
unfoldStmt ChildrenMap
childrenMap (Stmt Var
lhs RatedExp Var
rhs) = case RatedExp Var -> Maybe [Rate]
getParentTypes RatedExp Var
rhs of
    Maybe [Rate]
Nothing    -> MultiStmt -> StateT Int Identity MultiStmt
forall a. a -> StateT Int Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Var
lhs], RatedExp Var
rhs)
    Just [Rate]
types -> ([Var] -> MultiStmt)
-> StateT Int Identity [Var] -> StateT Int Identity MultiStmt
forall a b.
(a -> b) -> StateT Int Identity a -> StateT Int Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,RatedExp Var
rhs) (StateT Int Identity [Var] -> StateT Int Identity MultiStmt)
-> StateT Int Identity [Var] -> StateT Int Identity MultiStmt
forall a b. (a -> b) -> a -> b
$ [Port] -> [Rate] -> StateT Int Identity [Var]
formLhs (ChildrenMap -> Var -> [Port]
lookupChildren ChildrenMap
childrenMap Var
lhs) [Rate]
types

formLhs :: [Port] -> [Rate] -> State Int [Var]
formLhs :: [Port] -> [Rate] -> StateT Int Identity [Var]
formLhs [Port]
ports [Rate]
types = ([Int] -> [Var])
-> StateT Int Identity [Int] -> StateT Int Identity [Var]
forall a b.
(a -> b) -> StateT Int Identity a -> StateT Int Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Rate -> Int -> Var) -> [Rate] -> [Int] -> [Var]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Rate -> Int -> Var
Var [Rate]
types) ([Port] -> StateT Int Identity [Int]
forall {m :: * -> *}. Monad m => [Port] -> StateT Int m [Int]
getPorts [Port]
ports)
    where getPorts :: [Port] -> StateT Int m [Int]
getPorts [Port]
ps = (Int -> ([Int], Int)) -> StateT Int m [Int]
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> ([Int], Int)) -> StateT Int m [Int])
-> (Int -> ([Int], Int)) -> StateT Int m [Int]
forall a b. (a -> b) -> a -> b
$ \Int
lastFreshId ->
            let ps' :: [Port]
ps' = (Port -> Port -> Ordering) -> [Port] -> [Port]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((Port -> Int) -> Port -> Port -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing Port -> Int
portOrder) [Port]
ps
                ([[Int]]
ids, Int
lastPortOrder) = State Int [[Int]] -> Int -> ([[Int]], Int)
forall s a. State s a -> s -> (a, s)
runState ((Port -> StateT Int Identity [Int]) -> [Port] -> State Int [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> Port -> StateT Int Identity [Int]
fillMissingPorts Int
lastFreshId) [Port]
ps') Int
0
                freshIdForTail :: Int
freshIdForTail = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lastFreshId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
inUsePortsSize
                tailIds :: [Int]
tailIds = (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
freshIdForTail) [Int
0 .. Int
outputArity Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lastPortOrder]
            in  ([[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Int]]
ids [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
tailIds, Int
lastFreshId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
outputArity Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
inUsePortsSize)

          outputArity :: Int
outputArity = [Rate] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Rate]
types
          inUsePortsSize :: Int
inUsePortsSize = [Port] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Port]
ports

          fillMissingPorts :: Int -> Port -> State Int [Int]
          fillMissingPorts :: Int -> Port -> StateT Int Identity [Int]
fillMissingPorts Int
lastFreshId Port
port = (Int -> ([Int], Int)) -> StateT Int Identity [Int]
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> ([Int], Int)) -> StateT Int Identity [Int])
-> (Int -> ([Int], Int)) -> StateT Int Identity [Int]
forall a b. (a -> b) -> a -> b
$ \Int
s ->
                if Int
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
order
                then ([Int
e], Int
next)
                else ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lastFreshId) [Int
s .. Int
order Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
e], Int
next)
            where e :: Int
e = Port -> Int
portId Port
port
                  order :: Int
order = Port -> Int
portOrder Port
port
                  next :: Int
next = Int
order Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

-----------------------------------------------------------------------
-- unfolds multiple rates generic functions

getSelector :: RatedExp Var -> Maybe Selector
getSelector :: RatedExp Var -> Maybe Selector
getSelector RatedExp Var
x =
  case RatedExp Var -> Exp Var
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Var
x of
    Select Rate
_ Int
order (PrimOr (Right Var
parent)) -> Selector -> Maybe Selector
forall a. a -> Maybe a
Just (Selector -> Maybe Selector) -> Selector -> Maybe Selector
forall a b. (a -> b) -> a -> b
$ Var -> Int -> Selector
Selector Var
parent Int
order
    Exp Var
_ -> Maybe Selector
forall a. Maybe a
Nothing

getParentTypes :: RatedExp Var -> Maybe [Rate]
getParentTypes :: RatedExp Var -> Maybe [Rate]
getParentTypes RatedExp Var
x =
  case RatedExp Var -> Exp Var
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Var
x of
    Tfm Info
i [PrimOr Var]
_ -> if (Signature -> Bool
isMultiOutSignature (Signature -> Bool) -> Signature -> Bool
forall a b. (a -> b) -> a -> b
$ Info -> Signature
infoSignature Info
i)
                then [Rate] -> Maybe [Rate]
forall a. a -> Maybe a
Just (Exp Var -> [Rate]
forall a. MainExp a -> [Rate]
getRates (Exp Var -> [Rate]) -> Exp Var -> [Rate]
forall a b. (a -> b) -> a -> b
$ RatedExp Var -> Exp Var
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Var
x)
                else Maybe [Rate]
forall a. Maybe a
Nothing
    Exp Var
_ -> Maybe [Rate]
forall a. Maybe a
Nothing