module DDC.Core.Flow.Transform.Rates.CnfFromExp
(cnfOfExp, takeXLamFlags_safe) where
import DDC.Core.Collect
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Transform.Rates.Fail
import DDC.Core.Flow.Transform.Rates.Combinators as CNF
import qualified DDC.Type.Env as Env
import Control.Monad
import Data.List (intersect, nub)
import Data.Maybe (catMaybes)
import Data.Monoid
import qualified Data.Set as Set
cnfOfExp :: ExpF -> Either ConversionError (Program Name Name)
cnfOfExp fun
= do
let (lams, body) = takeXLamFlags_safe fun
(lets, xx) = splitXLets body
binds <- takeLets lets
let lam_names = catMaybes $ map (takeNameOfBind . snd) lams
let names = lam_names ++ map fst binds
when (length names /= length (nub names)) $
Left FailNamesNotUnique
let inputs = mconcat $ map getInput lams
getInput (False, BName n ty)
| isTypeArray ty
= ([],[n])
| otherwise
= ([n],[])
getInput (_,_) = ([],[])
let (binds', env') = getBinds binds inputs
let outs = localEnv env' xx
return (Program inputs binds' outs)
isTypeArray :: TypeF -> Bool
isTypeArray = isVectorType
getBinds :: [(Name,(TypeF,ExpF))] -> ([Name],[Name]) -> ([CNF.Bind Name Name], ([Name],[Name]))
getBinds bs env
= go bs env
where
go [] e = ([], e)
go (b:rest) e
= let b' = getBind b e
e' = envOfBind b' <> e
(rest',e'') = go rest e'
in (b' : rest', e'')
getBind :: (Name,(TypeF,ExpF)) -> ([Name], [Name]) -> CNF.Bind Name Name
getBind (nm,(t,x)) env
| Just (f, args) <- takeXApps x
, XVar (UPrim (NameOpVector ov) _) <- f
, args' <- filter ((==Nothing) . takeXType) args
= case (ov, args') of
(OpVectorReduce, [worker, seed, arr])
| Just fun <- getFun worker
, snm <- name seed
, Just a <- name arr
-> SBind nm (Fold fun (Scalar seed snm) a)
(OpVectorMap n, worker : arrs)
| Just fun <- getFun worker
, Just as <- names arrs
, length arrs == n
-> ABind nm (MapN fun as)
(OpVectorFilter, [worker, arr])
| Just fun <- getFun worker
, Just a <- name arr
-> ABind nm (Filter fun a)
(OpVectorGenerate, [sz, worker])
| Just fun <- getFun worker
, snm <- name sz
-> ABind nm (Generate (Scalar sz snm) fun)
(OpVectorGather, [v, ix])
| Just v' <- name v
, Just ix' <- name ix
-> ABind nm (Gather v' ix')
_ | otherwise
-> external
| otherwise
= external
where
external
= let ins = localEnv env x
out | isTypeArray t = NameArray nm
| otherwise = NameScalar nm
in Ext out x ins
names as
| xs <- catMaybes $ map name as
, length xs == length as
= Just xs
| otherwise
= Nothing
name xx
| XVar (UName n) <- xx
= Just n
| otherwise
= Nothing
getFun xx
= let (ss, as) = localEnv env xx
in if null as
then Just $ Fun xx ss
else Nothing
localEnv :: ([Name],[Name]) -> ExpF -> ([Name],[Name])
localEnv env xx
= let free = catMaybes
$ map takeNameOfBound
$ Set.toList
$ freeX Env.empty xx
ss = free `intersect` fst env
as = free `intersect` snd env
in (ss, as)
takeXLamFlags_safe x
| Just (binds, body) <- takeXLamFlags x
= (binds, body)
| otherwise
= ([], x)
takeLets :: [LetsF] -> Either ConversionError [(Name, (TypeF, ExpF))]
takeLets lets
= mapM get lets
where
get (LLet (BName n t) x) = return (n,(t,x))
get (LLet (BNone _) _) = Left FailNoAnonAllowed
get (LLet (BAnon _) _) = Left FailNoDeBruijnAllowed
get (LRec _ ) = Left FailRecursiveBindings
get (LPrivate _ _ _) = Left FailLetRegionNotHandled