-- |
-- Module      :  Cryptol.TypeCheck.Unify
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor, DeriveGeneric, DeriveAnyClass #-}
{-# LANGUAGE BlockArguments, OverloadedStrings #-}
module Cryptol.TypeCheck.Unify where

import Control.DeepSeq(NFData)
import GHC.Generics(Generic)

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Cryptol.Utils.RecordMap
import Cryptol.Utils.Ident(Ident)
import Cryptol.ModuleSystem.Name(nameIdent)

import Cryptol.TypeCheck.PP
import Control.Monad.Writer (Writer, writer, runWriter)
import qualified Data.Set as Set

import Prelude ()
import Prelude.Compat

-- | The most general unifier is a substitution and a set of constraints
-- on bound variables.
type MGU = (Subst,[Prop])

type Result a = Writer [(Path,UnificationError)] a

runResult :: Result a -> (a, [(Path,UnificationError)])
runResult :: forall a. Result a -> (a, [(Path, UnificationError)])
runResult = forall w a. Writer w a -> (a, w)
runWriter

data UnificationError
  = UniTypeMismatch Type Type
  | UniKindMismatch Kind Kind
  | UniTypeLenMismatch Int Int
  | UniRecursive TVar Type
  | UniNonPolyDepends TVar [TParam]
  | UniNonPoly TVar Type

uniError :: Path -> UnificationError -> Result MGU
uniError :: Path -> UnificationError -> Result MGU
uniError Path
p UnificationError
e = forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (MGU
emptyMGU, [(Path
p,UnificationError
e)])


newtype Path = Path [PathElement]
  deriving (Int -> Path -> ShowS
[Path] -> ShowS
Path -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Path] -> ShowS
$cshowList :: [Path] -> ShowS
show :: Path -> String
$cshow :: Path -> String
showsPrec :: Int -> Path -> ShowS
$cshowsPrec :: Int -> Path -> ShowS
Show,forall x. Rep Path x -> Path
forall x. Path -> Rep Path x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Path x -> Path
$cfrom :: forall x. Path -> Rep Path x
Generic,Path -> ()
forall a. (a -> ()) -> NFData a
rnf :: Path -> ()
$crnf :: Path -> ()
NFData)

data PathElement =
    TConArg     TC      Int
  | TNewtypeArg Newtype Int
  | TRecArg     Ident
  deriving (Int -> PathElement -> ShowS
[PathElement] -> ShowS
PathElement -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PathElement] -> ShowS
$cshowList :: [PathElement] -> ShowS
show :: PathElement -> String
$cshow :: PathElement -> String
showsPrec :: Int -> PathElement -> ShowS
$cshowsPrec :: Int -> PathElement -> ShowS
Show,forall x. Rep PathElement x -> PathElement
forall x. PathElement -> Rep PathElement x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep PathElement x -> PathElement
$cfrom :: forall x. PathElement -> Rep PathElement x
Generic,PathElement -> ()
forall a. (a -> ()) -> NFData a
rnf :: PathElement -> ()
$crnf :: PathElement -> ()
NFData)

rootPath :: Path
rootPath :: Path
rootPath = [PathElement] -> Path
Path []

isRootPath :: Path -> Bool
isRootPath :: Path -> Bool
isRootPath (Path [PathElement]
xs) = forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PathElement]
xs

extPath :: Path -> PathElement -> Path
extPath :: Path -> PathElement -> Path
extPath (Path [PathElement]
xs) PathElement
x = [PathElement] -> Path
Path (PathElement
x forall a. a -> [a] -> [a]
: [PathElement]
xs)


emptyMGU :: MGU
emptyMGU :: MGU
emptyMGU = (Subst
emptySubst, [])

doMGU :: Type -> Type -> Result MGU
doMGU :: Type -> Type -> Result MGU
doMGU Type
t1 Type
t2 = Path -> Type -> Type -> Result MGU
mgu Path
rootPath Type
t1 Type
t2

mgu :: Path -> Type -> Type -> Result MGU

mgu :: Path -> Type -> Type -> Result MGU
mgu Path
_ (TUser Name
c1 [Type]
ts1 Type
_) (TUser Name
c2 [Type]
ts2 Type
_)
  | Name
c1 forall a. Eq a => a -> a -> Bool
== Name
c2 Bool -> Bool -> Bool
&& [Type]
ts1 forall a. Eq a => a -> a -> Bool
== [Type]
ts2  = forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

mgu Path
p (TVar TVar
x) Type
t     = Path -> TVar -> Type -> Result MGU
bindVar Path
p TVar
x Type
t
mgu Path
p Type
t (TVar TVar
x)     = Path -> TVar -> Type -> Result MGU
bindVar Path
p TVar
x Type
t

mgu Path
p (TUser Name
_ [Type]
_ Type
t1) Type
t2 = Path -> Type -> Type -> Result MGU
mgu Path
p Type
t1 Type
t2
mgu Path
p Type
t1 (TUser Name
_ [Type]
_ Type
t2) = Path -> Type -> Type -> Result MGU
mgu Path
p Type
t1 Type
t2

mgu Path
p (TCon (TC TC
tc1) [Type]
ts1) (TCon (TC TC
tc2) [Type]
ts2)
  | TC
tc1 forall a. Eq a => a -> a -> Bool
== TC
tc2 =
    let paths :: [Path]
paths = [ Path -> PathElement -> Path
extPath Path
p (TC -> Int -> PathElement
TConArg TC
tc1 Int
i) | Int
i <- [ Int
0 .. ] ]
    in Path -> [Path] -> [Type] -> [Type] -> Result MGU
mguMany Path
p [Path]
paths [Type]
ts1 [Type]
ts2

mgu Path
_ (TCon (TF TFun
f1) [Type]
ts1) (TCon (TF TFun
f2) [Type]
ts2)
  | TFun
f1 forall a. Eq a => a -> a -> Bool
== TFun
f2 Bool -> Bool -> Bool
&& [Type]
ts1 forall a. Eq a => a -> a -> Bool
== [Type]
ts2  = forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

-- XXX: here we loose the information about where the constarint came from
mgu Path
_ Type
t1 Type
t2
  | TCon (TF TFun
_) [Type]
_ <- Type
t1, Bool
isNum, Kind
k1 forall a. Eq a => a -> a -> Bool
== Kind
k2 = forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Type
t1 Type -> Type -> Type
=#= Type
t2])
  | TCon (TF TFun
_) [Type]
_ <- Type
t2, Bool
isNum, Kind
k1 forall a. Eq a => a -> a -> Bool
== Kind
k2 = forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Type
t1 Type -> Type -> Type
=#= Type
t2])
  where
  k1 :: Kind
k1 = forall t. HasKind t => t -> Kind
kindOf Type
t1
  k2 :: Kind
k2 = forall t. HasKind t => t -> Kind
kindOf Type
t2

  isNum :: Bool
isNum = Kind
k1 forall a. Eq a => a -> a -> Bool
== Kind
KNum

mgu Path
p (TRec RecordMap Ident Type
fs1) (TRec RecordMap Ident Type
fs2)
  | forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap Ident Type
fs1 forall a. Eq a => a -> a -> Bool
== forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap Ident Type
fs2 =
    let paths :: [Path]
paths = [ Path -> PathElement -> Path
extPath Path
p (Ident -> PathElement
TRecArg Ident
i) | (Ident
i,Type
_) <- forall a b. RecordMap a b -> [(a, b)]
canonicalFields RecordMap Ident Type
fs1 ]
    in Path -> [Path] -> [Type] -> [Type] -> Result MGU
mguMany Path
p [Path]
paths (forall a b. RecordMap a b -> [b]
recordElements RecordMap Ident Type
fs1) (forall a b. RecordMap a b -> [b]
recordElements RecordMap Ident Type
fs2)

mgu Path
p (TNewtype Newtype
ntx [Type]
xs) (TNewtype Newtype
nty [Type]
ys)
  | Newtype
ntx forall a. Eq a => a -> a -> Bool
== Newtype
nty =
    let paths :: [Path]
paths = [ Path -> PathElement -> Path
extPath Path
p (Newtype -> Int -> PathElement
TNewtypeArg Newtype
ntx Int
i) | Int
i <- [ Int
0 .. ] ]
    in Path -> [Path] -> [Type] -> [Type] -> Result MGU
mguMany Path
p [Path]
paths [Type]
xs [Type]
ys

mgu Path
p Type
t1 Type
t2
  | Bool -> Bool
not (Kind
k1 forall a. Eq a => a -> a -> Bool
== Kind
k2)  = Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
  | Bool
otherwise       = Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ Type -> Type -> UnificationError
UniTypeMismatch Type
t1 Type
t2
  where
  k1 :: Kind
k1 = forall t. HasKind t => t -> Kind
kindOf Type
t1
  k2 :: Kind
k2 = forall t. HasKind t => t -> Kind
kindOf Type
t2


-- XXX: could pass the path to the lists themselvs
mguMany :: Path -> [Path] -> [Type] -> [Type] -> Result MGU
mguMany :: Path -> [Path] -> [Type] -> [Type] -> Result MGU
mguMany Path
_ [Path]
_ [] [] = forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mguMany Path
p (Path
p1:[Path]
ps) (Type
t1 : [Type]
ts1) (Type
t2 : [Type]
ts2) =
  do (Subst
su1,[Type]
ps1) <- Path -> Type -> Type -> Result MGU
mgu Path
p1 Type
t1 Type
t2
     (Subst
su2,[Type]
ps2) <- Path -> [Path] -> [Type] -> [Type] -> Result MGU
mguMany Path
p [Path]
ps (forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Type]
ts1) (forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Type]
ts2)
     forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su2 Subst -> Subst -> Subst
@@ Subst
su1, [Type]
ps1 forall a. [a] -> [a] -> [a]
++ [Type]
ps2)
mguMany Path
p [Path]
_ [Type]
t1 [Type]
t2 = Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ Int -> Int -> UnificationError
UniTypeLenMismatch (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t1) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t2)
-- XXX: I think by this point the types should have been kind checked,
-- so there should be no mismatches with the lengths...


bindVar :: Path -> TVar -> Type -> Result MGU

bindVar :: Path -> TVar -> Type -> Result MGU
bindVar Path
_ TVar
x (Type -> Type
tNoUser -> TVar TVar
y)
  | TVar
x forall a. Eq a => a -> a -> Bool
== TVar
y                      = forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

bindVar Path
p v :: TVar
v@(TVBound {})
          (Type -> Type
tNoUser -> TVar v1 :: TVar
v1@(TVFree {})) = Path -> TVar -> Type -> Result MGU
bindVar Path
p TVar
v1 (TVar -> Type
TVar TVar
v)

bindVar Path
p v :: TVar
v@(TVBound {}) Type
t
  | Kind
k forall a. Eq a => a -> a -> Bool
== forall t. HasKind t => t -> Kind
kindOf Type
t = if Kind
k forall a. Eq a => a -> a -> Bool
== Kind
KNum
                       then forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Type
TVar TVar
v Type -> Type -> Type
=#= Type
t])
                       else Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ TVar -> Type -> UnificationError
UniNonPoly TVar
v Type
t
  | Bool
otherwise     = Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (forall t. HasKind t => t -> Kind
kindOf Type
t)
  where k :: Kind
k = forall t. HasKind t => t -> Kind
kindOf TVar
v

bindVar Path
_ x :: TVar
x@(TVFree Int
_ Kind
xk Set TParam
xscope TVarInfo
_) (Type -> Type
tNoUser -> TVar y :: TVar
y@(TVFree Int
_ Kind
yk Set TParam
yscope TVarInfo
_))
  | Set TParam
xscope forall a. Ord a => Set a -> Set a -> Bool
`Set.isProperSubsetOf` Set TParam
yscope, Kind
xk forall a. Eq a => a -> a -> Bool
== Kind
yk =
    forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Type -> Subst
uncheckedSingleSubst TVar
y (TVar -> Type
TVar TVar
x), [])
    -- In this case, we can add the reverse binding y ~> x to the
    -- substitution, but the instantiation x ~> y would be forbidden
    -- because it would allow y to escape from its scope.

bindVar Path
p TVar
x Type
t =
  case TVar -> Type -> Either SubstError Subst
singleSubst TVar
x Type
t of
    Left SubstError
SubstRecursive
      | forall t. HasKind t => t -> Kind
kindOf TVar
x forall a. Eq a => a -> a -> Bool
== Kind
KType -> Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ TVar -> Type -> UnificationError
UniRecursive TVar
x Type
t
      | Bool
otherwise -> forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Type
TVar TVar
x Type -> Type -> Type
=#= Type
t])
    Left (SubstEscaped [TParam]
tps) ->
      Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ TVar -> [TParam] -> UnificationError
UniNonPolyDepends TVar
x [TParam]
tps
    Left (SubstKindMismatch Kind
k1 Kind
k2) ->
      Path -> UnificationError -> Result MGU
uniError Path
p forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
    Right Subst
su ->
      forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su, [])


--------------------------------------------------------------------------------

ppPathEl :: PathElement -> Int -> (Int -> Doc) -> Doc
ppPathEl :: PathElement -> Int -> (Int -> Doc) -> Doc
ppPathEl PathElement
el Int
prec Int -> Doc
k =
  case PathElement
el of
    TRecArg Ident
l -> Doc -> Doc
braces (forall a. PP a => a -> Doc
pp Ident
l Doc -> Doc -> Doc
<+> Doc
":" Doc -> Doc -> Doc
<+> Int -> Doc
k Int
0 Doc -> Doc -> Doc
<.> Doc
comma Doc -> Doc -> Doc
<+> Doc
"…")

    TConArg TC
tc Int
n ->
      case TC
tc of

       TC
TCSeq -> Bool -> Doc -> Doc
optParens (Int
prec forall a. Ord a => a -> a -> Bool
> Int
4)
                if Int
n forall a. Eq a => a -> a -> Bool
== Int
0 then Doc -> Doc
brackets (Int -> Doc
k Int
0) Doc -> Doc -> Doc
<+> Doc
"_"
                          else Doc -> Doc
brackets Doc
"_" Doc -> Doc -> Doc
<+> (Int -> Doc
k Int
4)

       TC
TCFun -> Bool -> Doc -> Doc
optParens (Int
prec forall a. Ord a => a -> a -> Bool
> Int
1)
                if Int
n forall a. Eq a => a -> a -> Bool
== Int
0 then Int -> Doc
k Int
2 Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
<+> Doc
"_"
                          else Doc
"_" Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
<+> Int -> Doc
k Int
1

       TCTuple Int
i  -> Doc -> Doc
parens ([Doc] -> Doc
commaSep ([Doc]
before forall a. [a] -> [a] -> [a]
++ [Int -> Doc
k Int
0] forall a. [a] -> [a] -> [a]
++ [Doc]
after))
          where before :: [Doc]
before = forall a. Int -> a -> [a]
replicate Int
n Doc
"_"
                after :: [Doc]
after  = forall a. Int -> a -> [a]
replicate (Int
i forall a. Num a => a -> a -> a
- Int
n forall a. Num a => a -> a -> a
- Int
1) Doc
"_"

       TC
_ -> Int -> Doc -> Int -> Doc
justPrefix (forall {a}. Num a => Kind -> a
kindArity (forall t. HasKind t => t -> Kind
kindOf TC
tc)) (forall a. PP a => a -> Doc
pp TC
tc) Int
n

    TNewtypeArg Newtype
nt Int
n ->
      Int -> Doc -> Int -> Doc
justPrefix (forall (t :: * -> *) a. Foldable t => t a -> Int
length (Newtype -> [TParam]
ntParams Newtype
nt)) (forall a. PP a => a -> Doc
pp (Name -> Ident
nameIdent (Newtype -> Name
ntName Newtype
nt))) Int
n

  where
  justPrefix :: Int -> Doc -> Int -> Doc
justPrefix Int
arity Doc
fun Int
n =
    Bool -> Doc -> Doc
optParens (Int
prec forall a. Ord a => a -> a -> Bool
> Int
3) (Doc
fun Doc -> Doc -> Doc
<+> [Doc] -> Doc
hsep ([Doc]
before forall a. [a] -> [a] -> [a]
++ [Int -> Doc
k Int
5] forall a. [a] -> [a] -> [a]
++ [Doc]
after))
    where before :: [Doc]
before = forall a. Int -> a -> [a]
replicate Int
n Doc
"_"
          after :: [Doc]
after  = forall a. Int -> a -> [a]
replicate (Int
arity forall a. Num a => a -> a -> a
- Int
n forall a. Num a => a -> a -> a
- Int
1) Doc
"_"

  kindArity :: Kind -> a
kindArity Kind
ki =
    case Kind
ki of
      Kind
_ :-> Kind
k1 -> a
1 forall a. Num a => a -> a -> a
+ Kind -> a
kindArity Kind
k1
      Kind
_        -> a
0

instance PP Path where
  ppPrec :: Int -> Path -> Doc
ppPrec Int
prec0 (Path [PathElement]
ps0) = [PathElement] -> Int -> Doc
go (forall a. [a] -> [a]
reverse [PathElement]
ps0) Int
prec0
    where
    go :: [PathElement] -> Int -> Doc
go [PathElement]
ps Int
prec =
      case [PathElement]
ps of
        []       -> Doc
"ERROR"
        PathElement
p : [PathElement]
more -> PathElement -> Int -> (Int -> Doc) -> Doc
ppPathEl PathElement
p Int
prec ([PathElement] -> Int -> Doc
go [PathElement]
more)