{-|
Copyright   : (C) 2020 QBayLogic B.V.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qbaylogic.com>

The AsTerm class and relevant instances for the partial evaluator. This
defines how to convert normal forms back into Terms which can be given as the
result of evaluation.
-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

module Clash.Core.PartialEval.AsTerm
  ( AsTerm(..)
  ) where

import Data.Bifunctor (first, second)
import Data.Graph (SCC(..), flattenSCCs)

import Clash.Core.FreeVars (localFVsOfTerms)
import Clash.Core.PartialEval.NormalForm
import Clash.Core.Term (Term(..), LetBinding, Pat, Alt, mkApps)
import Clash.Core.Util (sccLetBindings)
import Clash.Core.VarEnv (elemVarSet)

-- | Convert a term in some normal form back into a Term. This is important,
-- as it may perform substitutions which have not yet been performed (i.e. when
-- converting from WHNF where heads contain the environment at that point).
--
class AsTerm a where
  asTerm:: a -> Term

instance (AsTerm a) => AsTerm (Neutral a) where
  asTerm :: Neutral a -> Term
asTerm = \case
    NeVar Id
i -> Id -> Term
Var Id
i
    NePrim PrimInfo
pr Args a
args -> Term -> [Either Term Type] -> Term
mkApps (PrimInfo -> Term
Prim PrimInfo
pr) (Args a -> [Either Term Type]
forall a. AsTerm a => Args a -> [Either Term Type]
argsToTerms Args a
args)
    NeApp Neutral a
x a
y -> Term -> Term -> Term
App (Neutral a -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral a
x) (a -> Term
forall a. AsTerm a => a -> Term
asTerm a
y)
    NeTyApp Neutral a
x Type
ty -> Term -> Type -> Term
TyApp (Neutral a -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral a
x) Type
ty
    NeLetrec [(Id, a)]
bs a
x ->
      let bs' :: [(Id, Term)]
bs' = ((Id, a) -> (Id, Term)) -> [(Id, a)] -> [(Id, Term)]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> Term) -> (Id, a) -> (Id, Term)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second a -> Term
forall a. AsTerm a => a -> Term
asTerm) [(Id, a)]
bs
          x' :: Term
x'  = a -> Term
forall a. AsTerm a => a -> Term
asTerm a
x
       in [(Id, Term)] -> Term -> Term
removeUnusedBindings [(Id, Term)]
bs' Term
x'

    NeCase a
x Type
ty [(Pat, a)]
alts -> Term -> Type -> [Alt] -> Term
Case (a -> Term
forall a. AsTerm a => a -> Term
asTerm a
x) Type
ty ([(Pat, a)] -> [Alt]
forall a. AsTerm a => [(Pat, a)] -> [Alt]
altsToTerms [(Pat, a)]
alts)

removeUnusedBindings :: [LetBinding] -> Term -> Term
removeUnusedBindings :: [(Id, Term)] -> Term -> Term
removeUnusedBindings [(Id, Term)]
bs Term
x
  | [(Id, Term)] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [(Id, Term)]
used = Term
x
  | Bool
otherwise = [(Id, Term)] -> Term -> Term
Letrec [(Id, Term)]
used Term
x
 where
  free :: VarSet
free = [Term] -> VarSet
forall (f :: Type -> Type). Foldable f => f Term -> VarSet
localFVsOfTerms [Term
x]
  used :: [(Id, Term)]
used = [SCC (Id, Term)] -> [(Id, Term)]
forall a. [SCC a] -> [a]
flattenSCCs ([SCC (Id, Term)] -> [(Id, Term)])
-> [SCC (Id, Term)] -> [(Id, Term)]
forall a b. (a -> b) -> a -> b
$ (SCC (Id, Term) -> Bool) -> [SCC (Id, Term)] -> [SCC (Id, Term)]
forall a. (a -> Bool) -> [a] -> [a]
filter SCC (Id, Term) -> Bool
forall a b. SCC (Var a, b) -> Bool
isUsed (HasCallStack => [(Id, Term)] -> [SCC (Id, Term)]
[(Id, Term)] -> [SCC (Id, Term)]
sccLetBindings [(Id, Term)]
bs)

  isUsed :: SCC (Var a, b) -> Bool
isUsed = \case
    AcyclicSCC (Var a, b)
y -> (Var a, b) -> Var a
forall a b. (a, b) -> a
fst (Var a, b)
y Var a -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`elemVarSet` VarSet
free
    CyclicSCC [(Var a, b)]
ys -> ((Var a, b) -> Bool) -> [(Var a, b)] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any ((Var a -> VarSet -> Bool) -> VarSet -> Var a -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Var a -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet VarSet
free (Var a -> Bool) -> ((Var a, b) -> Var a) -> (Var a, b) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Var a, b) -> Var a
forall a b. (a, b) -> a
fst) [(Var a, b)]
ys

instance AsTerm Value where
  asTerm :: Value -> Term
asTerm = \case
    VNeutral Neutral Value
neu -> Neutral Value -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral Value
neu
    VLiteral Literal
lit -> Literal -> Term
Literal Literal
lit
    VData DataCon
dc Args Value
args LocalEnv
_env -> Term -> [Either Term Type] -> Term
mkApps (DataCon -> Term
Data DataCon
dc) (Args Value -> [Either Term Type]
forall a. AsTerm a => Args a -> [Either Term Type]
argsToTerms Args Value
args)
    VLam Id
i Term
x LocalEnv
_env -> Id -> Term -> Term
Lam Id
i Term
x
    VTyLam TyVar
i Term
x LocalEnv
_env -> TyVar -> Term -> Term
TyLam TyVar
i Term
x
    VCast Value
x Type
a Type
b -> Term -> Type -> Type -> Term
Cast (Value -> Term
forall a. AsTerm a => a -> Term
asTerm Value
x) Type
a Type
b
    VTick Value
x TickInfo
tick -> TickInfo -> Term -> Term
Tick TickInfo
tick (Value -> Term
forall a. AsTerm a => a -> Term
asTerm Value
x)
    VThunk Term
x LocalEnv
_env -> Term
x

instance AsTerm Normal where
  asTerm :: Normal -> Term
asTerm = \case
    NNeutral Neutral Normal
neu -> Neutral Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Neutral Normal
neu
    NLiteral Literal
lit -> Literal -> Term
Literal Literal
lit
    NData DataCon
dc Args Normal
args -> Term -> [Either Term Type] -> Term
mkApps (DataCon -> Term
Data DataCon
dc) (Args Normal -> [Either Term Type]
forall a. AsTerm a => Args a -> [Either Term Type]
argsToTerms Args Normal
args)
    NLam Id
i Normal
x LocalEnv
_env -> Id -> Term -> Term
Lam Id
i (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x)
    NTyLam TyVar
i Normal
x LocalEnv
_env -> TyVar -> Term -> Term
TyLam TyVar
i (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x)
    NCast Normal
x Type
a Type
b -> Term -> Type -> Type -> Term
Cast (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x) Type
a Type
b
    NTick Normal
x TickInfo
tick -> TickInfo -> Term -> Term
Tick TickInfo
tick (Normal -> Term
forall a. AsTerm a => a -> Term
asTerm Normal
x)

argsToTerms :: (AsTerm a) => Args a -> Args Term
argsToTerms :: Args a -> [Either Term Type]
argsToTerms = (Either a Type -> Either Term Type) -> Args a -> [Either Term Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Either a Type -> Either Term Type)
 -> Args a -> [Either Term Type])
-> (Either a Type -> Either Term Type)
-> Args a
-> [Either Term Type]
forall a b. (a -> b) -> a -> b
$ (a -> Term) -> Either a Type -> Either Term Type
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first a -> Term
forall a. AsTerm a => a -> Term
asTerm

altsToTerms :: (AsTerm a) => [(Pat, a)] -> [Alt]
altsToTerms :: [(Pat, a)] -> [Alt]
altsToTerms = ((Pat, a) -> Alt) -> [(Pat, a)] -> [Alt]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Pat, a) -> Alt) -> [(Pat, a)] -> [Alt])
-> ((Pat, a) -> Alt) -> [(Pat, a)] -> [Alt]
forall a b. (a -> b) -> a -> b
$ (a -> Term) -> (Pat, a) -> Alt
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second a -> Term
forall a. AsTerm a => a -> Term
asTerm