{-| Copyright : (C) 2012-2016, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij Smart constructor and destructor functions for CoreHW -} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} module Clash.Core.Util where import Control.Monad.Trans.Except (Except, throwE) import qualified Data.HashMap.Strict as HMS import qualified Data.HashMap.Lazy as HashMap import Data.HashMap.Lazy (HashMap) import qualified Data.HashSet as HashSet import Data.Maybe (fromJust, mapMaybe) import Unbound.Generics.LocallyNameless (Fresh, bind, embed, rebind, unbind, unembed, unrebind, unrec) import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) import Clash.Core.DataCon (DataCon, dcType, dataConInstArgTys) import Clash.Core.Literal (literalType) import Clash.Core.Name (Name (..), name2String, string2SystemName) import Clash.Core.Pretty (showDoc) import Clash.Core.Term (LetBinding, Pat (..), Term (..), TmName, TmOccName) import Clash.Core.Type (Kind, LitTy (..), TyName, TyOccName, Type (..), TypeView (..), applyTy, coreView, isFunTy, isPolyFunCoreTy, mkFunTy, splitFunTy, tyView) import Clash.Core.TyCon (TyCon (..), TyConOccName, tyConDataCons) import Clash.Core.TysPrim (typeNatKind) import Clash.Core.Var (Id, TyVar, Var (..), varType) import Clash.Util -- | Type environment/context type Gamma = HashMap TmOccName Type -- | Kind environment/context type Delta = HashMap TyOccName Kind -- | Determine the type of a term termType :: Fresh m => HashMap TyConOccName TyCon -> Term -> m Type termType m e = case e of Var t _ -> return t Data dc -> return $ dcType dc Literal l -> return $ literalType l Prim _ t -> return t Lam b -> do (v,e') <- unbind b mkFunTy (unembed $ varType v) <$> termType m e' TyLam b -> do (tv,e') <- unbind b ForAllTy <$> bind tv <$> termType m e' App _ _ -> case collectArgs e of (fun, args) -> termType m fun >>= (flip (applyTypeToArgs m) args) TyApp e' ty -> termType m e' >>= (\f -> applyTy m f ty) Letrec b -> do (_,e') <- unbind b termType m e' Case _ ty _ -> return ty Cast _ _ ty2 -> return ty2 -- | Split a (Type)Application in the applied term and it arguments collectArgs :: Term -> (Term, [Either Term Type]) collectArgs = go [] where go args (App e1 e2) = go (Left e2:args) e1 go args (TyApp e t) = go (Right t:args) e go args e = (e, args) -- | Split a (Type)Abstraction in the bound variables and the abstracted term collectBndrs :: Fresh m => Term -> m ([Either Id TyVar], Term) collectBndrs = go [] where go bs (Lam b) = do (v,e') <- unbind b go (Left v:bs) e' go bs (TyLam b) = do (tv,e') <- unbind b go (Right tv:bs) e' go bs e' = return (reverse bs,e') -- | Get the result type of a polymorphic function given a list of arguments applyTypeToArgs :: Fresh m => HashMap TyConOccName TyCon -> Type -> [Either Term Type] -> m Type applyTypeToArgs _ opTy [] = return opTy applyTypeToArgs m opTy (Right ty:args) = applyTy m opTy ty >>= (flip (applyTypeToArgs m) args) applyTypeToArgs m opTy (Left e:args) = case splitFunTy m opTy of Just (_,resTy) -> applyTypeToArgs m resTy args Nothing -> error $ concat [ $(curLoc) , "applyTypeToArgs splitFunTy: not a funTy:\n" , "opTy: " , showDoc opTy , "\nTerm: " , showDoc e , "\nOtherArgs: " , unlines (map (either showDoc showDoc) args) ] -- | Get the list of term-binders out of a DataType pattern patIds :: Pat -> [Id] patIds (DataPat _ ids) = snd $ unrebind ids patIds _ = [] -- | Make a type variable mkTyVar :: Kind -> TyName -> TyVar mkTyVar tyKind tyName = TyVar tyName (embed tyKind) -- | Make a term variable mkId :: Type -> TmName -> Id mkId tmType tmName = Id tmName (embed tmType) -- | Abstract a term over a list of term and type variables mkAbstraction :: Term -> [Either Id TyVar] -> Term mkAbstraction = foldr (either (Lam `dot` bind) (TyLam `dot` bind)) -- | Abstract a term over a list of term variables mkTyLams :: Term -> [TyVar] -> Term mkTyLams tm = mkAbstraction tm . map Right -- | Abstract a term over a list of type variables mkLams :: Term -> [Id] -> Term mkLams tm = mkAbstraction tm . map Left -- | Apply a list of types and terms to a term mkApps :: Term -> [Either Term Type] -> Term mkApps = foldl (\e a -> either (App e) (TyApp e) a) -- | Apply a list of terms to a term mkTmApps :: Term -> [Term] -> Term mkTmApps = foldl App -- | Apply a list of types to a term mkTyApps :: Term -> [Type] -> Term mkTyApps = foldl TyApp -- | Does a term have a function type? isFun :: Fresh m => HashMap TyConOccName TyCon -> Term -> m Bool isFun m t = fmap (isFunTy m) $ (termType m) t -- | Does a term have a function or polymorphic type? isPolyFun :: Fresh m => HashMap TyConOccName TyCon -> Term -> m Bool isPolyFun m t = isPolyFunCoreTy m <$> termType m t -- | Is a term a term-abstraction? isLam :: Term -> Bool isLam (Lam _) = True isLam _ = False -- | Is a term a recursive let-binding? isLet :: Term -> Bool isLet (Letrec _) = True isLet _ = False -- | Is a term a variable reference? isVar :: Term -> Bool isVar (Var _ _) = True isVar _ = False -- | Is a term a datatype constructor? isCon :: Term -> Bool isCon (Data _) = True isCon _ = False -- | Is a term a primitive? isPrim :: Term -> Bool isPrim (Prim _ _) = True isPrim _ = False -- | Make variable reference out of term variable idToVar :: Id -> Term idToVar (Id nm tyE) = Var (unembed tyE) nm idToVar tv = error $ $(curLoc) ++ "idToVar: tyVar: " ++ showDoc tv -- | Make a term variable out of a variable reference varToId :: Term -> Id varToId (Var ty nm) = Id nm (embed ty) varToId e = error $ $(curLoc) ++ "varToId: not a var: " ++ showDoc e termSize :: Term -> Word termSize (Var _ _) = 1 termSize (Data _) = 1 termSize (Literal _) = 1 termSize (Prim _ _) = 1 termSize (Lam b) = let (_,e) = unsafeUnbind b in termSize e + 1 termSize (TyLam b) = let (_,e) = unsafeUnbind b in termSize e termSize (App e1 e2) = termSize e1 + termSize e2 termSize (TyApp e _) = termSize e termSize (Cast e _ _)= termSize e termSize (Letrec b) = let (bndrsR,body) = unsafeUnbind b bndrSzs = map (termSize . unembed . snd) (unrec bndrsR) bodySz = termSize body in sum (bodySz:bndrSzs) termSize (Case subj _ alts) = let subjSz = termSize subj altSzs = map (termSize . snd . unsafeUnbind) alts in sum (subjSz:altSzs) -- | Create a vector of supplied elements mkVec :: DataCon -- ^ The Nil constructor -> DataCon -- ^ The Cons (:>) constructor -> Type -- ^ Element type -> Integer -- ^ Length of the vector -> [Term] -- ^ Elements to put in the vector -> Term mkVec nilCon consCon resTy = go where go _ [] = mkApps (Data nilCon) [Right (LitTy (NumTy 0)) ,Right resTy ,Left (Prim "_CO_" nilCoTy) ] go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n)) ,Right resTy ,Right (LitTy (NumTy (n-1))) ,Left (Prim "_CO_" (consCoTy n)) ,Left x ,Left (go (n-1) xs)] nilCoTy = head (fromJust $! dataConInstArgTys nilCon [(LitTy (NumTy 0)) ,resTy]) consCoTy n = head (fromJust $! dataConInstArgTys consCon [(LitTy (NumTy n)) ,resTy ,(LitTy (NumTy (n-1)))]) -- | Append elements to the supplied vector appendToVec :: DataCon -- ^ The Cons (:>) constructor -> Type -- ^ Element type -> Term -- ^ The vector to append the elements to -> Integer -- ^ Length of the vector -> [Term] -- ^ Elements to append -> Term appendToVec consCon resTy vec = go where go _ [] = vec go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n)) ,Right resTy ,Right (LitTy (NumTy (n-1))) ,Left (Prim "_CO_" (consCoTy n)) ,Left x ,Left (go (n-1) xs)] consCoTy n = head (fromJust $! dataConInstArgTys consCon [(LitTy (NumTy n)) ,resTy ,(LitTy (NumTy (n-1)))]) -- | Create let-bindings with case-statements that select elements out of a -- vector. Returns both the variables to which element-selections are bound -- and the let-bindings extractElems :: DataCon -- ^ The Cons (:>) constructor -> Type -- ^ The element type -> Char -- ^ Char to append to the bound variable names -> Integer -- ^ Length of the vector -> Term -- ^ The vector -> [(Term,[LetBinding])] extractElems consCon resTy s maxN = go maxN where go :: Integer -> Term -> [(Term,[LetBinding])] go 0 _ = [] go n e = (elVar ,[(Id elBNm (embed resTy) ,embed lhs) ,(Id restBNm (embed restTy),embed rhs) ] ) : go (n-1) (Var restTy restBNm) where elBNm = string2SystemName ("el" ++ s:show (maxN-n)) restBNm = string2SystemName ("rest" ++ s:show (maxN-n)) elVar = Var resTy elBNm pat = DataPat (embed consCon) (rebind [mTV] [co,el,rest]) elPatNm = string2SystemName "el" restPatNm = string2SystemName "rest" lhs = Case e resTy [bind pat (Var resTy elPatNm)] rhs = Case e restTy [bind pat (Var restTy restPatNm)] mName = string2SystemName "m" mTV = TyVar mName (embed typeNatKind) tys = [(LitTy (NumTy n)),resTy,(LitTy (NumTy (n-1)))] (Just idTys) = dataConInstArgTys consCon tys [co,el,rest] = zipWith Id [string2SystemName "_co_",elPatNm, restPatNm] (map embed idTys) restTy = last (fromJust (dataConInstArgTys consCon tys)) -- | Create let-bindings with case-statements that select elements out of a -- tree. Returns both the variables to which element-selections are bound -- and the let-bindings extractTElems :: DataCon -- ^ The 'LR' constructor -> DataCon -- ^ The 'BR' constructor -> Type -- ^ The element type -> Char -- ^ Char to append to the bound variable names -> Integer -- ^ Depth of the tree -> Term -- ^ The tree -> ([Term],[LetBinding]) extractTElems lrCon brCon resTy s maxN = go maxN [0..(2^(maxN+1))-2] [0..(2^maxN - 1)] where go :: Integer -> [Int] -> [Int] -> Term -> ([Term],[LetBinding]) go 0 _ ks e = ([elVar],[(Id elBNm (embed resTy), embed lhs)]) where elBNm = string2SystemName ("el" ++ s:show (head ks)) elVar = Var resTy elBNm pat = DataPat (embed lrCon) (rebind [] [co,el]) elPatNm = string2SystemName "el" lhs = Case e resTy [bind pat (Var resTy elPatNm)] tys = [LitTy (NumTy 0),resTy] (Just idTys) = dataConInstArgTys lrCon tys [co,el] = zipWith Id [string2SystemName "_co_",elPatNm] (map embed idTys) go n bs ks e = (lVars ++ rVars,(Id ltBNm (embed brTy),embed ltLhs): (Id rtBNm (embed brTy),embed rtLhs): (lBinds ++ rBinds)) where ltBNm = string2SystemName ("lt" ++ s:show b0) rtBNm = string2SystemName ("rt" ++ s:show b1) ltVar = Var brTy ltBNm rtVar = Var brTy rtBNm pat = DataPat (embed brCon) (rebind [mTV] [co,lt,rt]) ltPatNm = string2SystemName "lt" rtPatNm = string2SystemName "rt" ltLhs = Case e brTy [bind pat (Var brTy ltPatNm)] rtLhs = Case e brTy [bind pat (Var brTy rtPatNm)] mName = string2SystemName "m" mTV = TyVar mName (embed typeNatKind) tys = [LitTy (NumTy n),resTy,LitTy (NumTy (n-1))] (Just idTys) = dataConInstArgTys brCon tys [co,lt,rt] = zipWith Id [string2SystemName "_co_",ltPatNm,rtPatNm] (map embed idTys) brTy = last idTys (kL,kR) = splitAt (length ks `div` 2) ks (b0:bL,b1:bR) = splitAt (length bs `div` 2) bs (lVars,lBinds) = go (n-1) bL kL ltVar (rVars,rBinds) = go (n-1) bR kR rtVar -- | Create a vector of supplied elements mkRTree :: DataCon -- ^ The LR constructor -> DataCon -- ^ The BR constructor -> Type -- ^ Element type -> Integer -- ^ Depth of the tree -> [Term] -- ^ Elements to put in the tree -> Term mkRTree lrCon brCon resTy = go where go _ [x] = mkApps (Data lrCon) [Right (LitTy (NumTy 0)) ,Right resTy ,Left (Prim "_CO_" lrCoTy) ,Left x ] go n xs = let (xsL,xsR) = splitAt (length xs `div` 2) xs in mkApps (Data brCon) [Right (LitTy (NumTy n)) ,Right resTy ,Right (LitTy (NumTy (n-1))) ,Left (Prim "_CO_" (brCoTy n)) ,Left (go (n-1) xsL) ,Left (go (n-1) xsR)] lrCoTy = head (fromJust $! dataConInstArgTys lrCon [(LitTy (NumTy 0)) ,resTy]) brCoTy n = head (fromJust $! dataConInstArgTys brCon [(LitTy (NumTy n)) ,resTy ,(LitTy (NumTy (n-1)))]) -- | Determine whether a type is isomorphic to "Clash.Signal.Internal.Signal'" -- -- It is i.e.: -- -- * Signal' clk a -- * (Signal' clk a, Signal' clk b) -- * Vec n (Signal' clk a) -- * data Wrap = W (Signal clk' Int) -- * etc. isSignalType :: HashMap TyConOccName TyCon -> Type -> Bool isSignalType tcm ty = go HashSet.empty ty where go tcSeen (tyView -> TyConApp tcNm args) = case name2String tcNm of "Clash.Signal.Internal.Signal" -> True _ | tcNm `HashSet.member` tcSeen -> False -- Do not follow rec types | otherwise -> case HashMap.lookup (nameOcc tcNm) tcm of Just tc -> let dcs = tyConDataCons tc dcInsArgTys = concat $ mapMaybe (`dataConInstArgTys` args) dcs tcSeen' = HashSet.insert tcNm tcSeen in any (go tcSeen') dcInsArgTys Nothing -> traceIf True ($(curLoc) ++ "isSignalType: " ++ show tcNm ++ " not found.") False go _ _ = False isClockOrReset :: HashMap TyConOccName TyCon -> Type -> Bool isClockOrReset m (coreView m -> Just ty) = isClockOrReset m ty isClockOrReset _ (tyView -> TyConApp tcNm _) = case name2String tcNm of "Clash.Signal.Internal.Clock" -> True "Clash.Signal.Internal.Reset" -> True _ -> False isClockOrReset _ _ = False tyNatSize :: HMS.HashMap TyConOccName TyCon -> Type -> Except String Integer tyNatSize m (coreView m -> Just ty) = tyNatSize m ty tyNatSize _ (LitTy (NumTy i)) = return i tyNatSize _ ty = throwE $ $(curLoc) ++ "Cannot reduce an integer: " ++ show ty