-----------------------------------------------------------------------------
-- |
-- Module      :  Disco.Typecheck.Util
-- Copyright   :  (c) 2016 disco team (see LICENSE)
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  byorgey@gmail.com
--
-- Definition of type contexts, type errors, and various utilities
-- used during type checking.
--
-----------------------------------------------------------------------------

module Disco.Typecheck.Util where

import           Disco.Effects.Fresh
import           Polysemy
import           Polysemy.Error
import           Polysemy.Output
import           Polysemy.Reader
import           Polysemy.Writer
import           Unbound.Generics.LocallyNameless (Name, bind, string2Name)

import qualified Data.Map                         as M
import           Data.Tuple                       (swap)
import           Prelude                          hiding (lookup)

import           Disco.AST.Surface
import           Disco.Context
import           Disco.Messages
import           Disco.Names                      (ModuleName, QName)
import           Disco.Typecheck.Constraints
import           Disco.Typecheck.Solve
import           Disco.Types

------------------------------------------------------------
-- Contexts
------------------------------------------------------------

-- | A typing context is a mapping from term names to types.
type TyCtx = Ctx Term PolyType

------------------------------------------------------------
-- Errors
------------------------------------------------------------

-- | A typechecking error, wrapped up together with the name of the
--   thing that was being checked when the error occurred.
data LocTCError = LocTCError (Maybe (QName Term)) TCError
  deriving Int -> LocTCError -> ShowS
[LocTCError] -> ShowS
LocTCError -> String
(Int -> LocTCError -> ShowS)
-> (LocTCError -> String)
-> ([LocTCError] -> ShowS)
-> Show LocTCError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LocTCError] -> ShowS
$cshowList :: [LocTCError] -> ShowS
show :: LocTCError -> String
$cshow :: LocTCError -> String
showsPrec :: Int -> LocTCError -> ShowS
$cshowsPrec :: Int -> LocTCError -> ShowS
Show

-- | Wrap a @TCError@ into a @LocTCError@ with no explicit provenance
--   information.
noLoc :: TCError -> LocTCError
noLoc :: TCError -> LocTCError
noLoc = Maybe (QName Term) -> TCError -> LocTCError
LocTCError Maybe (QName Term)
forall a. Maybe a
Nothing

-- | Potential typechecking errors.
data TCError
  = Unbound (Name Term)    -- ^ Encountered an unbound variable
  | Ambiguous (Name Term) [ModuleName] -- ^ Encountered an ambiguous name.
  | NoType  (Name Term)    -- ^ No type is specified for a definition
  | NotCon Con Term Type   -- ^ The type of the term should have an
                           --   outermost constructor matching Con, but
                           --   it has type 'Type' instead
  | EmptyCase              -- ^ Case analyses cannot be empty.
  | PatternType Con Pattern Type  -- ^ The given pattern should have the type, but it doesn't.
                                  -- instead it has a kind of type given by the Con.
  | DuplicateDecls (Name Term)  -- ^ Duplicate declarations.
  | DuplicateDefns (Name Term)  -- ^ Duplicate definitions.
  | DuplicateTyDefns String -- ^ Duplicate type definitions.
  | CyclicTyDef String     -- ^ Cyclic type definition.
  | NumPatterns            -- ^ # of patterns does not match type in definition
  | NoSearch Type          -- ^ Type can't be quantified over.
  | Unsolvable SolveError  -- ^ The constraint solver couldn't find a solution.
  | NotTyDef String        -- ^ An undefined type name was used.
  | NoTWild                -- ^ Wildcards are not allowed in terms.
  | NotEnoughArgs Con      -- ^ Not enough arguments provided to type constructor.
  | TooManyArgs Con        -- ^ Too many arguments provided to type constructor.
  | UnboundTyVar (Name Type) -- ^ Unbound type variable
  | NoPolyRec String [String] [Type] -- ^ Polymorphic recursion is not allowed
  | NoError                -- ^ Not an error.  The identity of the
                           --   @Monoid TCError@ instance.
  deriving Int -> TCError -> ShowS
[TCError] -> ShowS
TCError -> String
(Int -> TCError -> ShowS)
-> (TCError -> String) -> ([TCError] -> ShowS) -> Show TCError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TCError] -> ShowS
$cshowList :: [TCError] -> ShowS
show :: TCError -> String
$cshow :: TCError -> String
showsPrec :: Int -> TCError -> ShowS
$cshowsPrec :: Int -> TCError -> ShowS
Show

instance Semigroup TCError where
  TCError
_ <> :: TCError -> TCError -> TCError
<> TCError
r = TCError
r

-- | 'TCError' is a monoid where we simply discard the first error.
instance Monoid TCError where
  mempty :: TCError
mempty  = TCError
NoError
  mappend :: TCError -> TCError -> TCError
mappend = TCError -> TCError -> TCError
forall a. Semigroup a => a -> a -> a
(<>)

------------------------------------------------------------
-- Constraints
------------------------------------------------------------

-- | Emit a constraint.
constraint :: Member (Writer Constraint) r => Constraint -> Sem r ()
constraint :: Constraint -> Sem r ()
constraint = Constraint -> Sem r ()
forall o (r :: EffectRow). Member (Writer o) r => o -> Sem r ()
tell

-- | Emit a list of constraints.
constraints :: Member (Writer Constraint) r => [Constraint] -> Sem r ()
constraints :: [Constraint] -> Sem r ()
constraints = Constraint -> Sem r ()
forall (r :: EffectRow).
Member (Writer Constraint) r =>
Constraint -> Sem r ()
constraint (Constraint -> Sem r ())
-> ([Constraint] -> Constraint) -> [Constraint] -> Sem r ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Constraint] -> Constraint
cAnd

-- | Close over the current constraint with a forall.
forAll :: Member (Writer Constraint) r => [Name Type] -> Sem r a -> Sem r a
forAll :: [Name Type] -> Sem r a -> Sem r a
forAll [Name Type]
nms = (Constraint -> Constraint) -> Sem r a -> Sem r a
forall o (r :: EffectRow) a.
Member (Writer o) r =>
(o -> o) -> Sem r a -> Sem r a
censor (Bind [Name Type] Constraint -> Constraint
CAll (Bind [Name Type] Constraint -> Constraint)
-> (Constraint -> Bind [Name Type] Constraint)
-> Constraint
-> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name Type] -> Constraint -> Bind [Name Type] Constraint
forall p t. (Alpha p, Alpha t) => p -> t -> Bind p t
bind [Name Type]
nms)

-- | Run two constraint-generating actions and combine the constraints
--   via disjunction.
cOr :: Members '[Writer Constraint] r => Sem r () -> Sem r () -> Sem r ()
cOr :: Sem r () -> Sem r () -> Sem r ()
cOr Sem r ()
m1 Sem r ()
m2 = do
  (Constraint
c1, ()
_) <- (Constraint -> Constraint)
-> Sem r (Constraint, ()) -> Sem r (Constraint, ())
forall o (r :: EffectRow) a.
Member (Writer o) r =>
(o -> o) -> Sem r a -> Sem r a
censor (Constraint -> Constraint -> Constraint
forall a b. a -> b -> a
const Constraint
CTrue) (Sem r () -> Sem r (Constraint, ())
forall o (r :: EffectRow) a.
Member (Writer o) r =>
Sem r a -> Sem r (o, a)
listen Sem r ()
m1)
  (Constraint
c2, ()
_) <- (Constraint -> Constraint)
-> Sem r (Constraint, ()) -> Sem r (Constraint, ())
forall o (r :: EffectRow) a.
Member (Writer o) r =>
(o -> o) -> Sem r a -> Sem r a
censor (Constraint -> Constraint -> Constraint
forall a b. a -> b -> a
const Constraint
CTrue) (Sem r () -> Sem r (Constraint, ())
forall o (r :: EffectRow) a.
Member (Writer o) r =>
Sem r a -> Sem r (o, a)
listen Sem r ()
m2)
  Constraint -> Sem r ()
forall (r :: EffectRow).
Member (Writer Constraint) r =>
Constraint -> Sem r ()
constraint (Constraint -> Sem r ()) -> Constraint -> Sem r ()
forall a b. (a -> b) -> a -> b
$ [Constraint] -> Constraint
COr [Constraint
c1, Constraint
c2]

-- | Run a computation that generates constraints, returning the
--   generated 'Constraint' along with the output. Note that this
--   locally dispatches the constraint writer effect.
--
--   This function is somewhat low-level; typically you should use
--   'solve' instead, which also solves the generated constraints.
withConstraint :: Sem (Writer Constraint ': r) a -> Sem r (a, Constraint)
withConstraint :: Sem (Writer Constraint : r) a -> Sem r (a, Constraint)
withConstraint = ((Constraint, a) -> (a, Constraint))
-> Sem r (Constraint, a) -> Sem r (a, Constraint)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Constraint, a) -> (a, Constraint)
forall a b. (a, b) -> (b, a)
swap (Sem r (Constraint, a) -> Sem r (a, Constraint))
-> (Sem (Writer Constraint : r) a -> Sem r (Constraint, a))
-> Sem (Writer Constraint : r) a
-> Sem r (a, Constraint)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sem (Writer Constraint : r) a -> Sem r (Constraint, a)
forall o (r :: EffectRow) a.
Monoid o =>
Sem (Writer o : r) a -> Sem r (o, a)
runWriter

-- | Run a computation and solve its generated constraint, returning
--   the resulting substitution (or failing with an error).  Note that
--   this locally dispatches the constraint writer effect.
solve
  :: Members '[Reader TyDefCtx, Error TCError, Output Message] r
  => Sem (Writer Constraint ': r) a -> Sem r (a, S)
solve :: Sem (Writer Constraint : r) a -> Sem r (a, S)
solve Sem (Writer Constraint : r) a
m = do
  (a
a, Constraint
c) <- Sem (Writer Constraint : r) a -> Sem r (a, Constraint)
forall (r :: EffectRow) a.
Sem (Writer Constraint : r) a -> Sem r (a, Constraint)
withConstraint Sem (Writer Constraint : r) a
m
  Either SolveError S
res <- Sem (Fresh : Error SolveError : r) S -> Sem r (Either SolveError S)
forall (r :: EffectRow) a.
Sem (Fresh : Error SolveError : r) a -> Sem r (Either SolveError a)
runSolve (Sem (Fresh : Error SolveError : r) S
 -> Sem r (Either SolveError S))
-> (Constraint -> Sem (Fresh : Error SolveError : r) S)
-> Constraint
-> Sem r (Either SolveError S)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sem (Input TyDefCtx : Fresh : Error SolveError : r) S
-> Sem (Fresh : Error SolveError : r) S
forall i (r :: EffectRow) a.
Member (Reader i) r =>
Sem (Input i : r) a -> Sem r a
inputToReader (Sem (Input TyDefCtx : Fresh : Error SolveError : r) S
 -> Sem (Fresh : Error SolveError : r) S)
-> (Constraint
    -> Sem (Input TyDefCtx : Fresh : Error SolveError : r) S)
-> Constraint
-> Sem (Fresh : Error SolveError : r) S
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Sem (Input TyDefCtx : Fresh : Error SolveError : r) S
forall (r :: EffectRow).
Members
  '[Fresh, Error SolveError, Output Message, Input TyDefCtx] r =>
Constraint -> Sem r S
solveConstraint (Constraint -> Sem r (Either SolveError S))
-> Constraint -> Sem r (Either SolveError S)
forall a b. (a -> b) -> a -> b
$ Constraint
c
  case Either SolveError S
res of
    Left SolveError
e  -> TCError -> Sem r (a, S)
forall e (r :: EffectRow) a. Member (Error e) r => e -> Sem r a
throw (SolveError -> TCError
Unsolvable SolveError
e)
    Right S
s -> (a, S) -> Sem r (a, S)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, S
s)

------------------------------------------------------------
-- Contexts
------------------------------------------------------------

-- | Look up the definition of a named type.  Throw a 'NotTyDef' error
--   if it is not found.
lookupTyDefn ::
  Members '[Reader TyDefCtx, Error TCError] r
  => String -> [Type] -> Sem r Type
lookupTyDefn :: String -> [Type] -> Sem r Type
lookupTyDefn String
x [Type]
args = do
  TyDefCtx
d <- forall (r :: EffectRow).
Member (Reader TyDefCtx) r =>
Sem r TyDefCtx
forall i (r :: EffectRow). Member (Reader i) r => Sem r i
ask @TyDefCtx
  case String -> TyDefCtx -> Maybe TyDefBody
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
x TyDefCtx
d of
    Maybe TyDefBody
Nothing                 -> TCError -> Sem r Type
forall e (r :: EffectRow) a. Member (Error e) r => e -> Sem r a
throw (String -> TCError
NotTyDef String
x)
    Just (TyDefBody [String]
_ [Type] -> Type
body) -> Type -> Sem r Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Sem r Type) -> Type -> Sem r Type
forall a b. (a -> b) -> a -> b
$ [Type] -> Type
body [Type]
args

-- | Run a subcomputation with an extended type definition context.
withTyDefns :: Member (Reader TyDefCtx) r => TyDefCtx -> Sem r a -> Sem r a
withTyDefns :: TyDefCtx -> Sem r a -> Sem r a
withTyDefns TyDefCtx
tyDefnCtx = (TyDefCtx -> TyDefCtx) -> Sem r a -> Sem r a
forall i (r :: EffectRow) a.
Member (Reader i) r =>
(i -> i) -> Sem r a -> Sem r a
local (TyDefCtx -> TyDefCtx -> TyDefCtx
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union TyDefCtx
tyDefnCtx)

------------------------------------------------------------
-- Fresh name generation
------------------------------------------------------------

-- | Generate a type variable with a fresh name.
freshTy :: Member Fresh r => Sem r Type
freshTy :: Sem r Type
freshTy = Name Type -> Type
TyVar (Name Type -> Type) -> Sem r (Name Type) -> Sem r Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name Type -> Sem r (Name Type)
forall (r :: EffectRow) x.
Member Fresh r =>
Name x -> Sem r (Name x)
fresh (String -> Name Type
forall a. String -> Name a
string2Name String
"a")

-- | Generate a fresh variable as an atom.
freshAtom :: Member Fresh r => Sem r Atom
freshAtom :: Sem r Atom
freshAtom = Var -> Atom
AVar (Var -> Atom) -> (Name Type -> Var) -> Name Type -> Atom
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name Type -> Var
U (Name Type -> Atom) -> Sem r (Name Type) -> Sem r Atom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name Type -> Sem r (Name Type)
forall (r :: EffectRow) x.
Member Fresh r =>
Name x -> Sem r (Name x)
fresh (String -> Name Type
forall a. String -> Name a
string2Name String
"c")