{-# LANGUAGE CPP #-}

-----------------------------------------------------------------------------

-- |

-- Module      :  Language.Haskell.TH.Desugar.Subst

-- Copyright   :  (C) 2018 Richard Eisenberg

-- License     :  BSD-style (see LICENSE)

-- Maintainer  :  Ryan Scott

-- Stability   :  experimental

-- Portability :  non-portable

--

-- Capture-avoiding substitutions on 'DType's

--

----------------------------------------------------------------------------


module Language.Haskell.TH.Desugar.Subst (
  DSubst,

  -- * Capture-avoiding substitution

  substTy, substForallTelescope, substTyVarBndrs,
  unionSubsts, unionMaybeSubsts,

  -- * Matching a type template against a type

  IgnoreKinds(..), matchTy
  ) where

import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S

import Language.Haskell.TH.Desugar.AST
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.Desugar.Util

-- | A substitution is just a map from names to types

type DSubst = M.Map Name DType

-- | Capture-avoiding substitution on types

substTy :: Quasi q => DSubst -> DType -> q DType
substTy :: forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars (DForallT DForallTelescope
tele DType
ty) = do
  (DSubst
vars', DForallTelescope
tele') <- forall (q :: * -> *).
Quasi q =>
DSubst -> DForallTelescope -> q (DSubst, DForallTelescope)
substForallTelescope DSubst
vars DForallTelescope
tele
  DType
ty' <- forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars' DType
ty
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ DForallTelescope -> DType -> DType
DForallT DForallTelescope
tele' DType
ty'
substTy DSubst
vars (DConstrainedT DCxt
cxt DType
ty) =
  DCxt -> DType -> DType
DConstrainedT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars) DCxt
cxt forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
ty
substTy DSubst
vars (DAppT DType
t1 DType
t2) =
  DType -> DType -> DType
DAppT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
t1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
t2
substTy DSubst
vars (DAppKindT DType
t DType
k) =
  DType -> DType -> DType
DAppKindT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
k
substTy DSubst
vars (DSigT DType
ty DType
ki) =
  DType -> DType -> DType
DSigT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
ty forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
ki
substTy DSubst
vars (DVarT Name
n)
  | Just DType
ty <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
n DSubst
vars
  = forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
  | Bool
otherwise
  = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> DType
DVarT Name
n
substTy DSubst
_ ty :: DType
ty@(DConT Name
_)  = forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
substTy DSubst
_ ty :: DType
ty@DType
DArrowT    = forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
substTy DSubst
_ ty :: DType
ty@(DLitT TyLit
_)  = forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
substTy DSubst
_ ty :: DType
ty@DType
DWildCardT = forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty

substForallTelescope :: Quasi q => DSubst -> DForallTelescope
                     -> q (DSubst, DForallTelescope)
substForallTelescope :: forall (q :: * -> *).
Quasi q =>
DSubst -> DForallTelescope -> q (DSubst, DForallTelescope)
substForallTelescope DSubst
vars DForallTelescope
tele =
  case DForallTelescope
tele of
    DForallVis [DTyVarBndrUnit]
tvbs -> do
      (DSubst
vars', [DTyVarBndrUnit]
tvbs') <- forall (q :: * -> *) flag.
Quasi q =>
DSubst -> [DTyVarBndr flag] -> q (DSubst, [DTyVarBndr flag])
substTyVarBndrs DSubst
vars [DTyVarBndrUnit]
tvbs
      forall (m :: * -> *) a. Monad m => a -> m a
return (DSubst
vars', [DTyVarBndrUnit] -> DForallTelescope
DForallVis [DTyVarBndrUnit]
tvbs')
    DForallInvis [DTyVarBndrSpec]
tvbs -> do
      (DSubst
vars', [DTyVarBndrSpec]
tvbs') <- forall (q :: * -> *) flag.
Quasi q =>
DSubst -> [DTyVarBndr flag] -> q (DSubst, [DTyVarBndr flag])
substTyVarBndrs DSubst
vars [DTyVarBndrSpec]
tvbs
      forall (m :: * -> *) a. Monad m => a -> m a
return (DSubst
vars', [DTyVarBndrSpec] -> DForallTelescope
DForallInvis [DTyVarBndrSpec]
tvbs')

substTyVarBndrs :: Quasi q => DSubst -> [DTyVarBndr flag]
                -> q (DSubst, [DTyVarBndr flag])
substTyVarBndrs :: forall (q :: * -> *) flag.
Quasi q =>
DSubst -> [DTyVarBndr flag] -> q (DSubst, [DTyVarBndr flag])
substTyVarBndrs = forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM forall (q :: * -> *) flag.
Quasi q =>
DSubst -> DTyVarBndr flag -> q (DSubst, DTyVarBndr flag)
substTvb

substTvb :: Quasi q => DSubst -> DTyVarBndr flag
         -> q (DSubst, DTyVarBndr flag)
substTvb :: forall (q :: * -> *) flag.
Quasi q =>
DSubst -> DTyVarBndr flag -> q (DSubst, DTyVarBndr flag)
substTvb DSubst
vars (DPlainTV Name
n flag
flag) = do
  Name
new_n <- forall (m :: * -> *). Quasi m => String -> m Name
qNewName (Name -> String
nameBase Name
n)
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
n (Name -> DType
DVarT Name
new_n) DSubst
vars, forall flag. Name -> flag -> DTyVarBndr flag
DPlainTV Name
new_n flag
flag)
substTvb DSubst
vars (DKindedTV Name
n flag
flag DType
k) = do
  Name
new_n <- forall (m :: * -> *). Quasi m => String -> m Name
qNewName (Name -> String
nameBase Name
n)
  DType
k' <- forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
k
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
n (Name -> DType
DVarT Name
new_n) DSubst
vars, forall flag. Name -> flag -> DType -> DTyVarBndr flag
DKindedTV Name
new_n flag
flag DType
k')

-- | Computes the union of two substitutions. Fails if both subsitutions map

-- the same variable to different types.

unionSubsts :: DSubst -> DSubst -> Maybe DSubst
unionSubsts :: DSubst -> DSubst -> Maybe DSubst
unionSubsts DSubst
a DSubst
b =
  let shared_key_set :: Set Name
shared_key_set = forall k a. Map k a -> Set k
M.keysSet DSubst
a forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` forall k a. Map k a -> Set k
M.keysSet DSubst
b
      matches_up :: Bool
matches_up     = forall a b. (a -> b -> b) -> b -> Set a -> b
S.foldr (\Name
name -> ((DSubst
a forall k a. Ord k => Map k a -> k -> a
M.! Name
name) forall a. Eq a => a -> a -> Bool
== (DSubst
b forall k a. Ord k => Map k a -> k -> a
M.! Name
name) Bool -> Bool -> Bool
&&))
                               Bool
True Set Name
shared_key_set
  in
  if Bool
matches_up then forall (m :: * -> *) a. Monad m => a -> m a
return (DSubst
a forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` DSubst
b) else forall a. Maybe a
Nothing

---------------------------

-- Matching


-- | Ignore kind annotations in @matchTy@?

data IgnoreKinds = YesIgnore | NoIgnore

-- | @matchTy ign tmpl targ@ matches a type template @tmpl@ against a type

-- target @targ@. This returns a Map from names of type variables in the

-- type template to types if the types indeed match up, or @Nothing@ otherwise.

-- In the @Just@ case, it is guaranteed that every type variable mentioned

-- in the template is mapped by the returned substitution.

--

-- The first argument @ign@ tells @matchTy@ whether to ignore kind signatures

-- in the template. A kind signature in the template might mean that a type

-- variable has a more restrictive kind than otherwise possible, and that

-- mapping that type variable to a type of a different kind could be disastrous.

-- So, if we don't ignore kind signatures, this function returns @Nothing@ if

-- the template has a signature anywhere. If we do ignore kind signatures, it's

-- possible the returned map will be ill-kinded. Use at your own risk.

matchTy :: IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy :: IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
_   (DVarT Name
var_name) DType
arg = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton Name
var_name DType
arg
  -- if a pattern has a kind signature, it's really easy to get

  -- this wrong.

matchTy IgnoreKinds
ign (DSigT DType
ty DType
_ki) DType
arg = case IgnoreKinds
ign of
  IgnoreKinds
YesIgnore -> IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
ty DType
arg
  IgnoreKinds
NoIgnore  -> forall a. Maybe a
Nothing
  -- but we can safely ignore kind signatures on the target

matchTy IgnoreKinds
ign DType
pat (DSigT DType
ty DType
_ki) = IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
pat DType
ty
matchTy IgnoreKinds
_   (DForallT {}) DType
_ =
  forall a. HasCallStack => String -> a
error String
"Cannot match a forall in a pattern"
matchTy IgnoreKinds
_   DType
_ (DForallT {}) =
  forall a. HasCallStack => String -> a
error String
"Cannot match a forall in a target"
matchTy IgnoreKinds
ign (DAppT DType
pat1 DType
pat2) (DAppT DType
arg1 DType
arg2) =
  [Maybe DSubst] -> Maybe DSubst
unionMaybeSubsts [IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
pat1 DType
arg1, IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
pat2 DType
arg2]
matchTy IgnoreKinds
_   (DConT Name
pat_con) (DConT Name
arg_con)
  | Name
pat_con forall a. Eq a => a -> a -> Bool
== Name
arg_con = forall a. a -> Maybe a
Just forall k a. Map k a
M.empty
matchTy IgnoreKinds
_   DType
DArrowT DType
DArrowT = forall a. a -> Maybe a
Just forall k a. Map k a
M.empty
matchTy IgnoreKinds
_   (DLitT TyLit
pat_lit) (DLitT TyLit
arg_lit)
  | TyLit
pat_lit forall a. Eq a => a -> a -> Bool
== TyLit
arg_lit = forall a. a -> Maybe a
Just forall k a. Map k a
M.empty
matchTy IgnoreKinds
_ DType
_ DType
_ = forall a. Maybe a
Nothing

unionMaybeSubsts :: [Maybe DSubst] -> Maybe DSubst
unionMaybeSubsts :: [Maybe DSubst] -> Maybe DSubst
unionMaybeSubsts = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Maybe DSubst -> Maybe DSubst -> Maybe DSubst
union_subst1 (forall a. a -> Maybe a
Just forall k a. Map k a
M.empty)
  where
    union_subst1 :: Maybe DSubst -> Maybe DSubst -> Maybe DSubst
    union_subst1 :: Maybe DSubst -> Maybe DSubst -> Maybe DSubst
union_subst1 Maybe DSubst
ma Maybe DSubst
mb = do
      DSubst
a <- Maybe DSubst
ma
      DSubst
b <- Maybe DSubst
mb
      DSubst -> DSubst -> Maybe DSubst
unionSubsts DSubst
a DSubst
b