module LambdaCube.Compiler.Statements where
import Data.Maybe
import Data.List
import Data.Char
import Data.Function
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.IntMap as IM
import Control.Monad.Writer
import Control.Arrow hiding ((<+>))
import LambdaCube.Compiler.Utils
import LambdaCube.Compiler.DeBruijn
import LambdaCube.Compiler.Pretty hiding (braces, parens)
import LambdaCube.Compiler.DesugaredSource
import LambdaCube.Compiler.Patterns
data PreStmt
= Stmt Stmt
| TypeAnn SIName SExp
| TypeFamily SIName SExp
| FunAlt SIName [(Visibility, SExp)] GuardTrees
| Class SIName [SExp] [(SIName, SExp)]
| Instance SIName [ParPat] [SExp] [Stmt]
instance PShow PreStmt where
pShow _ = text "PreStmt - TODO"
instance DeBruijnify SIName PreStmt where
deBruijnify_ k v = \case
FunAlt n ts gue -> FunAlt n (map (second $ deBruijnify_ k v) ts) $ deBruijnify_ k v gue
x -> error $ "deBruijnify @ " ++ ppShow x
mkLets :: [Stmt] -> SExp -> SExp
mkLets = mkLets_ SLet
mkLets_ mkLet = mkLets' mkLet . concatMap desugarMutual . sortDefs
mkLets' mkLet = f where
f [] e = e
f (StmtLet n x: ds) e = mkLet n x (deBruijnify [n] $ f ds e)
f (PrecDef{}: ds) e = f ds e
f (x: ds) e = error $ "mkLets: " ++ ppShow x
type DefinedSet = Set.Set SName
addForalls :: DefinedSet -> SExp -> SExp
addForalls defined x = foldl f x [v | v@(sName -> vh:_) <- reverse $ names x, sName v `notElem'` defined, isLower vh]
where
f e v = SPi Hidden (Wildcard SType) $ deBruijnify [v] e
notElem' s@(Ticked s') m = Set.notMember s m && Set.notMember s' m
notElem' s m = s `notElem` m
names :: SExp -> [SIName]
names = nub . foldName pure
compileStmt' = compileStmt'_ SLHS SRHS SRHS
compileStmt'_ lhs ulend lend ds = fmap concat . sequence $ map (compileStmt lhs (\si vt -> compileGuardTree ulend lend (Just si) vt . mconcat) ds) $ groupBy h ds where
h (FunAlt n _ _) (FunAlt m _ _) = m == n
h _ _ = False
compileStmt lhs compilegt ds = \case
[Instance{}] -> return []
[Class n ps ms] -> do
cd <- compileStmt' $
[ TypeAnn n $ foldr (SPi Visible) SConstraint ps ]
++ [ funAlt n (map noTA ps) $ noGuards $ foldr (SAppV2 $ SBuiltin F'T2) (SBuiltin FCUnit) cstrs | Instance n' ps cstrs _ <- ds, n == n' ]
++ [ funAlt n (replicate (length ps) (noTA $ PVarSimp $ dummyName "cst0")) $ noGuards $ SBuiltin FCEmpty `SAppV` sLit (LString $ "no instance of " ++ sName n ++ " on ???")]
cds <- sequence
[ compileStmt'_ SLHS SRHS SRHS
$ TypeAnn m (UncurryS (map ((,) Hidden) ps) $ SPi Hidden (SCW $ foldl SAppV (SGlobal n) $ downToS "a2" 0 $ length ps) $ up1 t)
: as
| (m, t) <- ms
, let as = [ funAlt m p $ noGuards $ SLet m' e $ sVar "cst" 0
| Instance n' i cstrs alts <- ds, n' == n
, StLet m' ~Nothing e <- alts, m' == m
, let p = zip ((,) Hidden <$> ps) i ++ [((Hidden, Wildcard SType), PVarSimp $ dummyName "cst2")]
]
]
return $ cd ++ concat cds
[TypeAnn n t] -> return [Primitive n t | n `notElem` [n' | FunAlt n' _ _ <- ds]]
tf@[TypeFamily n t] -> case [d | d@(FunAlt n' _ _) <- ds, n' == n] of
[] -> return [Primitive n t]
alts -> compileStmt lhs compileGuardTrees' [TypeAnn n t] alts
fs@(FunAlt n vs _: _) -> case groupBy ((==) `on` fst) [(length vs, n) | FunAlt n vs _ <- fs] of
[gs@((num, _): _)]
| num == 0 && length gs > 1 -> fail $ "redefined " ++ sName n ++ ":\n" ++ show (vcat $ pShow . sourceInfo . snd <$> gs)
| n `elem` [n' | TypeFamily n' _ <- ds] -> return []
| otherwise -> do
cf <- compilegt (SIName_ (mconcat [sourceInfo n | FunAlt n _ _ <- fs]) (nameFixity n) $ sName n) vs [gt | FunAlt _ _ gt <- fs]
return [StLet n (listToMaybe [t | TypeAnn n' t <- ds, n' == n]) $ lhs n cf]
fs -> fail $ "different number of arguments of " ++ sName n ++ ":\n" ++ show (vcat $ pShow . sourceInfo . snd . head <$> fs)
[Stmt x] -> return [x]
where
noTA x = ((Visible, Wildcard SType), x)
funAlt :: SIName -> [((Visibility, SExp), ParPat)] -> GuardTrees -> PreStmt
funAlt n pats gt = FunAlt n (fst <$> pats) $ compilePatts (map snd pats) gt
funAlt' n ts x gt = FunAlt n ts $ compilePatts x gt
desugarValueDef :: MonadWriter [ParseCheck] m => ParPat -> SExp -> m [PreStmt]
desugarValueDef p e = sequence
$ pure (FunAlt n [] $ noGuards e)
: [ FunAlt x [] . noGuards <$> compileCase (SGlobal n) [(p, noGuards $ SVar x i)]
| (i, x) <- zip [0..] dns
]
where
dns = reverse $ getPVars p
n = mangleNames dns
getLet (StmtLet x dx) = Just (x, dx)
getLet _ = Nothing
fst' (x, _) = x
desugarMutual :: [Stmt] -> [Stmt]
desugarMutual [x@Primitive{}] = [x]
desugarMutual [x@Data{}] = [x]
desugarMutual [x@PrecDef{}] = [x]
desugarMutual [StLet n nt nd] = [StLet n nt $ addFix n nd]
desugarMutual (traverse getLet -> Just (unzip -> (ns, ds))) = fst' $ runWriter $ do
ss <- compileStmt'_ sLHS SRHS SRHS =<< desugarValueDef (foldr cHCons cHNil $ PVarSimp <$> ns) (SGlobal xy)
return $ StmtLet xy (addFix xy $ mkLets' SLet ss $ foldr HCons HNil ds) : ss
where
xy = mangleNames ns
desugarMutual xs = error "desugarMutual"
addFix n x
| usedS n x = SBuiltin FprimFix `SAppV` SLamV (deBruijnify [n] x)
| otherwise = x
mangleNames xs = SIName (foldMap sourceInfo xs) $ "_" ++ intercalate "_" (sName <$> xs)
data StmtNode = StmtNode
{ snId :: !Int
, snValue :: Stmt
, snChildren :: [StmtNode]
, snRevChildren :: [StmtNode]
}
sortDefs :: [Stmt] -> [[Stmt]]
sortDefs xs = map snValue <$> scc snId snChildren snRevChildren nodes
where
nodes = zipWith mkNode [0..] xs
where
mkNode i s = StmtNode i s (nubBy ((==) `on` snId) $ catMaybes $ (`Map.lookup` defMap) <$> need)
(fromMaybe [] $ IM.lookup i revMap)
where
need = Set.toList $ case s of
PrecDef{} -> mempty
StLet _ mt e -> foldMap names mt <> names e
Data _ ps t cs -> foldMap (names . snd) ps <> names t <> foldMap (names . snd) cs
names = foldName Set.singleton
revMap = IM.unionsWith (++) [IM.singleton (snId c) [n] | n <- nodes, c <- snChildren n]
defMap = Map.fromList [(s, n) | n <- nodes, s <- def $ snValue n]
where
def = \case
PrecDef{} -> mempty
StLet n _ _ -> [n]
Data n _ _ cs -> n: map fst cs