{-|
Copyright   : (C) 2020-2021, QBayLogic B.V.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qbaylogic.com>

The AsTerm class and relevant instances for the partial evaluator. This
defines how to convert normal forms back into Terms which can be given as the
result of evaluation.
-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

module Clash.Core.PartialEval.AsTerm
  ( AsTerm(..)
  ) where

import Data.Bifunctor (first, second)

import Clash.Core.HasFreeVars
import Clash.Core.PartialEval.NormalForm
import Clash.Core.Term (Bind(..), Term(..), Pat, Alt, mkApps)
import Clash.Core.VarEnv (elemVarSet)

-- | Convert a term in some normal form back into a Term. This is important,
-- as it may perform substitutions which have not yet been performed (i.e. when
-- converting from WHNF where heads contain the environment at that point).
--
class AsTerm a where
  asTerm:: a -> Term

instance (AsTerm a) => AsTerm (Neutral a) where
  asTerm :: Neutral a -> Term
asTerm = \case
    NeVar Id
i -> Id -> Term
Var Id
i
    NePrim PrimInfo
pr Args a
args -> Term -> [Either Term Type] -> Term
mkApps (PrimInfo -> Term
Prim PrimInfo
pr) (Args a -> [Either Term Type]
forall a. AsTerm a => Args a -> [Either Term Type]
argsToTerms Args a
args)
    NeApp Neutral a
x a
y -> Term -> Term -> Term
App (Neutral a -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral a
x) (a -> Term
forall a. AsTerm a => a -> Term
asTerm a
y)
    NeTyApp Neutral a
x Type
ty -> Term -> Type -> Term
TyApp (Neutral a -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral a
x) Type
ty
    NeLet Bind a
bs a
x -> Bind Term -> Term -> Term
removeUnusedBindings ((a -> Term) -> Bind a -> Bind Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Term
forall a. AsTerm a => a -> Term
asTerm Bind a
bs) (a -> Term
forall a. AsTerm a => a -> Term
asTerm a
x)
    NeCase a
x Type
ty [(Pat, a)]
alts -> Term -> Type -> [Alt] -> Term
Case (a -> Term
forall a. AsTerm a => a -> Term
asTerm a
x) Type
ty ([(Pat, a)] -> [Alt]
forall a. AsTerm a => [(Pat, a)] -> [Alt]
altsToTerms [(Pat, a)]
alts)

removeUnusedBindings :: Bind Term -> Term -> Term
removeUnusedBindings :: Bind Term -> Term -> Term
removeUnusedBindings Bind Term
bs Term
x
  | Bind Term -> Bool
forall b. Bind b -> Bool
isUsed Bind Term
bs = Bind Term -> Term -> Term
Let Bind Term
bs Term
x
  | Bool
otherwise = Term
x
 where
  free :: VarSet
free = Term -> VarSet
forall a. HasFreeVars a => a -> VarSet
freeVarsOf Term
x

  isUsed :: Bind b -> Bool
isUsed = \case
    NonRec Id
i b
_ -> Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Id
i VarSet
free
    Rec [(Id, b)]
xs -> ((Id, b) -> Bool) -> [(Id, b)] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any ((Id -> VarSet -> Bool) -> VarSet -> Id -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet VarSet
free (Id -> Bool) -> ((Id, b) -> Id) -> (Id, b) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, b) -> Id
forall a b. (a, b) -> a
fst) [(Id, b)]
xs

instance AsTerm Value where
  asTerm :: Value -> Term
asTerm = \case
    VNeutral Neutral Value
neu -> Neutral Value -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral Value
neu
    VLiteral Literal
lit -> Literal -> Term
Literal Literal
lit
    VData DataCon
dc Args Value
args LocalEnv
_env -> Term -> [Either Term Type] -> Term
mkApps (DataCon -> Term
Data DataCon
dc) (Args Value -> [Either Term Type]
forall a. AsTerm a => Args a -> [Either Term Type]
argsToTerms Args Value
args)
    VLam Id
i Term
x LocalEnv
_env -> Id -> Term -> Term
Lam Id
i Term
x
    VTyLam TyVar
i Term
x LocalEnv
_env -> TyVar -> Term -> Term
TyLam TyVar
i Term
x
    VCast Value
x Type
a Type
b -> Term -> Type -> Type -> Term
Cast (Value -> Term
forall a. AsTerm a => a -> Term
asTerm Value
x) Type
a Type
b
    VTick Value
x TickInfo
tick -> TickInfo -> Term -> Term
Tick TickInfo
tick (Value -> Term
forall a. AsTerm a => a -> Term
asTerm Value
x)
    VThunk Term
x LocalEnv
_env -> Term
x

instance AsTerm Normal where
  asTerm :: Normal -> Term
asTerm = \case
    NNeutral Neutral Normal
neu -> Neutral Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral Normal
neu
    NLiteral Literal
lit -> Literal -> Term
Literal Literal
lit
    NData DataCon
dc Args Normal
args -> Term -> [Either Term Type] -> Term
mkApps (DataCon -> Term
Data DataCon
dc) (Args Normal -> [Either Term Type]
forall a. AsTerm a => Args a -> [Either Term Type]
argsToTerms Args Normal
args)
    NLam Id
i Normal
x LocalEnv
_env -> Id -> Term -> Term
Lam Id
i (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x)
    NTyLam TyVar
i Normal
x LocalEnv
_env -> TyVar -> Term -> Term
TyLam TyVar
i (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x)
    NCast Normal
x Type
a Type
b -> Term -> Type -> Type -> Term
Cast (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x) Type
a Type
b
    NTick Normal
x TickInfo
tick -> TickInfo -> Term -> Term
Tick TickInfo
tick (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x)

argsToTerms :: (AsTerm a) => Args a -> Args Term
argsToTerms :: Args a -> [Either Term Type]
argsToTerms = (Either a Type -> Either Term Type) -> Args a -> [Either Term Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Either a Type -> Either Term Type)
 -> Args a -> [Either Term Type])
-> (Either a Type -> Either Term Type)
-> Args a
-> [Either Term Type]
forall a b. (a -> b) -> a -> b
$ (a -> Term) -> Either a Type -> Either Term Type
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first a -> Term
forall a. AsTerm a => a -> Term
asTerm

altsToTerms :: (AsTerm a) => [(Pat, a)] -> [Alt]
altsToTerms :: [(Pat, a)] -> [Alt]
altsToTerms = ((Pat, a) -> Alt) -> [(Pat, a)] -> [Alt]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Pat, a) -> Alt) -> [(Pat, a)] -> [Alt])
-> ((Pat, a) -> Alt) -> [(Pat, a)] -> [Alt]
forall a b. (a -> b) -> a -> b
$ (a -> Term) -> (Pat, a) -> Alt
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second a -> Term
forall a. AsTerm a => a -> Term
asTerm