-----------------------------------------------------------------------------
-- |
-- Module      :  Disco.Typecheck.Erase
-- Copyright   :  (c) 2016 disco team (see LICENSE)
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  byorgey@gmail.com
--
-- Typecheck the Disco surface language and transform it into a
-- type-annotated AST.
--
-----------------------------------------------------------------------------

module Disco.Typecheck.Erase where

import           Unbound.Generics.LocallyNameless
import           Unbound.Generics.LocallyNameless.Unsafe

import           Control.Arrow                           ((***))
import           Data.Coerce

import           Disco.AST.Desugared
import           Disco.AST.Surface
import           Disco.AST.Typed
import           Disco.Names                             (QName (..))

-- | Erase all the type annotations from a term.
erase :: ATerm -> Term
erase :: ATerm -> Term
erase (ATVar Type
_ (QName NameProvenance
_ Name ATerm
x)) = Name Term -> Term
TVar (Name ATerm -> Name Term
coerce Name ATerm
x)
erase (ATPrim Type
_ Prim
x)          = Prim -> Term
TPrim Prim
x
erase (ATLet Type
_ Bind (Telescope ABinding) ATerm
bs)          = Bind (Telescope Binding) Term -> Term
TLet (Bind (Telescope Binding) Term -> Term)
-> Bind (Telescope Binding) Term -> Term
forall a b. (a -> b) -> a -> b
$ Telescope Binding -> Term -> Bind (Telescope Binding) Term
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind ((ABinding -> Binding) -> Telescope ABinding -> Telescope Binding
forall a b.
(Alpha a, Alpha b) =>
(a -> b) -> Telescope a -> Telescope b
mapTelescope ABinding -> Binding
eraseBinding Telescope ABinding
tel) (ATerm -> Term
erase ATerm
at)
  where (Telescope ABinding
tel,ATerm
at) = Bind (Telescope ABinding) ATerm -> (Telescope ABinding, ATerm)
forall p t. (Alpha p, Alpha t) => Bind p t -> (p, t)
unsafeUnbind Bind (Telescope ABinding) ATerm
bs
erase ATerm
ATUnit                = Term
TUnit
erase (ATBool Type
_ Bool
b)          = Bool -> Term
TBool Bool
b
erase (ATChar Char
c)            = Char -> Term
TChar Char
c
erase (ATString String
s)          = String -> Term
TString String
s
erase (ATNat Type
_ Integer
i)           = Integer -> Term
TNat Integer
i
erase (ATRat Rational
r)             = Rational -> Term
TRat Rational
r
erase (ATAbs Quantifier
q Type
_ Bind [APattern] ATerm
b)         = Quantifier -> Bind [Pattern] Term -> Term
TAbs Quantifier
q (Bind [Pattern] Term -> Term) -> Bind [Pattern] Term -> Term
forall a b. (a -> b) -> a -> b
$ [Pattern] -> Term -> Bind [Pattern] Term
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind ((APattern -> Pattern) -> [APattern] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map APattern -> Pattern
erasePattern [APattern]
x) (ATerm -> Term
erase ATerm
at)
  where ([APattern]
x,ATerm
at) = Bind [APattern] ATerm -> ([APattern], ATerm)
forall p t. (Alpha p, Alpha t) => Bind p t -> (p, t)
unsafeUnbind Bind [APattern] ATerm
b
erase (ATApp Type
_ ATerm
t1 ATerm
t2)       = Term -> Term -> Term
TApp (ATerm -> Term
erase ATerm
t1) (ATerm -> Term
erase ATerm
t2)
erase (ATTup Type
_ [ATerm]
ats)         = [Term] -> Term
TTup ((ATerm -> Term) -> [ATerm] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map ATerm -> Term
erase [ATerm]
ats)
erase (ATCase Type
_ [ABranch]
brs)        = [Branch] -> Term
TCase ((ABranch -> Branch) -> [ABranch] -> [Branch]
forall a b. (a -> b) -> [a] -> [b]
map ABranch -> Branch
eraseBranch [ABranch]
brs)
erase (ATChain Type
_ ATerm
at [ALink]
lnks)   = Term -> [Link] -> Term
TChain (ATerm -> Term
erase ATerm
at) ((ALink -> Link) -> [ALink] -> [Link]
forall a b. (a -> b) -> [a] -> [b]
map ALink -> Link
eraseLink [ALink]
lnks)
erase (ATTyOp Type
_ TyOp
op Type
ty)      = TyOp -> Type -> Term
TTyOp TyOp
op Type
ty
erase (ATContainer Type
_ Container
c [(ATerm, Maybe ATerm)]
ats Maybe (Ellipsis ATerm)
aell)   = Container -> [(Term, Maybe Term)] -> Maybe (Ellipsis Term) -> Term
TContainer Container
c (((ATerm, Maybe ATerm) -> (Term, Maybe Term))
-> [(ATerm, Maybe ATerm)] -> [(Term, Maybe Term)]
forall a b. (a -> b) -> [a] -> [b]
map (ATerm -> Term
erase (ATerm -> Term)
-> (Maybe ATerm -> Maybe Term)
-> (ATerm, Maybe ATerm)
-> (Term, Maybe Term)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** (ATerm -> Term) -> Maybe ATerm -> Maybe Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ATerm -> Term
erase) [(ATerm, Maybe ATerm)]
ats) (((Ellipsis ATerm -> Ellipsis Term)
-> Maybe (Ellipsis ATerm) -> Maybe (Ellipsis Term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Ellipsis ATerm -> Ellipsis Term)
 -> Maybe (Ellipsis ATerm) -> Maybe (Ellipsis Term))
-> ((ATerm -> Term) -> Ellipsis ATerm -> Ellipsis Term)
-> (ATerm -> Term)
-> Maybe (Ellipsis ATerm)
-> Maybe (Ellipsis Term)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ATerm -> Term) -> Ellipsis ATerm -> Ellipsis Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) ATerm -> Term
erase Maybe (Ellipsis ATerm)
aell)
erase (ATContainerComp Type
_ Container
c Bind (Telescope AQual) ATerm
b)      = Container -> Bind (Telescope Qual) Term -> Term
TContainerComp Container
c (Bind (Telescope Qual) Term -> Term)
-> Bind (Telescope Qual) Term -> Term
forall a b. (a -> b) -> a -> b
$ Telescope Qual -> Term -> Bind (Telescope Qual) Term
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind ((AQual -> Qual) -> Telescope AQual -> Telescope Qual
forall a b.
(Alpha a, Alpha b) =>
(a -> b) -> Telescope a -> Telescope b
mapTelescope AQual -> Qual
eraseQual Telescope AQual
tel) (ATerm -> Term
erase ATerm
at)
  where (Telescope AQual
tel,ATerm
at) = Bind (Telescope AQual) ATerm -> (Telescope AQual, ATerm)
forall p t. (Alpha p, Alpha t) => Bind p t -> (p, t)
unsafeUnbind Bind (Telescope AQual) ATerm
b
erase (ATTest [(String, Type, Name ATerm)]
_ ATerm
x)          = ATerm -> Term
erase ATerm
x

eraseBinding :: ABinding -> Binding
eraseBinding :: ABinding -> Binding
eraseBinding (ABinding Maybe (Embed PolyType)
mty Name ATerm
x (Embed ATerm -> Embedded (Embed ATerm)
forall e. IsEmbed e => e -> Embedded e
unembed -> Embedded (Embed ATerm)
at)) = Maybe (Embed PolyType) -> Name Term -> Embed Term -> Binding
Binding Maybe (Embed PolyType)
mty (Name ATerm -> Name Term
coerce Name ATerm
x) (Embedded (Embed Term) -> Embed Term
forall e. IsEmbed e => Embedded e -> e
embed (ATerm -> Term
erase Embedded (Embed ATerm)
ATerm
at))

erasePattern :: APattern -> Pattern
erasePattern :: APattern -> Pattern
erasePattern (APVar Type
_ Name ATerm
n)        = Name Term -> Pattern
PVar (Name ATerm -> Name Term
coerce Name ATerm
n)
erasePattern (APWild Type
_)         = Pattern
PWild
erasePattern APattern
APUnit             = Pattern
PUnit
erasePattern (APBool Bool
b)         = Bool -> Pattern
PBool Bool
b
erasePattern (APChar Char
c)         = Char -> Pattern
PChar Char
c
erasePattern (APString String
s)       = String -> Pattern
PString String
s
erasePattern (APTup Type
_ [APattern]
alp)      = [Pattern] -> Pattern
PTup ([Pattern] -> Pattern) -> [Pattern] -> Pattern
forall a b. (a -> b) -> a -> b
$ (APattern -> Pattern) -> [APattern] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map APattern -> Pattern
erasePattern [APattern]
alp
erasePattern (APInj Type
_ Side
s APattern
apt)    = Side -> Pattern -> Pattern
PInj Side
s (APattern -> Pattern
erasePattern APattern
apt)
erasePattern (APNat Type
_ Integer
n)        = Integer -> Pattern
PNat Integer
n
erasePattern (APCons Type
_ APattern
ap1 APattern
ap2) = Pattern -> Pattern -> Pattern
PCons (APattern -> Pattern
erasePattern APattern
ap1) (APattern -> Pattern
erasePattern APattern
ap2)
erasePattern (APList Type
_ [APattern]
alp)     = [Pattern] -> Pattern
PList ([Pattern] -> Pattern) -> [Pattern] -> Pattern
forall a b. (a -> b) -> a -> b
$ (APattern -> Pattern) -> [APattern] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map APattern -> Pattern
erasePattern [APattern]
alp
erasePattern (APAdd Type
_ Side
s APattern
p ATerm
t)    = Side -> Pattern -> Term -> Pattern
PAdd Side
s (APattern -> Pattern
erasePattern APattern
p) (ATerm -> Term
erase ATerm
t)
erasePattern (APMul Type
_ Side
s APattern
p ATerm
t)    = Side -> Pattern -> Term -> Pattern
PMul Side
s (APattern -> Pattern
erasePattern APattern
p) (ATerm -> Term
erase ATerm
t)
erasePattern (APSub Type
_ APattern
p ATerm
t)      = Pattern -> Term -> Pattern
PSub (APattern -> Pattern
erasePattern APattern
p) (ATerm -> Term
erase ATerm
t)
erasePattern (APNeg Type
_ APattern
p)        = Pattern -> Pattern
PNeg (APattern -> Pattern
erasePattern APattern
p)
erasePattern (APFrac Type
_ APattern
p1 APattern
p2)   = Pattern -> Pattern -> Pattern
PFrac (APattern -> Pattern
erasePattern APattern
p1) (APattern -> Pattern
erasePattern APattern
p2)

eraseBranch :: ABranch -> Branch
eraseBranch :: ABranch -> Branch
eraseBranch ABranch
b = Telescope Guard -> Term -> Branch
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind ((AGuard -> Guard) -> Telescope AGuard -> Telescope Guard
forall a b.
(Alpha a, Alpha b) =>
(a -> b) -> Telescope a -> Telescope b
mapTelescope AGuard -> Guard
eraseGuard Telescope AGuard
tel) (ATerm -> Term
erase ATerm
at)
  where (Telescope AGuard
tel,ATerm
at) = ABranch -> (Telescope AGuard, ATerm)
forall p t. (Alpha p, Alpha t) => Bind p t -> (p, t)
unsafeUnbind ABranch
b

eraseGuard :: AGuard -> Guard
eraseGuard :: AGuard -> Guard
eraseGuard (AGBool (Embed ATerm -> Embedded (Embed ATerm)
forall e. IsEmbed e => e -> Embedded e
unembed -> Embedded (Embed ATerm)
at))  = Embed Term -> Guard
GBool (Embedded (Embed Term) -> Embed Term
forall e. IsEmbed e => Embedded e -> e
embed (ATerm -> Term
erase Embedded (Embed ATerm)
ATerm
at))
eraseGuard (AGPat (Embed ATerm -> Embedded (Embed ATerm)
forall e. IsEmbed e => e -> Embedded e
unembed -> Embedded (Embed ATerm)
at) APattern
p) = Embed Term -> Pattern -> Guard
GPat (Embedded (Embed Term) -> Embed Term
forall e. IsEmbed e => Embedded e -> e
embed (ATerm -> Term
erase Embedded (Embed ATerm)
ATerm
at)) (APattern -> Pattern
erasePattern APattern
p)
eraseGuard (AGLet ABinding
b)                 = Binding -> Guard
GLet (ABinding -> Binding
eraseBinding ABinding
b)

eraseLink :: ALink -> Link
eraseLink :: ALink -> Link
eraseLink (ATLink BOp
bop ATerm
at) = BOp -> Term -> Link
TLink BOp
bop (ATerm -> Term
erase ATerm
at)

eraseQual :: AQual -> Qual
eraseQual :: AQual -> Qual
eraseQual (AQBind Name ATerm
x (Embed ATerm -> Embedded (Embed ATerm)
forall e. IsEmbed e => e -> Embedded e
unembed -> Embedded (Embed ATerm)
at)) = Name Term -> Embed Term -> Qual
QBind (Name ATerm -> Name Term
coerce Name ATerm
x) (Embedded (Embed Term) -> Embed Term
forall e. IsEmbed e => Embedded e -> e
embed (ATerm -> Term
erase Embedded (Embed ATerm)
ATerm
at))
eraseQual (AQGuard (Embed ATerm -> Embedded (Embed ATerm)
forall e. IsEmbed e => e -> Embedded e
unembed -> Embedded (Embed ATerm)
at))  = Embed Term -> Qual
QGuard (Embedded (Embed Term) -> Embed Term
forall e. IsEmbed e => Embedded e -> e
embed (ATerm -> Term
erase Embedded (Embed ATerm)
ATerm
at))

eraseProperty :: AProperty -> Property
eraseProperty :: ATerm -> Term
eraseProperty = ATerm -> Term
erase

------------------------------------------------------------
-- DTerm erasure

eraseDTerm :: DTerm -> Term
eraseDTerm :: DTerm -> Term
eraseDTerm (DTVar Type
_ (QName NameProvenance
_ Name DTerm
x)) = Name Term -> Term
TVar (Name DTerm -> Name Term
coerce Name DTerm
x)
eraseDTerm (DTPrim Type
_ Prim
x)     = Prim -> Term
TPrim Prim
x
eraseDTerm DTerm
DTUnit           = Term
TUnit
eraseDTerm (DTBool Type
_ Bool
b)     = Bool -> Term
TBool Bool
b
eraseDTerm (DTChar Char
c)       = Char -> Term
TChar Char
c
eraseDTerm (DTNat Type
_ Integer
n)      = Integer -> Term
TNat Integer
n
eraseDTerm (DTRat Rational
r)        = Rational -> Term
TRat Rational
r
eraseDTerm (DTAbs Quantifier
q Type
_ Bind (Name DTerm) DTerm
b)    = Quantifier -> Bind [Pattern] Term -> Term
TAbs Quantifier
q (Bind [Pattern] Term -> Term) -> Bind [Pattern] Term -> Term
forall a b. (a -> b) -> a -> b
$ [Pattern] -> Term -> Bind [Pattern] Term
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind [Name Term -> Pattern
PVar (Name Term -> Pattern)
-> (Name DTerm -> Name Term) -> Name DTerm -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DTerm -> Name Term
coerce (Name DTerm -> Pattern) -> Name DTerm -> Pattern
forall a b. (a -> b) -> a -> b
$ Name DTerm
x] (DTerm -> Term
eraseDTerm DTerm
dt)
  where (Name DTerm
x, DTerm
dt) = Bind (Name DTerm) DTerm -> (Name DTerm, DTerm)
forall p t. (Alpha p, Alpha t) => Bind p t -> (p, t)
unsafeUnbind Bind (Name DTerm) DTerm
b
eraseDTerm (DTApp Type
_ DTerm
d1 DTerm
d2)  = Term -> Term -> Term
TApp (DTerm -> Term
eraseDTerm DTerm
d1) (DTerm -> Term
eraseDTerm DTerm
d2)
eraseDTerm (DTPair Type
_ DTerm
d1 DTerm
d2) = [Term] -> Term
TTup [DTerm -> Term
eraseDTerm DTerm
d1, DTerm -> Term
eraseDTerm DTerm
d2]
eraseDTerm (DTCase Type
_ [DBranch]
bs)    = [Branch] -> Term
TCase ((DBranch -> Branch) -> [DBranch] -> [Branch]
forall a b. (a -> b) -> [a] -> [b]
map DBranch -> Branch
eraseDBranch [DBranch]
bs)
eraseDTerm (DTTyOp Type
_ TyOp
op Type
ty) = TyOp -> Type -> Term
TTyOp TyOp
op Type
ty
eraseDTerm (DTNil Type
_)        = [Term] -> Maybe (Ellipsis Term) -> Term
TList [] Maybe (Ellipsis Term)
forall a. Maybe a
Nothing
eraseDTerm (DTTest [(String, Type, Name DTerm)]
_ DTerm
x)     = DTerm -> Term
eraseDTerm DTerm
x

eraseDBranch :: DBranch -> Branch
eraseDBranch :: DBranch -> Branch
eraseDBranch DBranch
b = Telescope Guard -> Term -> Branch
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind ((DGuard -> Guard) -> Telescope DGuard -> Telescope Guard
forall a b.
(Alpha a, Alpha b) =>
(a -> b) -> Telescope a -> Telescope b
mapTelescope DGuard -> Guard
eraseDGuard Telescope DGuard
tel) (DTerm -> Term
eraseDTerm DTerm
d)
  where
    (Telescope DGuard
tel, DTerm
d) = DBranch -> (Telescope DGuard, DTerm)
forall p t. (Alpha p, Alpha t) => Bind p t -> (p, t)
unsafeUnbind DBranch
b

eraseDGuard :: DGuard -> Guard
eraseDGuard :: DGuard -> Guard
eraseDGuard (DGPat (Embed DTerm -> Embedded (Embed DTerm)
forall e. IsEmbed e => e -> Embedded e
unembed -> Embedded (Embed DTerm)
d) DPattern
p) = Embed Term -> Pattern -> Guard
GPat (Embedded (Embed Term) -> Embed Term
forall e. IsEmbed e => Embedded e -> e
embed (DTerm -> Term
eraseDTerm Embedded (Embed DTerm)
DTerm
d)) (DPattern -> Pattern
eraseDPattern DPattern
p)

eraseDPattern :: DPattern -> Pattern
eraseDPattern :: DPattern -> Pattern
eraseDPattern (DPVar Type
_ Name DTerm
x)      = Name Term -> Pattern
PVar (Name DTerm -> Name Term
coerce Name DTerm
x)
eraseDPattern (DPWild Type
_)       = Pattern
PWild
eraseDPattern DPattern
DPUnit           = Pattern
PUnit
eraseDPattern (DPPair Type
_ Name DTerm
x1 Name DTerm
x2) = [Pattern] -> Pattern
PTup ((Name DTerm -> Pattern) -> [Name DTerm] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map (Name Term -> Pattern
PVar (Name Term -> Pattern)
-> (Name DTerm -> Name Term) -> Name DTerm -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DTerm -> Name Term
coerce) [Name DTerm
x1,Name DTerm
x2])
eraseDPattern (DPInj Type
_ Side
s Name DTerm
x)    = Side -> Pattern -> Pattern
PInj Side
s (Name Term -> Pattern
PVar (Name DTerm -> Name Term
coerce Name DTerm
x))