--
-- (c) Susumu Katayama
--
\begin{code}
module MagicHaskeller.ProgramGenerator where
import MagicHaskeller.Types
import MagicHaskeller.TyConLib
import Control.Monad
import Data.Monoid
import MagicHaskeller.CoreLang
import Control.Monad.Search.Combinatorial
import MagicHaskeller.PriorSubsts
import Data.List(partition, sortBy, genericLength)
import Data.Ix(inRange)
import MagicHaskeller.Instantiate
import MagicHaskeller.Expression
import MagicHaskeller.T10
import qualified Data.Map as Map
import Debug.Trace
import Data.Monoid
import System.Random
import MagicHaskeller.MyDynamic
import MagicHaskeller.Options
import Data.Array
#if __GLASGOW_HASKELL__ >= 710
import Prelude hiding ((<$>))
#endif
type Prim = (Int, Type, TyVar, Typed [CoreExpr])
class WithCommon a where
extractCommon :: a -> Common
class WithCommon a => ProgramGenerator a where
mkTrie :: Common -> [Typed [CoreExpr]] -> [[Typed [CoreExpr]]] -> a
mkTrie cmn c t = mkTrieOpt cmn c t t
mkTrieOpt :: Common -> [Typed [CoreExpr]] -> [[Typed [CoreExpr]]] -> [[Typed [CoreExpr]]] -> a
mkTrieOpt cmn c _ t = mkTrie cmn c t
matchingPrograms, matchingProgramsWOAbsents, unifyingPrograms :: Search m => Type -> a -> m AnnExpr
matchingPrograms ty memodeb = unifyingPrograms (quantify ty) memodeb
matchingProgramsWOAbsents ty memodeb = mapDepth (filter (not . isAbsent (getArity ty) . toCE)) $ matchingPrograms ty memodeb
class WithCommon a => ProgramGeneratorIO a where
mkTrieIO :: Common -> [Typed [CoreExpr]] -> [[Typed [CoreExpr]]] -> IO a
mkTrieIO cmn c t = mkTrieOptIO cmn c t t
mkTrieOptIO :: Common -> [Typed [CoreExpr]] -> [[Typed [CoreExpr]]] -> [[Typed [CoreExpr]]] -> IO a
mkTrieOptIO cmn c _ t = mkTrieIO cmn c t
matchingProgramsIO, unifyingProgramsIO :: Type -> a -> RecompT IO AnnExpr
matchingProgramsIO ty memodeb = unifyingProgramsIO (quantify ty) memodeb
extractTCL :: WithCommon a => a -> TyConLib
extractTCL = tcl . extractCommon
extractVL :: WithCommon a => a -> VarLib
extractVL = vl . extractCommon
extractRTrie :: WithCommon a => a -> RTrie
extractRTrie = rt . extractCommon
reducer :: Common -> CoreExpr -> Dynamic
reducer cmn = execute (opt cmn) (vl cmn)
data Common = Cmn {opt :: Opt (), tcl :: TyConLib, vl :: VarLib, pvl :: VarLib, rt :: RTrie}
mkCommon :: Options -> [Primitive] -> [Primitive] -> Common
mkCommon opts totals partials =
let
tyconlib = primitivesToTCL totals
optunit = forget opts
in Cmn {opt = optunit, tcl = tyconlib, vl = primitivesToVL tyconlib totals, pvl = primitivesToVL tyconlib partials, rt = mkRandTrie (nrands opts) tyconlib (stdgen opts)}
type Options = Opt [[Primitive]]
retsTVar (_, TV tv, _, _) = True
retsTVar _ = False
annotateTCEs :: Typed [CoreExpr] -> Prim
annotateTCEs tx@(_:::t) = (getArity t, getRet t, maxVarID t + 1, tx)
splitPrims :: [Typed [CoreExpr]] -> ([Prim],[Prim])
splitPrims = partition retsTVar . map annotateTCEs
splitPrimss :: [[Typed [CoreExpr]]] -> ([[Prim]],[[Prim]])
splitPrimss = unzip . map splitPrims
mapSum :: (MonadPlus m, Delay m) => (a -> m b) -> [[a]] -> m b
mapSum f = foldr (\xs y -> msum (map f xs) `mplus` delay y) mzero
applyDo :: (Functor m, Monad m) => ([Type] -> Type -> PriorSubsts m a) -> [Type] -> Type -> PriorSubsts m a
applyDo fun avail ty = do subst <- getSubst
fun (map (apply subst) avail) (apply subst ty)
wind :: (a->a) -> ([Type] -> Type -> a) -> [Type] -> Type -> a
wind g f avail (t0 :-> t1) = g $ wind g f (t0 : avail) t1
wind _ f avail reqret = f avail reqret
wind_ :: ([Type] -> Type -> a) -> [Type] -> Type -> a
wind_ = wind id
fromAssumptions :: (Search m, Expression e) => Common -> Int -> (Type -> PriorSubsts m [e]) -> (Type -> Type -> PriorSubsts m ()) -> Type -> [Type] -> PriorSubsts m [e]
fromAssumptions cmn lenavails behalf mps reqret avail = msum $ map (retMono cmn lenavails behalf (flip mps reqret)) (fromAvail avail)
retMono :: (Search m, Expression e) => Common -> Int -> (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m ()) -> (Int8, (Int8,[Type],Type)) -> PriorSubsts m [e]
retMono cmn lenavails behalf tok fromBlah
= do let (n, (arity,args,retty)) = fromBlah
tok retty
convertPS (ndelay $ fromIntegral arity) $
fap behalf args (map (mkHead (reducer cmn) lenavails arity) [X n])
fromAvail :: [Type] -> [(Int8, (Int8,[Type],Type))]
fromAvail = zipWith (\ n t -> (n, revSplitArgs t)) [0..]
mguAssumptions :: (Functor m, MonadPlus m) => Type -> [Type] -> PriorSubsts m [CoreExpr]
mguAssumptions patty assumptions = applyDo mguAssumptions' assumptions patty
mguAssumptions' assumptions patty = msum $ zipWith (\n t -> mguPS patty t >> return [X n]) [0..] assumptions
matchAssumptions :: (Functor m, MonadPlus m, Expression e) => Common -> Int -> Type -> [Type] -> PriorSubsts m [e]
matchAssumptions cmn lenavails reqty assumptions
= do s <- getSubst
let newty = apply s reqty
msum $ zipWith (\n t -> matchPS newty t >> return [mkHead (reducer cmn) lenavails (getLongerArity newty) (X n)]) [0..] assumptions
mguAssumptions_ :: (Functor m, MonadPlus m) => Type -> [Type] -> PriorSubsts m ()
mguAssumptions_ patty assumptions = applyDo mguAssumptions_' assumptions patty
mguAssumptions_' assumptions patty = msum $ map (mguPS patty) assumptions
retPrimMono :: (Search m, Expression e) => Common -> Int -> (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m [e]) -> (Type -> Type -> PriorSubsts m ()) -> Type -> Prim -> PriorSubsts m [e]
retPrimMono cmn lenavails clbehalf lltbehalf behalf mps reqret (arity, retty, numtvs, xs:::ty)
= do tvid <- reserveTVars numtvs
mps (mapTV (tvid+) retty) reqret
convertPS (ndelay $ fromIntegral arity) $
funApSub clbehalf lltbehalf behalf (mapTV (tvid+) ty) (map (mkHead (reducer cmn) lenavails (getLongerArity ty)) xs)
funApSub :: (Search m, Expression e) => (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m [e]) -> Type -> [e] -> PriorSubsts m [e]
funApSub = funApSubOp (<$>)
funApSubOp op clbehalf lltbehalf behalf = faso
where faso (t:=>ts) funs
= do args <- clbehalf t
faso ts (liftM2 op funs args)
faso (t:> ts) funs
= do args <- lltbehalf t
faso ts (liftM2 op funs args)
faso (t:->ts) funs
= do args <- behalf t
faso ts (liftM2 op funs args)
faso _ funs = return funs
fap behalf ts funs = foldM (\fs t -> do args <- behalf t
return $ liftM2 (<$>) fs args)
funs
ts
mapAndFoldM op n f [] = return n
mapAndFoldM op n f (x:xs) = do y <- f x
mapAndFoldM op (n `op` y) f xs
retGen, retGenOrd, retGenTV1
:: (Search m, Expression e) => Common -> Int -> (Type -> Type -> [e] -> [e]) -> (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m [e]) -> (Type -> PriorSubsts m [e]) -> Type -> Prim -> PriorSubsts m [e]
retGen cmn lenavails fe clbehalf lltbehalf behalf = retGen' (funApSub clbehalf lltbehalf behalf) cmn lenavails fe clbehalf lltbehalf behalf
retGen' fas cmn lenavails fe clbehalf lltbehalf behalf reqret (arity, _retty, numtvs, xs:::ty)
= convertPS (ndelay $ fromIntegral arity) $
do tvid <- reserveTVars numtvs
a <- mkSubsts (tvndelay $ opt cmn) tvid reqret
exprs <- funApSub clbehalf lltbehalf behalf (mapTV (tvid+) ty) (map (mkHead (reducer cmn) lenavails (getLongerArity ty+a)) xs)
gentvar <- applyPS (TV tvid)
guard (orderedAndUsedArgs gentvar)
fas gentvar (fe gentvar ty exprs)
retGenOrd cmn lenavails fe clbehalf lltbehalf behalf = retGen' (funApSub'' False) cmn lenavails fe clbehalf lltbehalf behalf
where
funApSub'' filtexp (t:->ts@(u:->_)) funs
| otherwise = do args <- behalf t
funApSub'' (t==u) ts (if filtexp then [ f <$> e | f <- funs, e <- args, let _:$d = toCE f, d <= toCE e ]
else liftM2 (<$>) funs args)
funApSub'' filtexp (t:->ts) funs
= do args <- behalf t
return (if filtexp then [ f <$> e | f <- funs, e <- args, let _:$d = toCE f, d <= toCE e]
else liftM2 (<$>) funs args)
funApSub'' _fe _t funs = return funs
orderedAndUsedArgs (TV _ :-> _) = False
orderedAndUsedArgs (t:->ts@(u:->_)) | t > u = False
| otherwise = orderedAndUsedArgs ts
orderedAndUsedArgs _ = True
usedArg n (TV m :-> _) = n /= m
usedArg _ _ = True
retGenTV1 cmn lenavails fe clbehalf lltbehalf behalf reqret (arity, _retty, numtvs, xs:::ty)
= convertPS (ndelay $ fromIntegral arity) $
do tvid <- reserveTVars numtvs
a <- mkSubst (tvndelay $ opt cmn) tvid reqret
exprs <- funApSub clbehalf lltbehalf behalf (mapTV (tvid+) ty) (map (mkHead (reducer cmn) lenavails (getLongerArity ty+a)) xs)
gentvar <- applyPS (TV tvid)
guard (usedArg (tvid+1) gentvar)
funApSub clbehalf lltbehalf behalf gentvar (fe gentvar ty exprs)
retGenTV0 cmn lenavails fe clbehalf lltbehalf behalf reqret (arity, _retty, numtvs, xs:::ty)
= convertPS (ndelay $ fromIntegral arity) $
do tvid <- reserveTVars numtvs
updatePS (unitSubst tvid reqret)
exprs <- funApSub clbehalf lltbehalf behalf (mapTV (tvid+) ty) (map (mkHead (reducer cmn) lenavails (getLongerArity ty)) xs)
gentvar <- applyPS (TV tvid)
return $ fe gentvar ty exprs
filtExprs :: Expression e => Bool -> Type -> Type -> [e] -> [e]
filtExprs g a b | g = filterExprs a b
| otherwise = id
filterExprs :: Expression e => Type -> Type -> [e] -> [e]
filterExprs gentvar ty = filter (cond . getArgExprs . toCE)
where cond es = case gentvar of _:->_ -> not (retSameVal ty es) && not (includesStrictArg es) && anyRec ty es && not (constEq ty es)
_ -> not (retSameVal ty es) && not (includesStrictArg es)
getArgExprs e = gae e []
gae (f:$e) es = gae f (e:es)
gae _ es = es
constEq (t:->u) (e@(Lambda d):es) | returnsAtoA t = recHead t e && constEq u es
| otherwise = not (isUsed 0 d) && ceq e u es
constEq (t:->u) (_:_) = False
constEq (_:> u) (_ :es) = constEq u es
constEq _ [] = True
ceq d (t:->u) (e@(Lambda _):es) | returnsAtoA t = recHead t e && ceq d u es
| otherwise = d == e && ceq d u es
ceq d (t:->u) (_:_) = False
ceq d (_:> u) (_ :es) = ceq d u es
ceq _ _ [] = True
recHead (t:->u@(_:->_)) (Lambda e) = recHead u e
recHead (TV tv0 :-> TV tv1) (Lambda (Lambda (X 1 :$ _))) = tv0 == 0 && tv1 == 0
recHead _u _e = False
retSameVal (_:>u) (_:es) = retSameVal u es
retSameVal (t:->u) (e:es) = (returnsId t e && rsv u es) || rsv' (retVal t e) u es
retSameVal _ _ = False
rsv (_:>u) (_:es) = rsv u es
rsv (t:->u) (e:es) = (returnsId t e && rsv u es) || rsv' (retVal t e) u es
rsv _ _ = True
rsv' rve (_:>u) (_:es) = rsv' rve u es
rsv' rve (t:->u) (e:es) = (returnsId t e || retVal t e == rve) && rsv' rve u es
rsv' _ _ _ = True
returnsAtoA (TV tv0 :-> TV tv1) = tv0 == 0 && tv1 == 0
returnsAtoA (t :-> u) = returnsAtoA u
returnsAtoA _ = False
returnsId (t:->u@(_:->_)) (Lambda e) = returnsId u e
returnsId (TV tv0 :-> TV tv1) e = tv0 == 0 && tv1 == 0 && isId e
returnsId _u _e = False
isId e = isId' 0 e
isId' n (Lambda e) = isId' (n+1) e
isId' n e = isId'' n 0 e
isId'' n m (e :$ X i) = i==m && isId'' n (m+1) e
isId'' n m (X i) = i==m && n == m+1
isId'' _ _ _ = False
retVal t e = rv t 0 e
rv (_:->t) n (Lambda e) = rv t (n+1) e
rv (_:->_) _ _ = error "rv: impossible"
rv _ n e = mapsub n e
mapsub n (X m) = X (mn)
mapsub n (a :$ b) = mapsub n a :$ mapsub n b
mapsub n (Lambda e) = Lambda (mapsub n e)
mapsub n e = e
isConstrExpr (X _) = False
isConstrExpr (Lambda _) = False
isConstrExpr (Context _) = False
isConstrExpr (f :$ _) = isConstrExpr f
isConstrExpr (Primitive _) = False
isConstrExpr (PrimCon _) = True
isClosed = isClosed' 0
isClosed' dep (X n) = n < dep
isClosed' dep (Lambda e) = isClosed' (dep+1) e
isClosed' dep (f :$ e) = isClosed' dep f && isClosed' dep e
isClosed' _ _ = True
includesStrictArg (X n : es) = any (isUsed n) es
includesStrictArg _ = False
anyRec (_:>t) (_:es) = anyRec t es
anyRec (t:->u) (e:es) =
recursive t e || anyRec u es
anyRec (_:->_) _ = error "hoge"
anyRec _ [] = False
recursive (t:->u@(_:->_)) (Lambda e) = recursive u e
recursive (TV tv0 :-> TV tv1) (Lambda e) = tv0 == 0 && tv1 == 0 && isUsed 0 e && not (constRec 0 e)
recursive _ _ = False
constRec dep (Lambda e) = constRec (dep+1) e
constRec dep (X n :$ e) | n == dep = not (belowIsUsed n e)
constRec _ _ = False
belowIsUsed dep (X n) = dep > n
belowIsUsed dep (Lambda e) = belowIsUsed (dep+1) e
belowIsUsed dep (f :$ e) = belowIsUsed dep f || belowIsUsed dep e
belowIsUsed _ _ = False
isUsed dep (X n) = dep==n
isUsed dep (Lambda e) = isUsed (dep+1) e
isUsed dep (f :$ e) = isUsed dep f || isUsed dep e
isUsed _ _ = False
mkSubsts :: Search m => Int -> TyVar -> Type -> PriorSubsts m Int8
mkSubsts n tvid reqret = base `mplus` ndelayPS n recurse
where base = do updatePS (unitSubst tvid reqret)
return 0
recurse = do v <- newTVar
arity <- mkSubsts n tvid (TV v :-> reqret)
return (arity+1)
mkSubst :: Search m => Int -> TyVar -> Type -> PriorSubsts m Int8
mkSubst n tvid reqret = base `mplus` ndelayPS n first
where base = do updatePS (unitSubst tvid reqret)
return 0
first = do v <- newTVar
updatePS (unitSubst tvid (TV v :-> reqret))
return 1
mkRetty t = (getRet t, t)
reorganizer_ :: ([Type] -> a) -> [Type] -> a
reorganizer_ fun avail = fun $ uniqSort avail
hit :: Type -> [Type] -> Bool
hit ty tys = sum (map size (ty:tys)) < 10
combs 0 xs = [[]]
combs n xs = [] : [ y:zs | y:ys <- tails xs, zs <- combs (n1) ys ]
tails [] = []
tails xs@(_:ys) = xs : tails ys
mapFst3 f (ces, s, i) = (f ces, s, i)
decodeVarsPos vs = mapFst3 (map (decodeVarsCE vs))
decodeVarsCE :: [Int8] -> CoreExpr -> CoreExpr
decodeVarsCE vs = decodeVarsCE' 0 (listArray (0, genericLength vs1) vs)
decodeVarsCE' :: Int8 -> Array Int8 Int8 -> CoreExpr -> CoreExpr
decodeVarsCE' offset ar e@(X n) = let nn = n offset
in if inRange (bounds ar) nn then X $ (ar ! nn) + offset else e
decodeVarsCE' offset ar (Lambda e) = Lambda $ decodeVarsCE' (offset + 1) ar e
decodeVarsCE' offset ar (f :$ e) = decodeVarsCE' offset ar f :$ decodeVarsCE' offset ar e
decodeVarsCE' offset ar e = e
\end{code}