{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE MultiParamTypeClasses #-} module LambdaCube.Compiler.Statements where import Data.Maybe import Data.List import Data.Char import Data.Function import qualified Data.Set as Set import qualified Data.Map as Map import qualified Data.IntMap as IM import Control.Monad.Writer import Control.Arrow hiding ((<+>)) import LambdaCube.Compiler.Utils import LambdaCube.Compiler.DeBruijn import LambdaCube.Compiler.Pretty hiding (braces, parens) import LambdaCube.Compiler.DesugaredSource import LambdaCube.Compiler.Patterns -------------------------------------------------------------------------------- declaration representation -- eliminated during parsing data PreStmt = Stmt Stmt | TypeAnn SIName SExp | TypeFamily SIName SExp{-type-} -- type family declaration | FunAlt SIName [(Visibility, SExp)]{-TODO: remove-} GuardTrees | Class SIName [SExp]{-parameters-} [(SIName, SExp)]{-method names and types-} | Instance SIName [ParPat]{-parameter patterns-} [SExp]{-constraints-} [Stmt]{-method definitions-} instance PShow PreStmt where pShow _ = text "PreStmt - TODO" instance DeBruijnify SIName PreStmt where deBruijnify_ k v = \case FunAlt n ts gue -> FunAlt n (map (second $ deBruijnify_ k v) ts) $ deBruijnify_ k v gue x -> error $ "deBruijnify @ " ++ ppShow x mkLets :: [Stmt]{-where block-} -> SExp{-main expression-} -> SExp{-big let with lambdas; replaces global names with de bruijn indices-} mkLets = mkLets_ SLet mkLets_ mkLet = mkLets' mkLet . concatMap desugarMutual . sortDefs mkLets' mkLet = f where f [] e = e f (StmtLet n x: ds) e = mkLet n x (deBruijnify [n] $ f ds e) f (PrecDef{}: ds) e = f ds e f (x: ds) e = error $ "mkLets: " ++ ppShow x type DefinedSet = Set.Set SName addForalls :: DefinedSet -> SExp -> SExp addForalls defined x = foldl f x [v | v@(sName -> vh:_) <- reverse $ names x, sName v `notElem'` defined, isLower vh] where f e v = SPi Hidden (Wildcard SType) $ deBruijnify [v] e notElem' s@(Ticked s') m = Set.notMember s m && Set.notMember s' m -- TODO: review notElem' s m = s `notElem` m names :: SExp -> [SIName] names = nub . foldName pure ------------------------------------------------------------------------ compileStmt' = compileStmt'_ SLHS SRHS SRHS compileStmt'_ lhs ulend lend ds = fmap concat . sequence $ map (compileStmt lhs (\si vt -> compileGuardTree ulend lend (Just si) vt . mconcat) ds) $ groupBy h ds where h (FunAlt n _ _) (FunAlt m _ _) = m == n h _ _ = False --compileStmt :: MonadWriter [ParseCheck] m => (SIName -> [(Visibility, SExp)] -> [GuardTrees] -> m SExp) -> [PreStmt] -> [PreStmt] -> m [Stmt] compileStmt lhs compilegt ds = \case [Instance{}] -> return [] [Class n ps ms] -> do cd <- compileStmt' $ [ TypeAnn n $ foldr (SPi Visible) SConstraint ps ] ++ [ funAlt n (map noTA ps) $ noGuards $ foldr (SAppV2 $ SBuiltin F'T2) (SBuiltin FCUnit) cstrs | Instance n' ps cstrs _ <- ds, n == n' ] ++ [ funAlt n (replicate (length ps) (noTA $ PVarSimp $ dummyName "cst0")) $ noGuards $ SBuiltin FCEmpty `SAppV` sLit (LString $ "no instance of " ++ sName n ++ " on ???"{-TODO-})] cds <- sequence [ compileStmt'_ SLHS SRHS SRHS{-id-} $ TypeAnn m (UncurryS (map ((,) Hidden) ps) $ SPi Hidden (SCW $ foldl SAppV (SGlobal n) $ downToS "a2" 0 $ length ps) $ up1 t) : as | (m, t) <- ms -- , let ts = fst $ getParamsS $ up1 t , let as = [ funAlt m p $ noGuards {- -$ SLam Hidden (Wildcard SType) $ up1 -} $ SLet m' e $ sVar "cst" 0 | Instance n' i cstrs alts <- ds, n' == n , StLet m' ~Nothing e <- alts, m' == m , let p = zip ((,) Hidden <$> ps) i ++ [((Hidden, Wildcard SType), PVarSimp $ dummyName "cst2")] -- , let ic = patVars i ] ] return $ cd ++ concat cds [TypeAnn n t] -> return [Primitive n t | n `notElem` [n' | FunAlt n' _ _ <- ds]] tf@[TypeFamily n t] -> case [d | d@(FunAlt n' _ _) <- ds, n' == n] of [] -> return [Primitive n t] alts -> compileStmt lhs compileGuardTrees' [TypeAnn n t] alts fs@(FunAlt n vs _: _) -> case groupBy ((==) `on` fst) [(length vs, n) | FunAlt n vs _ <- fs] of [gs@((num, _): _)] | num == 0 && length gs > 1 -> fail $ "redefined " ++ sName n ++ ":\n" ++ show (vcat $ pShow . sourceInfo . snd <$> gs) | n `elem` [n' | TypeFamily n' _ <- ds] -> return [] | otherwise -> do cf <- compilegt (SIName_ (mconcat [sourceInfo n | FunAlt n _ _ <- fs]) (nameFixity n) $ sName n) vs [gt | FunAlt _ _ gt <- fs] return [StLet n (listToMaybe [t | TypeAnn n' t <- ds, n' == n]{-TODO: fail if more-}) $ lhs n cf] fs -> fail $ "different number of arguments of " ++ sName n ++ ":\n" ++ show (vcat $ pShow . sourceInfo . snd . head <$> fs) [Stmt x] -> return [x] where noTA x = ((Visible, Wildcard SType), x) funAlt :: SIName -> [((Visibility, SExp), ParPat)] -> GuardTrees -> PreStmt funAlt n pats gt = FunAlt n (fst <$> pats) $ compilePatts (map snd pats) gt funAlt' n ts x gt = FunAlt n ts $ compilePatts x gt desugarValueDef :: MonadWriter [ParseCheck] m => ParPat -> SExp -> m [PreStmt] desugarValueDef p e = sequence $ pure (FunAlt n [] $ noGuards e) : [ FunAlt x [] . noGuards <$> compileCase (SGlobal n) [(p, noGuards $ SVar x i)] | (i, x) <- zip [0..] dns ] where dns = reverse $ getPVars p n = mangleNames dns getLet (StmtLet x dx) = Just (x, dx) getLet _ = Nothing fst' (x, _) = x -- TODO desugarMutual :: {-MonadWriter [ParseCheck] m => -} [Stmt] -> [Stmt] desugarMutual [x@Primitive{}] = [x] desugarMutual [x@Data{}] = [x] desugarMutual [x@PrecDef{}] = [x] desugarMutual [StLet n nt nd] = [StLet n nt $ addFix n nd] --desugarMutual [StmtLet n nd] = [StmtLet n $ addFix n nd] -- TODO desugarMutual (traverse getLet -> Just (unzip -> (ns, ds))) = fst' $ runWriter $ do ss <- compileStmt'_ sLHS SRHS SRHS =<< desugarValueDef (foldr cHCons cHNil $ PVarSimp <$> ns) (SGlobal xy) return $ StmtLet xy (addFix xy $ mkLets' SLet ss $ foldr HCons HNil ds) : ss where xy = mangleNames ns desugarMutual xs = error "desugarMutual" addFix n x | usedS n x = SBuiltin FprimFix `SAppV` SLamV (deBruijnify [n] x) | otherwise = x mangleNames xs = SIName (foldMap sourceInfo xs) $ "_" ++ intercalate "_" (sName <$> xs) -------------------------------------------------------------------------------- statement with dependencies data StmtNode = StmtNode { snId :: !Int , snValue :: Stmt , snChildren :: [StmtNode] , snRevChildren :: [StmtNode] } sortDefs :: [Stmt] -> [[Stmt]] sortDefs xs = map snValue <$> scc snId snChildren snRevChildren nodes where nodes = zipWith mkNode [0..] xs where mkNode i s = StmtNode i s (nubBy ((==) `on` snId) $ catMaybes $ (`Map.lookup` defMap) <$> need) (fromMaybe [] $ IM.lookup i revMap) where need = Set.toList $ case s of PrecDef{} -> mempty StLet _ mt e -> foldMap names mt <> names e Data _ ps t cs -> foldMap (names . snd) ps <> names t <> foldMap (names . snd) cs names = foldName Set.singleton revMap = IM.unionsWith (++) [IM.singleton (snId c) [n] | n <- nodes, c <- snChildren n] defMap = Map.fromList [(s, n) | n <- nodes, s <- def $ snValue n] where def = \case PrecDef{} -> mempty StLet n _ _ -> [n] Data n _ _ cs -> n: map fst cs