module CLaSH.Core.Util where
import Control.Monad.Trans.Except (Except, throwE)
import qualified Data.HashMap.Strict as HMS
import qualified Data.HashMap.Lazy as HashMap
import Data.HashMap.Lazy (HashMap)
import qualified Data.HashSet as HashSet
import Data.Maybe (fromJust, mapMaybe)
import Unbound.Generics.LocallyNameless (Fresh, bind, embed, rebind,
string2Name, unbind, unembed,
unrebind, unrec)
import Unbound.Generics.LocallyNameless.Name (name2String)
import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)
import CLaSH.Core.DataCon (DataCon, dcType, dataConInstArgTys)
import CLaSH.Core.Literal (literalType)
import CLaSH.Core.Pretty (showDoc)
import CLaSH.Core.Term (LetBinding, Pat (..), Term (..),
TmName)
import CLaSH.Core.Type (Kind, LitTy (..), TyName,
Type (..), TypeView (..), applyTy,
findFunSubst, isFunTy,
isPolyFunCoreTy, mkFunTy,
splitFunTy, tyView)
import CLaSH.Core.TyCon (TyCon (..), TyConName,
tyConDataCons)
import CLaSH.Core.TysPrim (typeNatKind)
import CLaSH.Core.Var (Id, TyVar, Var (..), varType)
import CLaSH.Util
type Gamma = HashMap TmName Type
type Delta = HashMap TyName Kind
termType :: (Functor m, Fresh m)
=> HashMap TyConName TyCon
-> Term
-> m Type
termType m e = case e of
Var t _ -> return t
Data dc -> return $ dcType dc
Literal l -> return $ literalType l
Prim _ t -> return t
Lam b -> do (v,e') <- unbind b
mkFunTy (unembed $ varType v) <$> termType m e'
TyLam b -> do (tv,e') <- unbind b
ForAllTy <$> bind tv <$> termType m e'
App _ _ -> case collectArgs e of
(fun, args) -> termType m fun >>=
(flip (applyTypeToArgs m) args)
TyApp e' ty -> termType m e' >>= (\f -> applyTy m f ty)
Letrec b -> do (_,e') <- unbind b
termType m e'
Case _ ty _ -> return ty
collectArgs :: Term
-> (Term, [Either Term Type])
collectArgs = go []
where
go args (App e1 e2) = go (Left e2:args) e1
go args (TyApp e t) = go (Right t:args) e
go args e = (e, args)
collectBndrs :: Fresh m
=> Term
-> m ([Either Id TyVar], Term)
collectBndrs = go []
where
go bs (Lam b) = do
(v,e') <- unbind b
go (Left v:bs) e'
go bs (TyLam b) = do
(tv,e') <- unbind b
go (Right tv:bs) e'
go bs e' = return (reverse bs,e')
applyTypeToArgs :: Fresh m
=> HashMap TyConName TyCon
-> Type
-> [Either Term Type]
-> m Type
applyTypeToArgs _ opTy [] = return opTy
applyTypeToArgs m opTy (Right ty:args) = applyTy m opTy ty >>=
(flip (applyTypeToArgs m) args)
applyTypeToArgs m opTy (Left e:args) = case splitFunTy m opTy of
Just (_,resTy) -> applyTypeToArgs m resTy args
Nothing -> error $
concat [ $(curLoc)
, "applyTypeToArgs splitFunTy: not a funTy:\n"
, "opTy: "
, showDoc opTy
, "\nTerm: "
, showDoc e
, "\nOtherArgs: "
, unlines (map (either showDoc showDoc) args)
]
patIds :: Pat -> [Id]
patIds (DataPat _ ids) = snd $ unrebind ids
patIds _ = []
mkTyVar :: Kind
-> TyName
-> TyVar
mkTyVar tyKind tyName = TyVar tyName (embed tyKind)
mkId :: Type
-> TmName
-> Id
mkId tmType tmName = Id tmName (embed tmType)
mkAbstraction :: Term
-> [Either Id TyVar]
-> Term
mkAbstraction = foldr (either (Lam `dot` bind) (TyLam `dot` bind))
mkTyLams :: Term
-> [TyVar]
-> Term
mkTyLams tm = mkAbstraction tm . map Right
mkLams :: Term
-> [Id]
-> Term
mkLams tm = mkAbstraction tm . map Left
mkApps :: Term
-> [Either Term Type]
-> Term
mkApps = foldl (\e a -> either (App e) (TyApp e) a)
mkTmApps :: Term
-> [Term]
-> Term
mkTmApps = foldl App
mkTyApps :: Term
-> [Type]
-> Term
mkTyApps = foldl TyApp
isFun :: (Functor m, Fresh m)
=> HashMap TyConName TyCon
-> Term
-> m Bool
isFun m t = fmap (isFunTy m) $ (termType m) t
isPolyFun :: (Functor m, Fresh m)
=> HashMap TyConName TyCon
-> Term
-> m Bool
isPolyFun m t = isPolyFunCoreTy m <$> termType m t
isLam :: Term
-> Bool
isLam (Lam _) = True
isLam _ = False
isLet :: Term
-> Bool
isLet (Letrec _) = True
isLet _ = False
isVar :: Term
-> Bool
isVar (Var _ _) = True
isVar _ = False
isCon :: Term
-> Bool
isCon (Data _) = True
isCon _ = False
isPrim :: Term
-> Bool
isPrim (Prim _ _) = True
isPrim _ = False
idToVar :: Id
-> Term
idToVar (Id nm tyE) = Var (unembed tyE) nm
idToVar tv = error $ $(curLoc) ++ "idToVar: tyVar: " ++ showDoc tv
varToId :: Term
-> Id
varToId (Var ty nm) = Id nm (embed ty)
varToId e = error $ $(curLoc) ++ "varToId: not a var: " ++ showDoc e
termSize :: Term
-> Int
termSize (Var _ _) = 1
termSize (Data _) = 1
termSize (Literal _) = 1
termSize (Prim _ _) = 1
termSize (Lam b) = let (_,e) = unsafeUnbind b
in termSize e + 1
termSize (TyLam b) = let (_,e) = unsafeUnbind b
in termSize e
termSize (App e1 e2) = termSize e1 + termSize e2
termSize (TyApp e _) = termSize e
termSize (Letrec b) = let (bndrsR,body) = unsafeUnbind b
bndrSzs = map (termSize . unembed . snd) (unrec bndrsR)
bodySz = termSize body
in sum (bodySz:bndrSzs)
termSize (Case subj _ alts) = let subjSz = termSize subj
altSzs = map (termSize . snd . unsafeUnbind) alts
in sum (subjSz:altSzs)
mkVec :: DataCon
-> DataCon
-> Type
-> Integer
-> [Term]
-> Term
mkVec nilCon consCon resTy = go
where
go _ [] = mkApps (Data nilCon) [Right (LitTy (NumTy 0))
,Right resTy
,Left (Prim "_CO_" nilCoTy)
]
go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n))
,Right resTy
,Right (LitTy (NumTy (n1)))
,Left (Prim "_CO_" (consCoTy n))
,Left x
,Left (go (n1) xs)]
nilCoTy = head (fromJust $! dataConInstArgTys nilCon [(LitTy (NumTy 0))
,resTy])
consCoTy n = head (fromJust $! dataConInstArgTys consCon
[(LitTy (NumTy n))
,resTy
,(LitTy (NumTy (n1)))])
appendToVec :: DataCon
-> Type
-> Term
-> Integer
-> [Term]
-> Term
appendToVec consCon resTy vec = go
where
go _ [] = vec
go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n))
,Right resTy
,Right (LitTy (NumTy (n1)))
,Left (Prim "_CO_" (consCoTy n))
,Left x
,Left (go (n1) xs)]
consCoTy n = head (fromJust $! dataConInstArgTys consCon
[(LitTy (NumTy n))
,resTy
,(LitTy (NumTy (n1)))])
extractElems :: DataCon
-> Type
-> Char
-> Integer
-> Term
-> [(Term,[LetBinding])]
extractElems consCon resTy s maxN = go maxN
where
go :: Integer -> Term -> [(Term,[LetBinding])]
go 0 _ = []
go n e = (elVar
,[(Id elBNm (embed resTy) ,embed lhs)
,(Id restBNm (embed restTy),embed rhs)
]
) :
go (n1) (Var restTy restBNm)
where
elBNm = string2Name ("el" ++ s:show (maxNn))
restBNm = string2Name ("rest" ++ s:show (maxNn))
elVar = Var resTy elBNm
pat = DataPat (embed consCon) (rebind [mTV] [co,el,rest])
elPatNm = string2Name "el"
restPatNm = string2Name "rest"
lhs = Case e resTy [bind pat (Var resTy elPatNm)]
rhs = Case e restTy [bind pat (Var restTy restPatNm)]
mName = string2Name "m"
mTV = TyVar mName (embed typeNatKind)
tys = [(LitTy (NumTy n)),resTy,(LitTy (NumTy (n1)))]
(Just idTys) = dataConInstArgTys consCon tys
[co,el,rest] = zipWith Id [string2Name "_co_",elPatNm, restPatNm]
(map embed idTys)
restTy = last (fromJust (dataConInstArgTys consCon tys))
isSignalType :: HashMap TyConName TyCon -> Type -> Bool
isSignalType tcm ty = go HashSet.empty ty
where
go tcSeen (tyView -> TyConApp tcNm args) = case name2String tcNm of
"CLaSH.Signal.Internal.Signal'" -> True
_ | tcNm `HashSet.member` tcSeen -> False
| otherwise -> case HashMap.lookup tcNm tcm of
Just tc -> let dcs = tyConDataCons tc
dcInsArgTys = concat
$ mapMaybe (`dataConInstArgTys` args) dcs
tcSeen' = HashSet.insert tcNm tcSeen
in any (go tcSeen') dcInsArgTys
Nothing -> traceIf True ($(curLoc) ++ "isSignalType: " ++ show tcNm
++ " not found.") False
go _ _ = False
tyNatSize :: HMS.HashMap TyConName TyCon
-> Type
-> Except String Integer
tyNatSize tcm ty = case go ty of
Right (Left i) -> return i
Right _ -> throwE $ $(curLoc) ++ "Cannot reduce an integer: " ++ show ty
Left msg -> throwE msg
where
go :: Type -> Either String (Either Integer Bool)
go (LitTy (NumTy i)) = return (Left i)
go (tyView -> TyConApp tc tys)
| name2String tc == "GHC.TypeLits.+"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 + i2))
| name2String tc == "GHC.TypeLits.*"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 * i2))
| name2String tc == "GHC.TypeLits.^"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 ^ i2))
| name2String tc == "GHC.TypeLits.-"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 i2))
| name2String tc == "CLaSH.Promoted.Ord.Max"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 `max` i2))
| name2String tc == "CLaSH.Promoted.Ord.Min"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 `min` i2))
| name2String tc == "GHC.TypeLits.Extra.CLog"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
, i1 > 1
, i2 > 2
= return (Left (ceiling (logBase (fromIntegral i1 :: Double)
(fromIntegral i2 :: Double))))
| name2String tc == "GHC.TypeLits.Extra.GCD"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Left (i1 `gcd` i2))
| name2String tc == "Data.Type.Bool.If"
, TyConApp tcNat _ <- tyView (tys !! 0)
, name2String tcNat == "GHC.TypeLits.Nat"
, Right (Right b) <- go (tys !! 1)
, Right (Left i1) <- go (tys !! 2)
, Right (Left i2) <- go (tys !! 3)
= if b then return (Left i1)
else return (Left i2)
| name2String tc == "GHC.TypeLits.<=?"
, length tys == 2
, Right (Left i1) <- go (tys !! 0)
, Right (Left i2) <- go (tys !! 1)
= return (Right (i1 <= i2))
| FunTyCon {tyConSubst = tcSubst} <- tcm HMS.! tc
, Just ty' <- findFunSubst tcSubst tys
= go ty'
go t = Left ($(curLoc) ++ "Can't convert tyNat: " ++ show t)