-- Please, see the file LICENSE for copyright and license information.
>module HFusion.Internal.Parsing.HyloParser(parse,parseResult2FusionState,parseHsModule,deriveHylos) where
import qualified HsParser as P(parse,ParseResult(..))
> import Language.Haskell.Parser(parseModule,ParseResult(..))
> import Language.Haskell.Syntax(SrcLoc,HsModule)
> import HFusion.Internal.Parsing.Translator
> import HFusion.Internal.HsSyn
> import HFusion.Internal.Utils
> import HFusion.Internal.FuseFace
> import HFusion.Internal.HyloFace
> import HFusion.Internal.Parsing.HyloContext
> import Control.Monad(zipWithM)
> import Control.Monad.Error(throwError,runErrorT)
> import Control.Monad.Trans(lift)
> import Control.Monad.State(StateT(..),State,MonadState(..))
> import Data.List(partition,intersect,union,find,nubBy,(\\),deleteFirstsBy,sort)
> import Data.Maybe(isJust,catMaybes)
import Debug.Trace -- PosiciĆ³n de un token. Es utilizada por el parser y el -- lexer para resolver los problemas del layout. data SrcLoc = SrcLoc Int Int -- (Line, Indentation) deriving (Eq,Ord,Show)
> parse :: String -> FusionState [HyloT]
> parse inp = parseResult2FusionState (parseModule inp) >>= hsModule2HsSyn >>=
>             lift . deriveHylos >>= \(errors,hs) -> if null errors then return (map snd hs)
>               else throwError (snd$ head errors)
> -- | Obtains hylomorphisms representing functions in the original program.
> -- 
> -- The hylomorphisms are returned in the second component of the output. 
> -- If a hylomorphism cannot be derived for some (possibly) mutually recursive 
> -- function definitions, then they are returned in the first component of the 
> -- output together with the error obtained when attempting derivation.
> deriveHylos :: [Def] -> VarGenState ([([Def],FusionError)],[([Def],HyloT)])
> deriveHylos dfs = removeInputVar dfs >>= 
>                   handleRegularFunctions . getCycles >>= \ cdfs ->
>                   mapM (\cdf -> runErrorT$ fmap ((,) cdf)$ deriveHylo cdf) cdfs >>= \ehs ->
>                   return (concat (zipWith (\df -> either ((:[]) . ((,) df)) (const [])) cdfs ehs)
>                          ,concat (map (either (const []) (:[])) ehs))
> -- | Allows to handle parsing of an 'HsModule' as a 'FusionState' computation.
> -- 
> -- @parseResult2FusionState (Language.Haskell.Parser.parseModule sourceCode)@
> parseResult2FusionState :: ParseResult HsModule -> FusionState HsModule
> parseResult2FusionState = catchParseState return (\loc -> throwError . ParserError loc)
> parseHsModule :: HsModule -> VarGenState ([FusionError],[HyloT])
> parseHsModule m = 
>   hsModule2HsSyn_ m >>= 
>   removeInputVar . snd >>= 
>   handleRegularFunctions . getCycles >>=
>   (\dss -> do he <- mapM (runErrorT . deriveHylo) dss
>               return (concat (map (either (:[]) (const [])) he),concat (map (either (const []) (:[])) he))
>   )
catchParseState :: (a->b) -> (String->b) -> P.ParseResult a -> b catchParseState f h (P.Ok _ p) = f p catchParseState f h (P.Failed err) = h err
> catchParseState :: (a->b) -> (SrcLoc->String->b) -> ParseResult a -> b
> catchParseState f _ (ParseOk p) = f p
> catchParseState _ h (ParseFailed loc err) = h loc err
getCycles groups mutually recursive function definitions.
> getCycles :: [Def] -> [[Def]]
> getCycles defs = let idxs=findCycles (getDependencyGraph defs)
>                   in map (map (defs!!). sort) idxs
> collect :: [Maybe a] -> [a]
> collect = foldr (\a r->maybe r (:r) a) []
Groups indexes identifying mutual recursive definitions.
> getDependencyGraph :: [Def] -> [[Int]]
> getDependencyGraph ds = map (dps (zip (map getV ds) [0..])) ds
>  where dps ps (Defvalue _ t) = collect .map (flip lookup ps).vars$ t
>        getV (Defvalue v _) = v
> findCycles :: [[Int]] -> [[Int]]
> findCycles g = joinCycles [] $ concat $ map (follow [] g) [0..length g-1]
>   where follow :: [Int] -> [[Int]] -> Int -> [[Int]]
>         follow vs g i | elem i vs = [i:takeWhile (/=i) vs]
>                       | otherwise = concat$ map (follow (i:vs) g) (g!!i)
>         joinCycles :: [[Int]] -> [[Int]] -> [[Int]]
>         joinCycles ant (is:iss) = let (cs1,cs2) = partition (null.intersect is) ant
>                                    in joinCycles (foldr union is cs2:cs1) iss
>         joinCycles ant [] = ant
removeInputVar transforms a case expression: case v0 of c p1 ... 1n -> ... v0 ... ... into case v0 of c p1 ... 1n -> ... (c p1 ... pn) ... ...
> removeInputVar :: [Def] -> VarGenState [Def]
> removeInputVar defs = mapM removeInputVar' defs
> removeInputVar' (Defvalue v t) = inTlamb t >>= return . Defvalue v
>  where inTlamb :: Term -> VarGenState Term
>        inTlamb (Tlamb bv t) = inTlamb t >>= return . Tlamb bv
>        inTlamb (Tcase t0@(Tvar v0) ps ts) = sequence (zipWith (selectP v0) ps ts) >>= return . uncurry (Tcase t0) . unzip
>        inTlamb t = return t
>        selectP v0 p t | not (elem v0 (vars p)) && elem v0 (vars t) = 
>                                   renamePany p >>= \p'-> return (p',substitution [(v0,pat2term p')] t)
>                       | otherwise = return (p,t)
>        renamePany (Pvar (Vuserdef "_")) = getFreshVar "v" >>= return . Pvar
>        renamePany p@(Pvar _) = return p
>        renamePany (Ptuple ps) = mapM renamePany ps >>= return . Ptuple
>        renamePany (Pcons c ps) = mapM renamePany ps >>= return . Pcons c
>        renamePany (Pas v p) = renamePany p >>= return . Pas v
>        renamePany p@(Plit _) = return p
>        pat2term (Pvar v) = Tvar v
>        pat2term (Ptuple ps) = Ttuple False$ map pat2term ps
>        pat2term (Pcons c ps) = Tcapp c$ map pat2term ps
>        pat2term (Plit l) = Tlit l
>        pat2term (Pas v _) = Tvar v
> type CallDescription = (Variable,Def,Int,[Maybe Term],Term)
handleRegularFunctions creates new definition where recursion of regular functors can be expressed with mutually recursive functions.
> handleRegularFunctions :: [[Def]] -> VarGenState [[Def]]
> handleRegularFunctions dss = handleRegularFunctions' dss [] (map (const []) dss) dss
> handleRegularFunctions' :: [[Def]] -> [CallDescription] -> [[Variable]] -> [[Def]] -> VarGenState [[Def]]
> handleRegularFunctions' p _ _ [] = return p
> handleRegularFunctions' p calls dns ds = 
>   do cs<-zipWithM (getCallDefs p calls) dns ds
>      let (dfs,nfs)=unzip cs
>      if all null nfs then return dfs
>        else mapM (mapM buildDef . nubBy eq) nfs 
>             >>= handleRegularFunctions' (zipWith (++) dfs (zipWith (deleteFirstsBy eqDefs) p dfs)) 
>                                         (calls++concat nfs) (zipWith (++) dns (map (map getDefName) ds))
>             >>= return . zipWith (++) dfs
>  where eq (_,d1,i1,_,t1) (_,d2,i2,_,t2) = i1==i2 && (getDefName d1)==(getDefName d2) && t1==t2
>        getCallDefs :: [[Def]] -> [CallDescription] -> [Variable] -> [Def] -> VarGenState ([Def],[CallDescription])
>        getCallDefs p calls dns ds = mapM (getCalls p calls (dns++map getDefName ds)) ds 
>                                     >>= (\ (dfs,m)-> return (dfs,concat m)) . unzip
>        eqDefs d d' = getDefName d == getDefName d'
> buildDef :: CallDescription -> VarGenState Def
> buildDef (u,d,i,tsargs,t) = buildDef' u i tsargs t d
> buildDef' u i tsargs t (Defvalue nd t0) = 
>    do (us,t')<-regenVars t
>       t0'<-regenConstantArgs t t0
>       let (bvs,t0'') = getInputVars t0'
>           (ant,pos)=splitAt i bvs
>           bs=vars ant++vars (tail pos)
>       return$ Defvalue u (foldr Tlamb (adapt bs us t' (head pos) t0'') (map Bvar us++ant++tail pos))
>  where adapt bs us t (Bvar l) t0 = substitution [(l,t)]$ adaptr bs us t0
>        adapt bs us t bv t0 = Tcase t [bv2pat bv] [adaptr bs us t0]
>        getInputVars (Tlamb bv t) = let (bs,t')=getInputVars t in (bv:bs,t')
>        getInputVars t = ([],t)
>        bv2pat (Bvar v) = Pvar v
>        bv2pat (Bvtuple _ bvs) = Ptuple (map bv2pat bvs)
>        regenConstantArgs t t0 = 
>             do let freeVars = case t of
>                                 Tfapp v args -> v : concat (zipWith (\a -> maybe (vars a) (const [])) args tsargs)
>                                 _ -> vars t
>                us<-mapM (getFreshVar . varPrefix) freeVars
>                return$ alphaConvert [] (zip freeVars us) t0
>        regenVars t = do us<-mapM (maybe (return Nothing) (\t -> fmap Just$
>                                     case vars t of
>                                       []  -> getFreshVar "u"
>                                       v:_ -> getFreshVar (varPrefix v))) tsargs
>                         return (catMaybes us,
>                                     case t of 
>                                       Tfapp v args -> Tfapp v (zipWith (\a -> maybe a Tvar) args us)
>                                       _ -> t
>                                )
>        adaptr bs us (Ttuple b ts) = Ttuple b (map (adaptr bs us) ts)
>        adaptr bs us (Tcapp c ts) = Tcapp c (map (adaptr bs us) ts)
>        adaptr bs us (Tcase t0 ps ts) = Tcase (adaptr bs us t0) ps 
>                                                (zipWith (\p->adaptr (vars p++bs) us) ps ts)
>        adaptr bs us (Tlet v t0 t1) = Tlet v (adaptr (v:bs) us t0) (adaptr (v:bs) us t1)
>        adaptr bs us (Tlamb v t) = Tlamb v (adaptr (vars v++bs) us t)
>        adaptr bs us (Tapp t0 t1) = tapp (adaptr bs us t0) (adaptr bs us t1)
>        adaptr bs us (Tfapp fv ts) | elem fv bs = Tfapp fv (map (adaptr bs us) ts)
>                                   | fv == nd = let (ant,pos)=splitAt i (map (adaptr bs us) ts)
>                                                    in Tfapp u (map Tvar us++ant++tail pos)
>                                   | otherwise = Tfapp fv (map (adaptr bs us) ts)
>        adaptr _ _ t = t
>        isVar (Tvar _) = True
>        isVar _ = False
getCalls collects the information about each recursive call that can be rewritten as a call to a recursive function which fixes one of the arguments. The returned pair (def,l) contains the rewritten definition (with fresh vars for some recursive calls), and l is a list containing data for each of the new definitions to be introduced. Each item in the list is a tuple (u,def,i,vrs,t) where u is the name for the new definition, def is the definition to be rewritten with a fixed argument, i is the index of the fixed argument, vrs are the bounded variables appearing in the term in the ith argument, and t is that term.
> getCalls :: [[Def]] -> [CallDescription] -> [Variable] -> Def -> VarGenState (Def,[CallDescription])
> getCalls ps calls ds (Defvalue v t) = runStateT (do (t',ds')<-getCalls' [] t; return$ (Defvalue v t',ds')) calls >>= return . fst
>  where getCalls' :: [Variable] -> Term -> StateT [CallDescription] (State VarGen) (Term,[CallDescription])
>        getCalls' bs (Ttuple b ts) = do (ts',ns)<-mapGetCalls' bs ts
>                                        return (Ttuple b ts',ns)
>        getCalls' bs (Tlamb bv t) = do (t',ns)<-getCalls' (bs++vars bv) t; return (Tlamb bv t',ns)
>        getCalls' bs (Tcase t0 ps ts) = do (t0',n0)<-getCalls' bs t0
>                                           res<-sequence$ zipWith (getCalls'.(bs++).vars) ps ts
>                                           let (ts',ns)=unzip res
>                                           return (Tcase t0' ps ts',n0++concat ns)
>        getCalls' bs (Tapp t0 t1) = do (t0',n0)<-getCalls' bs t0
>                                       (t1',n1)<-getCalls' bs t1
>                                       return (tapp t0' t1',n0++n1)
>        getCalls' bs (Tlet v t0 t1) = do (t0',n0)<-getCalls' (v:bs) t0
>                                         (t1',n1)<-getCalls' (v:bs) t1
>                                         return (Tlet v t0' t1',n0++n1)
>        getCalls' bs (Tcapp c ts) = do (ts',ns)<-mapGetCalls' bs ts
>                                       return (Tcapp c ts',ns)
>        getCalls' bs tt@(Tfapp v ts) =
>            do (ts',ns)<-mapGetCalls' bs ts
>               let rr = return (Tfapp v ts',ns)
>                   checkNoPattern (idxs,d@(Defvalue _ t)) = 
>                       case [ i | (i,t)<-zip [0..] ts', any (flip elem (vars t)) ds, callIsOkToSpecialize i t ] of
>                         i:_ | elem i idxs -> mr i d -- recursive calls appear in constant positions
>                         _                 -> rr
>                      where (vargs,t') = extractVars t
>                            callIsOkToSpecialize i (Tfapp v' ts) = 
>                                          elem v' ds
>                                          && (length ts < lengthvargs'
>                                              || length ts==lengthvargs'
>                                                 -- variable is used at most once
>                                                 && countLinear (getVar (vargs!!i)) t'<2)
>                                where lengthvargs' = maybe (error "lengthvars'") 
>                                                     (length.fst.extractVars.getDefTerm) $ find ((v'==).getDefName)$ concat ps
>                            callIsOkToSpecialize _ (Tvar _) = True
>                            callIsOkToSpecialize _ _ = False
>                            getVar (Bvar v) = v
>                            getVar _ = error "getCalls': getVar"
>                   mr i d = do let (ant,pos)=splitAt i ts'
>                                   tsargs = map (\t -> if (not$ isVar t) || (not$ null$ intersect (vars t) bs) then Just t else Nothing)$
>                                              case (head pos) of { Tfapp _ ts -> ts; _ -> [] }
>                               calls<-get
>                               case find (\(_,d',i',tsargs',t')-> i'==i 
>                                                           && getDefName d'==getDefName d
>                                                           && and (zipWith (\t t' -> isJust t == isJust t') tsargs tsargs')
>                                                           && t'==head pos) 
>                                         calls of
>                                Nothing -> do u<-lift$ getFreshVar (varPrefix v)
>                                              let c = (u,d,i,tsargs,head pos)
>                                              put (c:calls)
>                                              return (Tfapp u (catMaybes tsargs++ant++tail pos),c:ns)
>                                Just (u,_,_,_,_) -> return (Tfapp u (catMaybes tsargs++ant++tail pos),ns)
>               if elem v bs then rr
>                else maybe rr checkNoPattern (lookupDef v (map constantArgs ps) ps)
>        getCalls' _ t = return (t,[])
>        isVar (Tvar _) = True
>        isVar _ = False
>        constantArgs :: [Def] -> [Int]
>        constantArgs dfs = findConstantArguments dfs 
>        mapGetCalls' bs ts = do res<-mapM (getCalls' bs) ts
>                                let (ts',ns)=unzip res
>                                return (ts',concat ns)
>        lookupDef :: Variable -> [[Int]] -> [[Def]] -> Maybe ([Int],Def)
>        lookupDef v argidxs ps = find ((v==).getDefName.snd)$ [ (idxs,df) | (idxs,dfs)<-zip argidxs ps, df<-dfs]