{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Core.Evaluator where
import Control.Arrow (second)
import Control.Concurrent.Supply (Supply, freshId)
import Control.Lens (view, _4)
import Data.Bits (shiftL)
import Data.Either (lefts,rights)
import Data.List
(foldl',mapAccumL,uncons)
import Data.IntMap (IntMap)
import qualified Data.Primitive.ByteArray as BA
import qualified Data.Vector.Primitive as PV
import Data.Text (Text)
import Data.Text.Prettyprint.Doc
import Debug.Trace (trace)
import GHC.Integer.GMP.Internals
(Integer (..), BigNat (..))
import Clash.Core.DataCon
import Clash.Core.FreeVars
import Clash.Core.Literal
import Clash.Core.Name
import Clash.Core.Pretty
import Clash.Core.Subst
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Util
import Clash.Core.Var
import Clash.Core.VarEnv
import Clash.Driver.Types (BindingMap)
import Prelude hiding (lookup)
import Clash.Unique
import Clash.Util (curLoc)
import Clash.Pretty
data Heap = Heap GlobalHeap GPureHeap PureHeap Supply InScopeSet
type PureHeap = VarEnv Term
newtype GPureHeap = GPureHeap { unGPureHeap :: PureHeap }
type GlobalHeap = (IntMap Term, Int)
type Stack = [StackFrame]
data StackFrame
= Update Id
| GUpdate Id
| Apply Id
| Instantiate Type
| PrimApply Text PrimInfo [Type] [Value] [Term]
| Scrutinise Type [Alt]
| Tickish TickInfo
deriving Show
instance ClashPretty StackFrame where
clashPretty (Update i) = hsep ["Update", fromPpr i]
clashPretty (GUpdate i) = hsep ["GUpdate", fromPpr i]
clashPretty (Apply i) = hsep ["Apply", fromPpr i]
clashPretty (Instantiate t) = hsep ["Instantiate", fromPpr t]
clashPretty (PrimApply a b c d e) = do
hsep ["PrimApply", fromPretty a, "::", fromPpr (primType b),
"; type args=", fromPpr c,
"; val args=", fromPpr (map valToTerm d),
"term args=", fromPpr e]
clashPretty (Scrutinise a b) =
hsep ["Scrutinise ", fromPpr a,
fromPpr (Case (Literal (CharLiteral '_')) a b)]
clashPretty (Tickish sp) =
hsep ["Tick", fromPpr sp]
mkTickish
:: Stack
-> [TickInfo]
-> Stack
mkTickish s sps = map Tickish sps ++ s
data Value
= Lambda Id Term
| TyLambda TyVar Term
| DC DataCon [Either Term Type]
| Lit Literal
| PrimVal Text PrimInfo [Type] [Value]
| Suspend Term
deriving Show
type State = (Heap, Stack, Term)
type PrimEvaluator =
Bool ->
TyConMap ->
Heap ->
Stack ->
Text ->
PrimInfo ->
[Type] ->
[Value] ->
Maybe State
whnf'
:: PrimEvaluator
-> BindingMap
-> TyConMap
-> GlobalHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (GlobalHeap, PureHeap, Term)
whnf' eval gbl0 tcm gh ids is isSubj e
= case whnf eval tcm isSubj (Heap gh gbl1 emptyVarEnv ids is,[],e) of
(Heap gh' _ ph' _ _,_,e') -> (gh',ph',e')
where
gbl1 = GPureHeap (mapVarEnv (view _4) gbl0)
whnf
:: PrimEvaluator
-> TyConMap
-> Bool
-> State
-> State
whnf eval tcm isSubj (h,k,e) =
if isSubj
then go (h,Scrutinise ty []:k,e)
else go (h,k,e)
where
ty = termType tcm e
go s = case step eval tcm s of
Just s' -> go s'
Nothing
| Just e' <- unwindStack s
-> e'
| otherwise
-> error $ showDoc $ ppr e
isScrut :: Stack -> Bool
isScrut (Scrutinise {}:_) = True
isScrut (PrimApply {} :_) = True
isScrut (Tickish {}:k) = isScrut k
isScrut _ = False
unwindStack :: State -> Maybe State
unwindStack s@(_,[],_) = Just s
unwindStack (h@(Heap gh gbl h' ids is),(kf:k'),e) = case kf of
PrimApply nm ty tys vs tms ->
unwindStack
(h,k'
,foldl' App
(foldl' App (foldl' TyApp (Prim nm ty) tys) (map valToTerm vs))
(e:tms))
Instantiate ty ->
unwindStack (h,k',TyApp e ty)
Apply id_ -> do
case lookupVarEnv id_ h' of
Just e' -> unwindStack (h,k',App e e')
Nothing -> error $ unlines
$ [ "Clash.Core.Evaluator.unwindStack:"
, "Stack:"
] ++
[ " "++ showDoc (clashPretty frame) | frame <- kf:k'] ++
[ ""
, "Expression:"
, showPpr e
, ""
, "Heap:"
, showDoc (clashPretty h')
]
Scrutinise _ [] ->
unwindStack (h,k',e)
Scrutinise ty alts ->
unwindStack (h,k',Case e ty alts)
Update x ->
unwindStack (Heap gh gbl (extendVarEnv x e h') ids is,k',e)
GUpdate _ ->
unwindStack (h,k',e)
Tickish sp ->
unwindStack (h,k',Tick sp e)
step
:: PrimEvaluator
-> TyConMap
-> State
-> Maybe State
step eval tcm (h, k, e) = case e of
Var v -> force h k v
(Lam x e') -> unwind eval tcm h k (Lambda x e')
(TyLam x e') -> unwind eval tcm h k (TyLambda x e')
(Literal l) -> unwind eval tcm h k (Lit l)
(App e1 e2)
| (Data dc,args,_ticks) <- collectArgsTicks e
, (tys,_) <- splitFunForallTy (dcType dc)
-> case compare (length args) (length tys) of
EQ -> unwind eval tcm h k (DC dc args)
LT -> let (tys',_) = splitFunForallTy (termType tcm e)
(h2,e') = mkAbstr (h,e) tys'
in step eval tcm (h2,k,e')
GT -> error "Overapplied DC"
| (Prim nm pInfo,args,_ticks) <- collectArgsTicks e
, let ty = primType pInfo
, (tys,_) <- splitFunForallTy ty
-> case compare (length args) (length tys) of
EQ -> let (e':es) = lefts args
in Just (h,PrimApply nm pInfo (rights args) [] es:k,e')
LT -> let (tys',_) = splitFunForallTy (termType tcm e)
(h2,e') = mkAbstr (h,e) tys'
in step eval tcm (h2,k,e')
GT -> let (h2,id_) = newLetBinding tcm h e2
in Just (h2,Apply id_:k,e1)
(TyApp e1 ty)
| (Data dc,args,_ticks) <- collectArgsTicks e
, (tys,_) <- splitFunForallTy (dcType dc)
-> case compare (length args) (length tys) of
EQ -> unwind eval tcm h k (DC dc args)
LT -> let (tys',_) = splitFunForallTy (termType tcm e)
(h2,e') = mkAbstr (h,e) tys'
in step eval tcm (h2,k,e')
GT -> error "Overapplied DC"
| (Prim nm pInfo,args,_ticks) <- collectArgsTicks e
, let ty' = primType pInfo
, (tys,_) <- splitFunForallTy ty'
-> case compare (length args) (length tys) of
EQ -> case lefts args of
[] | nm `elem` ["Clash.Transformations.removedArg"]
-> unwind eval tcm h k (PrimVal nm pInfo (rights args) [])
| otherwise
-> eval (isScrut k) tcm h k nm pInfo (rights args) []
(e':es) -> Just (h,PrimApply nm pInfo (rights args) [] es:k,e')
LT -> let (tys',_) = splitFunForallTy (termType tcm e)
(h2,e') = mkAbstr (h,e) tys'
in step eval tcm (h2,k,e')
GT -> Just (h,Instantiate ty:k,e1)
(Data dc) -> unwind eval tcm h k (DC dc [])
(Prim nm pInfo)
| nm `elem` ["GHC.Prim.realWorld#"]
-> unwind eval tcm h k (PrimVal nm pInfo [] [])
| otherwise
, let ty' = primType pInfo
-> case fst (splitFunForallTy ty') of
[] -> eval (isScrut k) tcm h k nm pInfo [] []
tys -> let (h2,e') = mkAbstr (h,e) tys
in step eval tcm (h2,k,e')
(App e1 e2) -> let (h2,id_) = newLetBinding tcm h e2
in Just (h2,Apply id_:k,e1)
(TyApp e1 ty) -> Just (h,Instantiate ty:k,e1)
(Case scrut ty alts) -> Just (h,Scrutinise ty alts:k,scrut)
(Letrec bs e') -> Just (allocate h k bs e')
Tick sp e' -> Just (h,Tickish sp:k,e')
Cast _ _ _ -> trace (unlines ["WARNING: " ++ $(curLoc) ++ "Clash currently can't symbolically evaluate casts"
,"If you have testcase that produces this message, please open an issue about it."]) Nothing
newLetBinding
:: TyConMap
-> Heap
-> Term
-> (Heap,Id)
newLetBinding tcm h@(Heap gh gbl h' ids is0) e
| Var v <- e
, Just _ <- lookupVarEnv v h'
= (h, v)
| otherwise
= (Heap gh gbl (extendVarEnv id_ e h') ids' is1,id_)
where
ty = termType tcm e
((ids',is1),id_) = mkUniqSystemId (ids,is0) ("x",ty)
newLetBindings'
:: TyConMap
-> Heap
-> [Either Term Type]
-> (Heap,[Either Term Type])
newLetBindings' tcm =
(second (map (either (Left . toVar) (Right . id))) .) . mapAccumL go
where
go h (Left tm) = second Left (newLetBinding tcm h tm)
go h (Right ty) = (h,Right ty)
mkAbstr
:: (Heap,Term)
-> [Either TyVar Type]
-> (Heap,Term)
mkAbstr = foldr go
where
go (Left tv) (h,e) =
(h,TyLam tv (TyApp e (VarTy tv)))
go (Right ty) (Heap gh gbl h ids is,e) =
let ((ids',_),id_) = mkUniqSystemId (ids,is) ("x",ty)
in (Heap gh gbl h ids' is,Lam id_ (App e (Var id_)))
force :: Heap -> Stack -> Id -> Maybe State
force (Heap gh g@(GPureHeap gbl) h ids is) k x' = case lookupVarEnv x' h of
Nothing -> case lookupVarEnv x' gbl of
Just e | isGlobalId x'
-> Just (Heap gh (GPureHeap (delVarEnv gbl x')) h ids is,GUpdate x':k,e)
_ -> Nothing
Just e -> Just (Heap gh g (delVarEnv h x') ids is,Update x':k,e)
unwind
:: PrimEvaluator
-> TyConMap
-> Heap -> Stack -> Value -> Maybe State
unwind eval tcm h k v = do
(kf,k') <- uncons k
case kf of
Update x -> return (update h k' x v)
GUpdate x -> return (gupdate h k' x v)
Apply x -> return (apply h k' v x)
Instantiate ty -> return (instantiate h k' v ty)
PrimApply nm ty tys vals tms -> primop eval tcm h k' nm ty tys vals v tms
Scrutinise _ alts -> return (scrutinise h k' v alts)
Tickish _ -> return (h,k',valToTerm v)
update :: Heap -> Stack -> Id -> Value -> State
update (Heap gh gbl h ids is) k x v = (Heap gh gbl (extendVarEnv x v' h) ids is,k,v')
where
v' = valToTerm v
gupdate :: Heap -> Stack -> Id -> Value -> State
gupdate (Heap gh (GPureHeap gbl) h ids is) k x v =
(Heap gh (GPureHeap (extendVarEnv x v' gbl)) h ids is,k,v')
where
v' = valToTerm v
valToTerm :: Value -> Term
valToTerm v = case v of
Lambda x e -> Lam x e
TyLambda x e -> TyLam x e
DC dc pxs -> foldl' (\e a -> either (App e) (TyApp e) a)
(Data dc) pxs
Lit l -> Literal l
PrimVal nm ty tys vs -> foldl' App (foldl' TyApp (Prim nm ty) tys)
(map valToTerm vs)
Suspend e -> e
toVar :: Id -> Term
toVar x = Var x
toType :: TyVar -> Type
toType x = VarTy x
apply :: Heap -> Stack -> Value -> Id -> State
apply h@(Heap _ _ _ _ is0) k (Lambda x' e) x = (h,k,substTm "Evaluator.apply" subst e)
where
subst = extendIdSubst subst0 x' (Var x)
subst0 = mkSubst (extendInScopeSet is0 x)
apply _ _ _ _ = error "not a lambda"
instantiate :: Heap -> Stack -> Value -> Type -> State
instantiate h k (TyLambda x e) ty = (h,k,substTm "Evaluator.instantiate" subst e)
where
subst = extendTvSubst subst0 x ty
subst0 = mkSubst is0
is0 = mkInScopeSet (localFVsOfTerms [e] `unionUniqSet` tyFVsOfTypes [ty])
instantiate _ _ _ _ = error "not a ty lambda"
naturalLiteral :: Value -> Maybe Integer
naturalLiteral v =
case v of
Lit (NaturalLiteral i) -> Just i
DC dc [Left (Literal (WordLiteral i))]
| dcTag dc == 1
-> Just i
DC dc [Left (Literal (ByteArrayLiteral (PV.Vector _ _ (BA.ByteArray ba))))]
| dcTag dc == 2
-> Just (Jp# (BN# ba))
_ -> Nothing
integerLiteral :: Value -> Maybe Integer
integerLiteral v =
case v of
Lit (IntegerLiteral i) -> Just i
DC dc [Left (Literal (IntLiteral i))]
| dcTag dc == 1
-> Just i
DC dc [Left (Literal (ByteArrayLiteral (PV.Vector _ _ (BA.ByteArray ba))))]
| dcTag dc == 2
-> Just (Jp# (BN# ba))
| dcTag dc == 3
-> Just (Jn# (BN# ba))
_ -> Nothing
primop
:: PrimEvaluator
-> TyConMap
-> Heap
-> Stack
-> Text
-> PrimInfo
-> [Type]
-> [Value]
-> Value
-> [Term]
-> Maybe State
primop eval tcm h k nm ty tys vs v []
| nm `elem` ["Clash.Sized.Internal.Index.fromInteger#"
,"GHC.CString.unpackCString#"
,"Clash.Transformations.removedArg"
,"GHC.Prim.MutableByteArray#"
]
= unwind eval tcm h k (PrimVal nm ty tys (vs ++ [v]))
| nm == "Clash.Sized.Internal.BitVector.fromInteger#"
= case (vs,v) of
([naturalLiteral -> Just n,mask], integerLiteral -> Just i) ->
unwind eval tcm h k (PrimVal nm ty tys [Lit (NaturalLiteral n)
,mask
,Lit (IntegerLiteral (wrapUnsigned n i))])
_ -> error ($(curLoc) ++ "Internal error" ++ show (vs,v))
| nm == "Clash.Sized.Internal.BitVector.fromInteger##"
= case (vs,v) of
([mask], integerLiteral -> Just i) ->
unwind eval tcm h k (PrimVal nm ty tys [mask
,Lit (IntegerLiteral (wrapUnsigned 1 i))])
_ -> error ($(curLoc) ++ "Internal error" ++ show (vs,v))
| nm == "Clash.Sized.Internal.Signed.fromInteger#"
= case (vs,v) of
([naturalLiteral -> Just n],integerLiteral -> Just i) ->
unwind eval tcm h k (PrimVal nm ty tys [Lit (NaturalLiteral n)
,Lit (IntegerLiteral (wrapSigned n i))])
_ -> error ($(curLoc) ++ "Internal error" ++ show (vs,v))
| nm == "Clash.Sized.Internal.Unsigned.fromInteger#"
= case (vs,v) of
([naturalLiteral -> Just n],integerLiteral -> Just i) ->
unwind eval tcm h k (PrimVal nm ty tys [Lit (NaturalLiteral n)
,Lit (IntegerLiteral (wrapUnsigned n i))])
_ -> error ($(curLoc) ++ "Internal error" ++ show (vs,v))
| otherwise = eval (isScrut k) tcm h k nm ty tys (vs ++ [v])
primop eval tcm h0 k nm ty tys vs v [e]
| nm `elem` [ "Clash.Sized.Vector.lazyV"
, "Clash.Sized.Vector.replicate"
, "Clash.Sized.Vector.replace_int"
]
= let (h1,i) = newLetBinding tcm h0 e
in eval (isScrut k) tcm h1 k nm ty tys (vs ++ [v,Suspend (Var i)])
primop _ _ h k nm ty tys vs v (e:es) =
Just (h,PrimApply nm ty tys (vs ++ [v]) es:k,e)
scrutinise :: Heap -> Stack -> Value -> [Alt] -> State
scrutinise h k v [] = (h,k,valToTerm v)
scrutinise h k (Lit l) alts = case alts of
(DefaultPat,altE):alts1 -> (h,k,go altE alts1)
_ -> (h,k,go (error ("scrutinise: no match " ++
showPpr (Case (valToTerm (Lit l)) (ConstTy Arrow) alts))) alts)
where
go def [] = def
go _ ((LitPat l1,altE):_) | l1 == l = altE
go _ ((DataPat dc [] [x],altE):_)
| IntegerLiteral l1 <- l
, Just patE <- case dcTag dc of
1 | l1 >= ((-2)^(63::Int)) && l1 < 2^(63::Int) ->
Just (IntLiteral l1)
2 | l1 >= (2^(63::Int)) ->
let !(Jp# !(BN# ba0)) = l1
ba1 = BA.ByteArray ba0
bv = PV.Vector 0 (BA.sizeofByteArray ba1) ba1
in Just (ByteArrayLiteral bv)
3 | l1 < ((-2)^(63::Int)) ->
let !(Jn# !(BN# ba0)) = l1
ba1 = BA.ByteArray ba0
bv = PV.Vector 0 (BA.sizeofByteArray ba1) ba1
in Just (ByteArrayLiteral bv)
_ -> Nothing
= let inScope = localFVsOfTerms [altE]
subst0 = mkSubst (mkInScopeSet inScope)
subst1 = extendIdSubst subst0 x (Literal patE)
in substTm "Evaluator.scrutinise" subst1 altE
| NaturalLiteral l1 <- l
, Just patE <- case dcTag dc of
1 | l1 >= 0 && l1 < 2^(64::Int) ->
Just (WordLiteral l1)
2 | l1 >= (2^(64::Int)) ->
let !(Jp# !(BN# ba0)) = l1
ba1 = BA.ByteArray ba0
bv = PV.Vector 0 (BA.sizeofByteArray ba1) ba1
in Just (ByteArrayLiteral bv)
_ -> Nothing
= let inScope = localFVsOfTerms [altE]
subst0 = mkSubst (mkInScopeSet inScope)
subst1 = extendIdSubst subst0 x (Literal patE)
in substTm "Evaluator.scrutinise" subst1 altE
go def (_:alts1) = go def alts1
scrutinise h k (DC dc xs) alts
| altE:_ <- [substAlt altDc tvs pxs xs altE
| (DataPat altDc tvs pxs,altE) <- alts, altDc == dc ] ++
[altE | (DefaultPat,altE) <- alts ]
= (h,k,altE)
scrutinise h k v@(PrimVal nm _ _ vs) alts
| any (\case {(LitPat {},_) -> True; _ -> False}) alts
= case alts of
((DefaultPat,altE):alts1) -> (h,k,go altE alts1)
_ -> (h,k,go (error ("scrutinise: no match " ++
showPpr (Case (valToTerm v) (ConstTy Arrow) alts))) alts)
where
go def [] = def
go _ ((LitPat l1,altE):_) | l1 == l = altE
go def (_:alts1) = go def alts1
l = case nm of
"Clash.Sized.Internal.BitVector.fromInteger#"
| [_,Lit (IntegerLiteral 0),Lit l0] <- vs -> l0
"Clash.Sized.Internal.Index.fromInteger#"
| [_,Lit l0] <- vs -> l0
"Clash.Sized.Internal.Signed.fromInteger#"
| [_,Lit l0] <- vs -> l0
"Clash.Sized.Internal.Unsigned.fromInteger#"
| [_,Lit l0] <- vs -> l0
_ -> error ("scrutinise: " ++ showPpr (Case (valToTerm v) (ConstTy Arrow) alts))
scrutinise _ _ v alts = error ("scrutinise: " ++ showPpr (Case (valToTerm v) (ConstTy Arrow) alts))
substAlt :: DataCon -> [TyVar] -> [Id] -> [Either Term Type] -> Term -> Term
substAlt dc tvs xs args e = substTm "Evaluator.substAlt" subst e
where
tys = rights args
tms = lefts args
substTyMap = zip tvs (drop (length (dcUnivTyVars dc)) tys)
substTmMap = zip xs tms
inScope = tyFVsOfTypes tys `unionVarSet` localFVsOfTerms (e:tms)
subst = extendTvSubstList (extendIdSubstList subst0 substTmMap) substTyMap
subst0 = mkSubst (mkInScopeSet inScope)
allocate :: Heap -> Stack -> [LetBinding] -> Term -> State
allocate (Heap gh gbl h ids is0) k xes e =
(Heap gh gbl (h `extendVarEnvList` xes') ids' isN,k,e')
where
xNms = map fst xes
is1 = extendInScopeSetList is0 xNms
(ids',s) = mapAccumL (letSubst h) ids xNms
(nms,s') = unzip s
isN = extendInScopeSetList is1 nms
subst = extendIdSubstList subst0 s'
subst0 = mkSubst (foldl' extendInScopeSet is1 nms)
xes' = zip nms (map (substTm "Evaluator.allocate0" subst . snd) xes)
e' = substTm "Evaluator.allocate1" subst e
letSubst
:: PureHeap
-> Supply
-> Id
-> ( Supply
, (Id,(Id,Term)))
letSubst h acc id0 =
let (acc',id1) = uniqueInHeap h acc id0
in (acc',(id1,(id0,Var id1)))
uniqueInHeap
:: PureHeap
-> Supply
-> Id
-> (Supply, Id)
uniqueInHeap h ids x = case lookupVarEnv x' h of
Just _ -> uniqueInHeap h ids' x
_ -> (ids',x')
where
(i,ids') = freshId ids
x' = modifyVarName (\nm -> nm {nameUniq = i}) x
wrapUnsigned :: Integer -> Integer -> Integer
wrapUnsigned n i = i `mod` sz
where
sz = 1 `shiftL` fromInteger n
wrapSigned :: Integer -> Integer -> Integer
wrapSigned n i = if mask == 0 then 0 else res
where
mask = 1 `shiftL` fromInteger (n - 1)
res = case divMod i mask of
(s,i1) | even s -> i1
| otherwise -> i1 - mask