----------------------------------------------------------------------------- -- | -- Module : Disco.Compile -- Copyright : disco team and contributors -- Maintainer : byorgey@gmail.com -- -- SPDX-License-Identifier: BSD-3-Clause -- -- Compiling the typechecked, desugared AST to the untyped core -- language. ----------------------------------------------------------------------------- module Disco.Compile where import Control.Monad ((<=<)) import Data.Bool (bool) import Data.Coerce import qualified Data.Map as M import Data.Ratio import Data.Set (Set) import qualified Data.Set as S import Data.Set.Lens (setOf) import Disco.Effects.Fresh import Polysemy (Member, Sem, run) import Unbound.Generics.LocallyNameless (Name, bind, string2Name, unembed) import Disco.AST.Core import Disco.AST.Desugared import Disco.AST.Generic import Disco.AST.Typed import Disco.Context as Ctx import Disco.Desugar import Disco.Module import Disco.Names import Disco.Syntax.Operators import Disco.Syntax.Prims import qualified Disco.Typecheck.Graph as G import Disco.Types import Disco.Util ------------------------------------------------------------ -- Convenience operations ------------------------------------------------------------ -- | Utility function to desugar and compile a thing, given a -- desugaring function for it. compileThing :: (a -> Sem '[Fresh] DTerm) -> a -> Core compileThing desugarThing = run . runFresh . (compileDTerm <=< desugarThing) -- | Compile a typechecked term ('ATerm') directly to a 'Core' term, -- by desugaring and then compiling. compileTerm :: ATerm -> Core compileTerm = compileThing desugarTerm -- | Compile a typechecked property ('AProperty') directly to a 'Core' term, -- by desugaring and then compilling. compileProperty :: AProperty -> Core compileProperty = compileThing desugarProperty ------------------------------------------------------------ -- Compiling definitions ------------------------------------------------------------ -- | Compile a context of typechecked definitions ('Defn') to a -- sequence of compiled 'Core' bindings, such that the body of each -- binding depends only on previous ones in the list. First -- topologically sorts the definitions into mutually recursive -- groups, then compiles recursive definitions specially in terms of -- 'delay' and 'force'. compileDefns :: Ctx ATerm Defn -> [(QName Core, Core)] compileDefns defs = run . runFresh $ do let vars = Ctx.keysSet defs -- Get a list of pairs of the form (y,x) where x uses y in its -- definition. We want them in the order (y,x) since y needs to -- be evaluated before x. These will be the edges in our -- dependency graph. Note that some of these edges may refer to -- things that were imported, and hence not in the set of -- definitions; those edges will simply be dropped by G.mkGraph. deps :: Set (QName ATerm, QName ATerm) deps = S.unions . map (\(x, body) -> S.map (,x) (setOf (fvQ @Defn @ATerm) body)) . Ctx.assocs $ defs -- Do a topological sort of the condensation of the dependency -- graph. Each SCC corresponds to a group of mutually recursive -- definitions; each such group depends only on groups that come -- before it in the topsort. defnGroups :: [Set (QName ATerm)] defnGroups = G.topsort (G.condensation (G.mkGraph vars deps)) concat <$> mapM (compileDefnGroup . Ctx.assocs . Ctx.restrictKeys defs) defnGroups -- | Compile a group of mutually recursive definitions, using @delay@ -- to compile recursion via references to memory cells. compileDefnGroup :: Member Fresh r => [(QName ATerm, Defn)] -> Sem r [(QName Core, Core)] compileDefnGroup [(f, defn)] -- Informally, a recursive definition f = body compiles to -- -- f = force (delay f. [force f / f] body). -- -- However, we have to be careful: in the informal notation above, -- all the variables are named 'f', but in fully renamed syntax they -- are different. Writing fT for the top-level f bound in a -- specific module etc. and fL for a locally bound f, we really -- have -- -- fT = force (delay fL. [force fL / fT] body) | f `S.member` setOf fvQ defn = return . (:[]) $ (fT, CForce (CProj L (CDelay (bind [qname fL] [substQC fT (CForce (CVar fL)) cdefn])))) -- A non-recursive definition just compiles simply. | otherwise = return [(fT, cdefn)] where fT, fL :: QName Core fT = coerce f fL = localName (coerce (qname f)) cdefn = compileThing desugarDefn defn -- A group of mutually recursive definitions {f = fbody, g = gbody, ...} -- compiles to -- { _grp = delay fL gL ... . (forceVars fbody, forceVars gbody, ...) -- , fT = fst (force _grp) -- , gT = snd (force _grp) -- , ... -- } -- where forceVars is the substitution [force fL / fT, force gL / gT, ...] compileDefnGroup defs = do grp :: QName Core <- freshQ "__grp" let (vars, bodies) = unzip defs varsT, varsL :: [QName Core] varsT = coerce vars varsL = map (localName . qname) varsT forceVars :: [(QName Core, Core)] forceVars = zipWith (\t l -> (t, CForce (CVar l))) varsT varsL bodies' :: [Core] bodies' = map (substsQC forceVars . compileThing desugarDefn) bodies return $ (grp, CDelay (bind (map qname varsL) bodies')) : zip varsT (for [0 ..] $ CForce . flip proj (CVar grp)) where proj :: Int -> Core -> Core proj 0 = CProj L proj n = proj (n -1) . CProj R ------------------------------------------------------------ -- Compiling terms ------------------------------------------------------------ -- | Compile a typechecked, desugared 'DTerm' to an untyped 'Core' -- term. compileDTerm :: Member Fresh r => DTerm -> Sem r Core compileDTerm (DTVar _ x) = return $ CVar (coerce x) compileDTerm (DTPrim ty x) = compilePrim ty x compileDTerm DTUnit = return CUnit compileDTerm (DTBool _ b) = return $ CInj (bool L R b) CUnit compileDTerm (DTChar c) = return $ CNum Fraction (toInteger (fromEnum c) % 1) compileDTerm (DTNat _ n) = return $ CNum Fraction (n % 1) -- compileNat ty n compileDTerm (DTRat r) = return $ CNum Decimal r compileDTerm term@(DTAbs q _ _) = do (xs, tys, body) <- unbindDeep term cbody <- compileDTerm body case q of Lam -> return $ abstract xs cbody Ex -> return $ quantify (OExists tys) (abstract xs cbody) All -> return $ quantify (OForall tys) (abstract xs cbody) where -- Gather nested abstractions with the same quantifier. unbindDeep :: Member Fresh r => DTerm -> Sem r ([Name DTerm], [Type], DTerm) unbindDeep (DTAbs q' ty l) | q == q' = do (name, inner) <- unbind l (ns, tys, body) <- unbindDeep inner return (name : ns, ty : tys, body) unbindDeep t = return ([], [], t) abstract :: [Name DTerm] -> Core -> Core abstract xs body = CAbs (bind (map coerce xs) body) quantify :: Op -> Core -> Core quantify op = CApp (CConst op) -- Special case for Cons, which compiles to a constructor application -- rather than a function application. compileDTerm (DTApp _ (DTPrim _ (PrimBOp Cons)) (DTPair _ t1 t2)) = CInj R <$> (CPair <$> compileDTerm t1 <*> compileDTerm t2) -- Special cases for left and right, which also compile to constructor applications. compileDTerm (DTApp _ (DTPrim _ PrimLeft) t) = CInj L <$> compileDTerm t compileDTerm (DTApp _ (DTPrim _ PrimRight) t) = CInj R <$> compileDTerm t compileDTerm (DTApp _ t1 t2) = CApp <$> compileDTerm t1 <*> compileDTerm t2 compileDTerm (DTPair _ t1 t2) = CPair <$> compileDTerm t1 <*> compileDTerm t2 compileDTerm (DTCase _ bs) = CApp <$> compileCase bs <*> pure CUnit compileDTerm (DTTyOp _ op ty) = return $ CApp (CConst (tyOps ! op)) (CType ty) where tyOps = M.fromList [ Enumerate ==> OEnum, Count ==> OCount ] compileDTerm (DTNil _) = return $ CInj L CUnit compileDTerm (DTTest info t) = CTest (coerce info) <$> compileDTerm t ------------------------------------------------------------ -- | Compile a natural number. A separate function is needed in -- case the number is of a finite type, in which case we must -- mod it by its type. -- compileNat :: Member Fresh r => Type -> Integer -> Sem r Core -- compileNat (TyFin n) x = return $ CNum Fraction ((x `mod` n) % 1) -- compileNat _ x = return $ CNum Fraction (x % 1) ------------------------------------------------------------ -- | Compile a primitive. Typically primitives turn into a -- corresponding function constant in the core language, but -- sometimes the particular constant it turns into may depend on the -- type. compilePrim :: Member Fresh r => Type -> Prim -> Sem r Core compilePrim (argTy :->: _) (PrimUOp uop) = return $ compileUOp argTy uop compilePrim ty p@(PrimUOp _) = compilePrimErr p ty -- This special case for Cons only triggers if we didn't hit the case -- for fully saturated Cons; just fall back to generating a lambda. Have to -- do it here, not in compileBOp, since we need to generate fresh names. compilePrim _ (PrimBOp Cons) = do hd <- fresh (string2Name "hd") tl <- fresh (string2Name "tl") return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) compilePrim _ PrimLeft = do a <- fresh (string2Name "a") return $ CAbs $ bind [a] $ CInj L (CVar (localName a)) compilePrim _ PrimRight = do a <- fresh (string2Name "a") return $ CAbs $ bind [a] $ CInj R (CVar (localName a)) compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop compilePrim ty p@(PrimBOp _) = compilePrimErr p ty compilePrim _ PrimSqrt = return $ CConst OSqrt compilePrim _ PrimFloor = return $ CConst OFloor compilePrim _ PrimCeil = return $ CConst OCeil compilePrim (TySet _ :->: _) PrimAbs = return $ CVar (Named Stdlib "container" .- string2Name "setSize") compilePrim (TyBag _ :->: _) PrimAbs = return $ CVar (Named Stdlib "container" .- string2Name "bagSize") compilePrim (TyList _ :->: _) PrimAbs = return $ CVar (Named Stdlib "list" .- string2Name "length") compilePrim _ PrimAbs = return $ CConst OAbs compilePrim (TySet _ :->: _) PrimPower = return $ CConst OPower compilePrim (TyBag _ :->: _) PrimPower = return $ CConst OPower compilePrim ty PrimPower = compilePrimErr PrimPower ty compilePrim (TySet _ :->: _) PrimList = return $ CConst OSetToList compilePrim (TyBag _ :->: _) PrimSet = return $ CConst OBagToSet compilePrim (TyBag _ :->: _) PrimList = return $ CConst OBagToList compilePrim (TyList _ :->: _) PrimSet = return $ CConst OListToSet compilePrim (TyList _ :->: _) PrimBag = return $ CConst OListToBag compilePrim _ p | p `elem` [PrimList, PrimBag, PrimSet] = return $ CConst OId compilePrim ty PrimList = compilePrimErr PrimList ty compilePrim ty PrimBag = compilePrimErr PrimBag ty compilePrim ty PrimSet = compilePrimErr PrimSet ty compilePrim _ PrimB2C = return $ CConst OBagToCounts compilePrim (_ :->: TyBag _) PrimC2B = return $ CConst OCountsToBag compilePrim ty PrimC2B = compilePrimErr PrimC2B ty compilePrim (_ :->: TyBag _) PrimUC2B = return $ CConst OUnsafeCountsToBag compilePrim ty PrimUC2B = compilePrimErr PrimUC2B ty compilePrim (TyMap _ _ :->: _) PrimMapToSet = return $ CConst OMapToSet compilePrim (_ :->: TyMap _ _) PrimSetToMap = return $ CConst OSetToMap compilePrim ty PrimMapToSet = compilePrimErr PrimMapToSet ty compilePrim ty PrimSetToMap = compilePrimErr PrimSetToMap ty compilePrim _ PrimSummary = return $ CConst OSummary compilePrim (_ :->: TyGraph _) PrimVertex = return $ CConst OVertex compilePrim (TyGraph _) PrimEmptyGraph = return $ CConst OEmptyGraph compilePrim (_ :->: TyGraph _) PrimOverlay = return $ CConst OOverlay compilePrim (_ :->: TyGraph _) PrimConnect = return $ CConst OConnect compilePrim ty PrimVertex = compilePrimErr PrimVertex ty compilePrim ty PrimEmptyGraph = compilePrimErr PrimEmptyGraph ty compilePrim ty PrimOverlay = compilePrimErr PrimOverlay ty compilePrim ty PrimConnect = compilePrimErr PrimConnect ty compilePrim _ PrimInsert = return $ CConst OInsert compilePrim _ PrimLookup = return $ CConst OLookup compilePrim (_ :*: TyList _ :->: _) PrimEach = return $ CVar (Named Stdlib "list" .- string2Name "eachlist") compilePrim (_ :*: TyBag _ :->: TyBag _) PrimEach = return $ CConst OEachBag compilePrim (_ :*: TySet _ :->: TySet _) PrimEach = return $ CConst OEachSet compilePrim ty PrimEach = compilePrimErr PrimEach ty compilePrim (_ :*: _ :*: TyList _ :->: _) PrimReduce = return $ CVar (Named Stdlib "list" .- string2Name "foldr") compilePrim (_ :*: _ :*: TyBag _ :->: _) PrimReduce = return $ CVar (Named Stdlib "container" .- string2Name "reducebag") compilePrim (_ :*: _ :*: TySet _ :->: _) PrimReduce = return $ CVar (Named Stdlib "container" .- string2Name "reduceset") compilePrim ty PrimReduce = compilePrimErr PrimReduce ty compilePrim (_ :*: TyList _ :->: _) PrimFilter = return $ CVar (Named Stdlib "list" .- string2Name "filterlist") compilePrim (_ :*: TyBag _ :->: _) PrimFilter = return $ CConst OFilterBag compilePrim (_ :*: TySet _ :->: _) PrimFilter = return $ CConst OFilterBag compilePrim ty PrimFilter = compilePrimErr PrimFilter ty compilePrim (_ :->: TyList _) PrimJoin = return $ CVar (Named Stdlib "list" .- string2Name "concat") compilePrim (_ :->: TyBag _) PrimJoin = return $ CConst OBagUnions compilePrim (_ :->: TySet _) PrimJoin = return $ CVar (Named Stdlib "container" .- string2Name "unions") compilePrim ty PrimJoin = compilePrimErr PrimJoin ty compilePrim (_ :*: TyBag _ :*: _ :->: _) PrimMerge = return $ CConst OMerge compilePrim (_ :*: TySet _ :*: _ :->: _) PrimMerge = return $ CConst OMerge compilePrim ty PrimMerge = compilePrimErr PrimMerge ty compilePrim _ PrimIsPrime = return $ CConst OIsPrime compilePrim _ PrimFactor = return $ CConst OFactor compilePrim _ PrimFrac = return $ CConst OFrac compilePrim _ PrimCrash = return $ CConst OCrash compilePrim _ PrimUntil = return $ CConst OUntil compilePrim _ PrimHolds = return $ CConst OHolds compilePrim _ PrimLookupSeq = return $ CConst OLookupSeq compilePrim _ PrimExtendSeq = return $ CConst OExtendSeq compilePrimErr :: Prim -> Type -> a compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad type " ++ show ty ------------------------------------------------------------ -- Case expressions ------------------------------------------------------------ -- | Compile a case expression of type τ to a core language expression -- of type (Unit → τ), in order to delay evaluation until explicitly -- applying it to the unit value. compileCase :: Member Fresh r => [DBranch] -> Sem r Core compileCase [] = return $ CAbs (bind [string2Name "_"] (CConst OMatchErr)) -- empty case ==> λ _ . error compileCase (b : bs) = do c1 <- compileBranch b c2 <- compileCase bs return $ CAbs (bind [string2Name "_"] (CApp c1 c2)) -- | Compile a branch of a case expression of type τ to a core -- language expression of type (Unit → τ) → τ. The idea is that it -- takes a failure continuation representing the subsequent branches -- in the case expression. If the branch succeeds, it just returns -- the associated expression of type τ; if it fails, it calls the -- continuation to proceed with the case analysis. compileBranch :: Member Fresh r => DBranch -> Sem r Core compileBranch b = do (gs, e) <- unbind b c <- compileDTerm e k <- fresh (string2Name "k") -- Fresh name for the failure continuation bc <- compileGuards (fromTelescope gs) k c return $ CAbs (bind [k] bc) -- | 'compileGuards' takes a list of guards, the name of the failure -- continuation of type (Unit → τ), and a Core term of type τ to -- return in the case of success, and compiles to an expression of -- type τ which evaluates the guards in sequence, ultimately -- returning the given expression if all guards succeed, or calling -- the failure continuation at any point if a guard fails. compileGuards :: Member Fresh r => [DGuard] -> Name Core -> Core -> Sem r Core compileGuards [] _ e = return e compileGuards (DGPat (unembed -> s) p : gs) k e = do e' <- compileGuards gs k e s' <- compileDTerm s compileMatch p s' k e' -- | 'compileMatch' takes a pattern, the compiled scrutinee, the name -- of the failure continuation, and a Core term representing the -- compilation of any guards which come after this one, and returns -- a Core expression of type τ that performs the match and either -- calls the failure continuation in the case of failure, or the -- rest of the guards in the case of success. compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs (bind [coerce x] e)) s -- Note in the below two cases that we can't just discard s since -- that would result in a lazy semantics. With an eager/strict -- semantics, we have to make sure s gets evaluated even if its -- value is then discarded. compileMatch (DPWild _) s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s compileMatch DPUnit s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s compileMatch (DPPair _ x1 x2) s _ e = do y <- fresh (string2Name "y") -- {? e when s is (x1,x2) ?} ==> (\y. (\x1.\x2. e) (fst y) (snd y)) s return $ CApp ( CAbs ( bind [y] ( CApp ( CApp (CAbs (bind [coerce x1, coerce x2] e)) (CProj L (CVar (localName y))) ) (CProj R (CVar (localName y))) ) ) ) s compileMatch (DPInj _ L x) s k e = -- {? e when s is left(x) ?} ==> case s of {left x -> e; right _ -> k unit} return $ CCase s (bind (coerce x) e) (bind (string2Name "_") (CApp (CVar (localName k)) CUnit)) compileMatch (DPInj _ R x) s k e = -- {? e when s is right(x) ?} ==> case s of {left _ -> k unit; right x -> e} return $ CCase s (bind (string2Name "_") (CApp (CVar (localName k)) CUnit)) (bind (coerce x) e) ------------------------------------------------------------ -- Unary and binary operators ------------------------------------------------------------ -- | Compile a unary operator. compileUOp :: -- | Type of the operator argument Type -> UOp -> Core compileUOp _ op = CConst (coreUOps ! op) where -- Just look up the corresponding core operator. coreUOps = M.fromList [ Neg ==> ONeg, Fact ==> OFact ] -- | Compile a binary operator. This function needs to know the types -- of the arguments and result since some operators are overloaded -- and compile to different code depending on their type. -- -- @arg1 ty -> arg2 ty -> result ty -> op -> result@ compileBOp :: Type -> Type -> Type -> BOp -> Core -- First, compile some operators specially for modular arithmetic. -- Most operators on TyFun (add, mul, sub, etc.) have already been -- desugared to an operation followed by a mod. The only operators -- here are the ones that have a special runtime behavior for Zn that -- can't be implemented in terms of other, existing operators: -- -- - Division on Zn needs to find modular inverses. -- - Divisibility testing on Zn similarly needs to find a gcd etc. -- - Exponentiation on Zn could in theory be implemented as a normal -- exponentiation on naturals followed by a mod, but that would be -- silly and inefficient. Instead we compile to a special modular -- exponentiation operator which takes mods along the way. Also, -- negative powers have similar requirements to division. -- -- We match on the type of arg1 because that is the only one which -- will consistently be TyFin in the case of Div, Exp, and Divides. -- compileBOp (TyFin n) _ _ op -- | op `elem` [Div, Exp, Divides] -- = CConst ((omOps ! op) n) -- where -- omOps = M.fromList -- [ Div ==> OMDiv -- , Exp ==> OMExp -- , Divides ==> OMDivides -- ] -- Graph operations are separate, but use the same syntax, as traditional -- addition and multiplication. compileBOp (TyGraph _) (TyGraph _) (TyGraph _) op | op `elem` [Add, Mul] = CConst (regularOps ! op) where regularOps = M.fromList [ Add ==> OOverlay, Mul ==> OConnect ] -- The Cartesian product operator just compiles to library function calls. compileBOp (TySet _) _ _ CartProd = CVar (Named Stdlib "container" .- string2Name "setCP") compileBOp (TyBag _) _ _ CartProd = CVar (Named Stdlib "container" .- string2Name "bagCP") compileBOp (TyList _) _ _ CartProd = CVar (Named Stdlib "list" .- string2Name "listCP") -- Some regular arithmetic operations that just translate straightforwardly. compileBOp _ _ _ op | op `M.member` regularOps = CConst (regularOps ! op) where regularOps = M.fromList [ Add ==> OAdd, Mul ==> OMul, Div ==> ODiv, Exp ==> OExp, Mod ==> OMod, Divides ==> ODivides, Choose ==> OMultinom, Eq ==> OEq, Lt ==> OLt ] -- ShouldEq needs to know the type at which the comparison is -- occurring, so values can be correctly pretty-printed if the test -- fails. compileBOp ty _ _ ShouldEq = CConst (OShouldEq ty) compileBOp _ty (TyList _) _ Elem = CConst OListElem compileBOp _ty _ _ Elem = CConst OBagElem compileBOp ty1 ty2 resTy op = error $ "Impossible! missing case in compileBOp: " ++ show (ty1, ty2, resTy, op)