module Csound.Dynamic.Tfm.DeduceTypes(
    Var(..), TypeGraph(..), Convert(..), Stmt, deduceTypes
) where

import Data.List(nub)
import qualified Data.Map as M
import qualified Data.IntMap as IM
import qualified Data.Traversable as T

import Data.STRef
import Control.Monad.ST
import Data.Array.ST

type TypeRequests s ty = STArray s Int [ty]

initTypeRequests :: Int -> ST s (TypeRequests s ty)
initTypeRequests :: Int -> ST s (TypeRequests s ty)
initTypeRequests Int
size = (Int, Int) -> [ty] -> ST s (TypeRequests s ty)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) []

requestType :: Var ty -> TypeRequests s ty -> ST s ()
requestType :: Var ty -> TypeRequests s ty -> ST s ()
requestType Var ty
v TypeRequests s ty
arr = TypeRequests s ty -> Int -> ([ty] -> [ty]) -> ST s ()
forall i s a. Ix i => STArray s i a -> i -> (a -> a) -> ST s ()
modifyArray TypeRequests s ty
arr (Var ty -> Int
forall a. Var a -> Int
varId Var ty
v) (Var ty -> ty
forall a. Var a -> a
varType Var ty
v ty -> [ty] -> [ty]
forall a. a -> [a] -> [a]
:)

modifyArray :: Ix i => STArray s i a -> i -> (a -> a) -> ST s ()
modifyArray :: STArray s i a -> i -> (a -> a) -> ST s ()
modifyArray STArray s i a
arr i
i a -> a
f = STArray s i a -> i -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s i a
arr i
i (a -> ST s ()) -> (a -> a) -> a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f (a -> ST s ()) -> ST s a -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STArray s i a -> i -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s i a
arr i
i

getTypes :: Int -> TypeRequests s ty -> ST s [ty]
getTypes :: Int -> TypeRequests s ty -> ST s [ty]
getTypes Int
n TypeRequests s ty
arr = TypeRequests s ty -> Int -> ST s [ty]
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray TypeRequests s ty
arr Int
n

-- | Typed variable.
data Var a = Var
    { Var a -> Int
varId   :: Int
    , Var a -> a
varType :: a
    } deriving (Int -> Var a -> ShowS
[Var a] -> ShowS
Var a -> String
(Int -> Var a -> ShowS)
-> (Var a -> String) -> ([Var a] -> ShowS) -> Show (Var a)
forall a. Show a => Int -> Var a -> ShowS
forall a. Show a => [Var a] -> ShowS
forall a. Show a => Var a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Var a] -> ShowS
$cshowList :: forall a. Show a => [Var a] -> ShowS
show :: Var a -> String
$cshow :: forall a. Show a => Var a -> String
showsPrec :: Int -> Var a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Var a -> ShowS
Show, Var a -> Var a -> Bool
(Var a -> Var a -> Bool) -> (Var a -> Var a -> Bool) -> Eq (Var a)
forall a. Eq a => Var a -> Var a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Var a -> Var a -> Bool
$c/= :: forall a. Eq a => Var a -> Var a -> Bool
== :: Var a -> Var a -> Bool
$c== :: forall a. Eq a => Var a -> Var a -> Bool
Eq, Eq (Var a)
Eq (Var a)
-> (Var a -> Var a -> Ordering)
-> (Var a -> Var a -> Bool)
-> (Var a -> Var a -> Bool)
-> (Var a -> Var a -> Bool)
-> (Var a -> Var a -> Bool)
-> (Var a -> Var a -> Var a)
-> (Var a -> Var a -> Var a)
-> Ord (Var a)
Var a -> Var a -> Bool
Var a -> Var a -> Ordering
Var a -> Var a -> Var a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Var a)
forall a. Ord a => Var a -> Var a -> Bool
forall a. Ord a => Var a -> Var a -> Ordering
forall a. Ord a => Var a -> Var a -> Var a
min :: Var a -> Var a -> Var a
$cmin :: forall a. Ord a => Var a -> Var a -> Var a
max :: Var a -> Var a -> Var a
$cmax :: forall a. Ord a => Var a -> Var a -> Var a
>= :: Var a -> Var a -> Bool
$c>= :: forall a. Ord a => Var a -> Var a -> Bool
> :: Var a -> Var a -> Bool
$c> :: forall a. Ord a => Var a -> Var a -> Bool
<= :: Var a -> Var a -> Bool
$c<= :: forall a. Ord a => Var a -> Var a -> Bool
< :: Var a -> Var a -> Bool
$c< :: forall a. Ord a => Var a -> Var a -> Bool
compare :: Var a -> Var a -> Ordering
$ccompare :: forall a. Ord a => Var a -> Var a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Var a)
Ord)

data GetType ty
    = NoConversion ty
    -- If there is a conversion we look for a fresh identifier by map
    -- (map converts mismatched type to fresh identifier)
    | ConversionLookup (Var ty) (M.Map ty Int)

type TypeMap ty = IM.IntMap (GetType ty)

lookupVar :: (Show a, Ord a) => TypeMap a -> Var a -> Var a
lookupVar :: TypeMap a -> Var a -> Var a
lookupVar TypeMap a
m (Var Int
i a
r) = case TypeMap a
m TypeMap a -> Int -> GetType a
forall a. IntMap a -> Int -> a
IM.! Int
i of
    NoConversion     a
ty        -> Int -> a -> Var a
forall a. Int -> a -> Var a
Var Int
i a
ty
    ConversionLookup Var a
noConv Map a Int
f  -> Var a -> (Int -> Var a) -> Maybe Int -> Var a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Var a
noConv ((Int -> a -> Var a) -> a -> Int -> Var a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> a -> Var a
forall a. Int -> a -> Var a
Var a
r) (Maybe Int -> Var a) -> Maybe Int -> Var a
forall a b. (a -> b) -> a -> b
$ a -> Map a Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup a
r Map a Int
f

-- Statement: assignment, like
--    leftHandSide = RightHandSide( arguments )
type Stmt f a = (a, f a)

-- When we have type collisions we have to insert converters:
data Convert a = Convert
    { Convert a -> Var a
convertFrom   :: Var a
    , Convert a -> Var a
convertTo     :: Var a }

data Line f a = Line
    { Line f a -> (Int, GetType a)
lineType      :: (Int, GetType a)
    , Line f a -> Stmt f (Var a)
lineStmt      :: Stmt f (Var a)
    , Line f a -> [Convert a]
lineConverts  :: [Convert a] }

-- Algorithm specification for the given functor 'f' and type labels of 'a'.
data TypeGraph f a = TypeGraph
    -- create a type conversion statement
    { TypeGraph f a -> Convert a -> Stmt f (Var a)
mkConvert   :: Convert a -> Stmt f (Var a)
    -- for a given statement and a list of requested types for the output produces a pair of
    -- (nonConvertibleTypes, statementWithDeducedTypes)
    -- nonConvertibleTypes is used for insertion of converters.
    , TypeGraph f a -> Stmt f Int -> [a] -> ([a], Stmt f (Var a))
defineType  :: Stmt f Int -> [a] -> ([a], Stmt f (Var a)) }

-- | Deduces types for a dag:
--
-- deduceTypes (functorSpecificFuns) (dag) = (dagWithTypes, lastFreshIdentifier)
--
-- Assumption -- dag is labeled with integers. Labels are unique
-- and a list of labels is a range (0, n) (It's just what we get with CSE algorithm).
--
-- Algorithm proceeds as follows. We init an array of type requests and a reference for fresh identifiers.
-- Type request comes from right hand side of the statement. We need fresh identifiers for converters.
-- If we are going to use a new statement for conversion we need new variables.
--
-- (discussLine)
-- Then we process lines in reverse order and collect type requests by looking at right hand sides
-- and writing type requests for all arguments.
--
-- (processLine)
-- In the second run we substitute all identifiers with typed variables. It's no so strightforward
-- due to converters. If there are converters we have to insert new statements and substitute identifiers
-- with new ones. That's why we convert variables to variables in the processLine.
--
deduceTypes :: (Show a, Ord a, T.Traversable f) => TypeGraph f a -> [Stmt f Int] -> ([Stmt f (Var a)], Int)
deduceTypes :: TypeGraph f a -> [Stmt f Int] -> ([Stmt f (Var a)], Int)
deduceTypes TypeGraph f a
spec [Stmt f Int]
as = (forall s. ST s ([Stmt f (Var a)], Int)) -> ([Stmt f (Var a)], Int)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ([Stmt f (Var a)], Int))
 -> ([Stmt f (Var a)], Int))
-> (forall s. ST s ([Stmt f (Var a)], Int))
-> ([Stmt f (Var a)], Int)
forall a b. (a -> b) -> a -> b
$ do
    STRef s Int
freshIds <- Int -> ST s (STRef s Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
n
    TypeRequests s a
typeRequests <- Int -> ST s (TypeRequests s a)
forall s ty. Int -> ST s (TypeRequests s ty)
initTypeRequests Int
n
    [Line f a]
lines' <- (Stmt f Int -> ST s (Line f a)) -> [Stmt f Int] -> ST s [Line f a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TypeGraph f a
-> TypeRequests s a -> STRef s Int -> Stmt f Int -> ST s (Line f a)
forall a (f :: * -> *) s.
(Ord a, Traversable f) =>
TypeGraph f a
-> TypeRequests s a -> STRef s Int -> Stmt f Int -> ST s (Line f a)
discussLine TypeGraph f a
spec TypeRequests s a
typeRequests STRef s Int
freshIds) ([Stmt f Int] -> ST s [Line f a])
-> [Stmt f Int] -> ST s [Line f a]
forall a b. (a -> b) -> a -> b
$ [Stmt f Int] -> [Stmt f Int]
forall a. [a] -> [a]
reverse [Stmt f Int]
as
    let typeMap :: IntMap (GetType a)
typeMap = [(Int, GetType a)] -> IntMap (GetType a)
forall a. [(Int, a)] -> IntMap a
IM.fromList ([(Int, GetType a)] -> IntMap (GetType a))
-> [(Int, GetType a)] -> IntMap (GetType a)
forall a b. (a -> b) -> a -> b
$ (Line f a -> (Int, GetType a)) -> [Line f a] -> [(Int, GetType a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Line f a -> (Int, GetType a)
forall (f :: * -> *) a. Line f a -> (Int, GetType a)
lineType [Line f a]
lines'
    Int
lastId <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
freshIds
    ([Stmt f (Var a)], Int) -> ST s ([Stmt f (Var a)], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stmt f (Var a)] -> [Stmt f (Var a)]
forall a. [a] -> [a]
reverse ([Stmt f (Var a)] -> [Stmt f (Var a)])
-> [Stmt f (Var a)] -> [Stmt f (Var a)]
forall a b. (a -> b) -> a -> b
$ IntMap (GetType a) -> Line f a -> [Stmt f (Var a)]
processLine IntMap (GetType a)
typeMap (Line f a -> [Stmt f (Var a)]) -> [Line f a] -> [Stmt f (Var a)]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Line f a]
lines', Int
lastId)
    where n :: Int
n = Int -> Int
forall a. Enum a => a -> a
succ (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ if ([Stmt f Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Stmt f Int]
as) then Int
0 else (Stmt f Int -> Int
forall a b. (a, b) -> a
fst (Stmt f Int -> Int) -> Stmt f Int -> Int
forall a b. (a -> b) -> a -> b
$ [Stmt f Int] -> Stmt f Int
forall a. [a] -> a
last [Stmt f Int]
as)
          processLine :: IntMap (GetType a) -> Line f a -> [Stmt f (Var a)]
processLine IntMap (GetType a)
typeMap Line f a
line = (Convert a -> Stmt f (Var a)) -> [Convert a] -> [Stmt f (Var a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeGraph f a -> Convert a -> Stmt f (Var a)
forall (f :: * -> *) a.
TypeGraph f a -> Convert a -> Stmt f (Var a)
mkConvert TypeGraph f a
spec) (Line f a -> [Convert a]
forall (f :: * -> *) a. Line f a -> [Convert a]
lineConverts Line f a
line) [Stmt f (Var a)] -> [Stmt f (Var a)] -> [Stmt f (Var a)]
forall a. [a] -> [a] -> [a]
++ [(Var a
a, (Var a -> Var a) -> f (Var a) -> f (Var a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (IntMap (GetType a) -> Var a -> Var a
forall a. (Show a, Ord a) => TypeMap a -> Var a -> Var a
lookupVar IntMap (GetType a)
typeMap) f (Var a)
b)]
              where (Var a
a, f (Var a)
b) = Line f a -> Stmt f (Var a)
forall (f :: * -> *) a. Line f a -> Stmt f (Var a)
lineStmt Line f a
line

discussLine :: (Ord a, T.Traversable f) => TypeGraph f a -> TypeRequests s a -> STRef s Int -> Stmt f Int -> ST s (Line f a)
discussLine :: TypeGraph f a
-> TypeRequests s a -> STRef s Int -> Stmt f Int -> ST s (Line f a)
discussLine TypeGraph f a
spec TypeRequests s a
typeRequests STRef s Int
freshIds stmt :: Stmt f Int
stmt@(Int
pid, f Int
_) = do
    ([a]
conv, Stmt f (Var a)
expr') <- ([a] -> ([a], Stmt f (Var a)))
-> ST s [a] -> ST s ([a], Stmt f (Var a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeGraph f a -> Stmt f Int -> [a] -> ([a], Stmt f (Var a))
forall (f :: * -> *) a.
TypeGraph f a -> Stmt f Int -> [a] -> ([a], Stmt f (Var a))
defineType TypeGraph f a
spec Stmt f Int
stmt ([a] -> ([a], Stmt f (Var a)))
-> ([a] -> [a]) -> [a] -> ([a], Stmt f (Var a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [a]
forall a. Eq a => [a] -> [a]
nub) (ST s [a] -> ST s ([a], Stmt f (Var a)))
-> ST s [a] -> ST s ([a], Stmt f (Var a))
forall a b. (a -> b) -> a -> b
$ Int -> TypeRequests s a -> ST s [a]
forall s ty. Int -> TypeRequests s ty -> ST s [ty]
getTypes Int
pid TypeRequests s a
typeRequests
    f ()
_ <- (Var a -> ST s ()) -> f (Var a) -> ST s (f ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
T.traverse ((Var a -> TypeRequests s a -> ST s ())
-> TypeRequests s a -> Var a -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Var a -> TypeRequests s a -> ST s ()
forall ty s. Var ty -> TypeRequests s ty -> ST s ()
requestType TypeRequests s a
typeRequests) (Stmt f (Var a) -> f (Var a)
forall a b. (a, b) -> b
snd Stmt f (Var a)
expr')
    let curType :: Var a
curType = Stmt f (Var a) -> Var a
forall a b. (a, b) -> a
fst Stmt f (Var a)
expr'
    (GetType a
getType, [Convert a]
convs) <- [a] -> Var a -> STRef s Int -> ST s (GetType a, [Convert a])
forall a s.
Ord a =>
[a] -> Var a -> STRef s Int -> ST s (GetType a, [Convert a])
mkGetType [a]
conv Var a
curType STRef s Int
freshIds
    Line f a -> ST s (Line f a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Line f a -> ST s (Line f a)) -> Line f a -> ST s (Line f a)
forall a b. (a -> b) -> a -> b
$ (Int, GetType a) -> Stmt f (Var a) -> [Convert a] -> Line f a
forall (f :: * -> *) a.
(Int, GetType a) -> Stmt f (Var a) -> [Convert a] -> Line f a
Line (Int
pid, GetType a
getType) Stmt f (Var a)
expr' [Convert a]
convs

mkGetType :: Ord a => [a] -> Var a -> STRef s Int -> ST s (GetType a, [Convert a])
mkGetType :: [a] -> Var a -> STRef s Int -> ST s (GetType a, [Convert a])
mkGetType [a]
typesToConvert Var a
curVar STRef s Int
freshIds
    | [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [a]
typesToConvert = (GetType a, [Convert a]) -> ST s (GetType a, [Convert a])
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> GetType a
forall ty. ty -> GetType ty
NoConversion (a -> GetType a) -> a -> GetType a
forall a b. (a -> b) -> a -> b
$ Var a -> a
forall a. Var a -> a
varType Var a
curVar, [])
    | Bool
otherwise = do
        [Int]
ids <- Int -> STRef s Int -> ST s [Int]
forall s. Int -> STRef s Int -> ST s [Int]
nextIds Int
n STRef s Int
freshIds
        (GetType a, [Convert a]) -> ST s (GetType a, [Convert a])
forall (m :: * -> *) a. Monad m => a -> m a
return (Var a -> Map a Int -> GetType a
forall ty. Var ty -> Map ty Int -> GetType ty
ConversionLookup Var a
curVar (Map a Int -> GetType a) -> Map a Int -> GetType a
forall a b. (a -> b) -> a -> b
$ [(a, Int)] -> Map a Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
typesToConvert [Int]
ids),
                (Int -> a -> Convert a) -> [Int] -> [a] -> [Convert a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i a
t -> Var a -> Var a -> Convert a
forall a. Var a -> Var a -> Convert a
Convert Var a
curVar (Int -> a -> Var a
forall a. Int -> a -> Var a
Var Int
i a
t)) [Int]
ids [a]
typesToConvert)
    where n :: Int
n = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
typesToConvert

nextIds :: Int -> STRef s Int -> ST s [Int]
nextIds :: Int -> STRef s Int -> ST s [Int]
nextIds Int
n STRef s Int
ref = do
    Int
curId <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
ref
    STRef s Int -> Int -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Int
ref (Int
curId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n)
    [Int] -> ST s [Int]
forall (m :: * -> *) a. Monad m => a -> m a
return [Int
curId .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
curId]