-----------------------------------------------------------------------------
-- |
-- Module      :  Language.Haskell.TH.Desugar.Subst.Capturing
-- Copyright   :  (C) 2024 Ryan Scott
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Substitutions on 'DType's that do /not/ avoid capture. (For capture-avoiding
-- substitution functions, use "Language.Haskell.TH.Desugar.Subst" instead.)
--
----------------------------------------------------------------------------

module Language.Haskell.TH.Desugar.Subst.Capturing (
  DSubst,

  -- * Non–capture-avoiding substitution
  substTy, substForallTelescope, substTyVarBndrs, substTyVarBndr,
  unionSubsts, unionMaybeSubsts,

  -- * Matching a type template against a type
  IgnoreKinds(..), matchTy
  ) where

import Data.Bifunctor (second)
import qualified Data.List as L
import qualified Data.Map as M

import Language.Haskell.TH.Desugar.AST
import Language.Haskell.TH.Desugar.Subst
  (DSubst, unionSubsts, unionMaybeSubsts, IgnoreKinds(..), matchTy)

-- | Non–capture-avoiding substitution on 'DType's. Unlike the @substTy@
-- function in "Language.Haskell.TH.Desugar.Subst", this 'substTy' function is
-- pure, as it never needs to create fresh names.
substTy :: DSubst -> DType -> DType
substTy :: DSubst -> DType -> DType
substTy DSubst
subst DType
ty | DSubst -> Bool
forall k a. Map k a -> Bool
M.null DSubst
subst = DType
ty
substTy DSubst
subst (DForallT DForallTelescope
tele DType
inner_ty)
  = DForallTelescope -> DType -> DType
DForallT DForallTelescope
tele' DType
inner_ty'
  where
    (DSubst
subst', DForallTelescope
tele') = DSubst -> DForallTelescope -> (DSubst, DForallTelescope)
substForallTelescope DSubst
subst DForallTelescope
tele
    inner_ty' :: DType
inner_ty'       = DSubst -> DType -> DType
substTy DSubst
subst' DType
inner_ty
substTy DSubst
subst (DConstrainedT DCxt
cxt DType
inner_ty) =
  DCxt -> DType -> DType
DConstrainedT ((DType -> DType) -> DCxt -> DCxt
forall a b. (a -> b) -> [a] -> [b]
map (DSubst -> DType -> DType
substTy DSubst
subst) DCxt
cxt) (DSubst -> DType -> DType
substTy DSubst
subst DType
inner_ty)
substTy DSubst
subst (DAppT DType
ty1 DType
ty2) = DSubst -> DType -> DType
substTy DSubst
subst DType
ty1 DType -> DType -> DType
`DAppT` DSubst -> DType -> DType
substTy DSubst
subst DType
ty2
substTy DSubst
subst (DAppKindT DType
ty DType
ki) = DSubst -> DType -> DType
substTy DSubst
subst DType
ty DType -> DType -> DType
`DAppKindT` DSubst -> DType -> DType
substTy DSubst
subst DType
ki
substTy DSubst
subst (DSigT DType
ty DType
ki) = DSubst -> DType -> DType
substTy DSubst
subst DType
ty DType -> DType -> DType
`DSigT` DSubst -> DType -> DType
substTy DSubst
subst DType
ki
substTy DSubst
subst (DVarT Name
n) =
  case Name -> DSubst -> Maybe DType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
n DSubst
subst of
    Just DType
ki -> DType
ki
    Maybe DType
Nothing -> Name -> DType
DVarT Name
n
substTy DSubst
_ ty :: DType
ty@(DConT {}) = DType
ty
substTy DSubst
_ ty :: DType
ty@(DType
DArrowT)  = DType
ty
substTy DSubst
_ ty :: DType
ty@(DLitT {}) = DType
ty
substTy DSubst
_ ty :: DType
ty@DType
DWildCardT = DType
ty

-- | Non–capture-avoiding substitution on 'DForallTelescope's. This returns a
-- pair containing the new 'DSubst' as well as a new 'DForallTelescope' value,
-- where the kinds have been substituted.
substForallTelescope :: DSubst -> DForallTelescope -> (DSubst, DForallTelescope)
substForallTelescope :: DSubst -> DForallTelescope -> (DSubst, DForallTelescope)
substForallTelescope DSubst
s (DForallInvis [DTyVarBndrSpec]
tvbs) = ([DTyVarBndrSpec] -> DForallTelescope)
-> (DSubst, [DTyVarBndrSpec]) -> (DSubst, DForallTelescope)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [DTyVarBndrSpec] -> DForallTelescope
DForallInvis ((DSubst, [DTyVarBndrSpec]) -> (DSubst, DForallTelescope))
-> (DSubst, [DTyVarBndrSpec]) -> (DSubst, DForallTelescope)
forall a b. (a -> b) -> a -> b
$ DSubst -> [DTyVarBndrSpec] -> (DSubst, [DTyVarBndrSpec])
forall flag.
DSubst -> [DTyVarBndr flag] -> (DSubst, [DTyVarBndr flag])
substTyVarBndrs DSubst
s [DTyVarBndrSpec]
tvbs
substForallTelescope DSubst
s (DForallVis   [DTyVarBndrUnit]
tvbs) = ([DTyVarBndrUnit] -> DForallTelescope)
-> (DSubst, [DTyVarBndrUnit]) -> (DSubst, DForallTelescope)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [DTyVarBndrUnit] -> DForallTelescope
DForallVis   ((DSubst, [DTyVarBndrUnit]) -> (DSubst, DForallTelescope))
-> (DSubst, [DTyVarBndrUnit]) -> (DSubst, DForallTelescope)
forall a b. (a -> b) -> a -> b
$ DSubst -> [DTyVarBndrUnit] -> (DSubst, [DTyVarBndrUnit])
forall flag.
DSubst -> [DTyVarBndr flag] -> (DSubst, [DTyVarBndr flag])
substTyVarBndrs DSubst
s [DTyVarBndrUnit]
tvbs

-- | Non–capture-avoiding substitution on a telescope of 'DTyVarBndr's. This
-- returns a pair containing the new 'DSubst' as well as a new telescope of
-- 'DTyVarBndr's, where the kinds have been substituted.
substTyVarBndrs :: DSubst -> [DTyVarBndr flag] -> (DSubst, [DTyVarBndr flag])
substTyVarBndrs :: forall flag.
DSubst -> [DTyVarBndr flag] -> (DSubst, [DTyVarBndr flag])
substTyVarBndrs = (DSubst -> DTyVarBndr flag -> (DSubst, DTyVarBndr flag))
-> DSubst -> [DTyVarBndr flag] -> (DSubst, [DTyVarBndr flag])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
L.mapAccumL DSubst -> DTyVarBndr flag -> (DSubst, DTyVarBndr flag)
forall flag. DSubst -> DTyVarBndr flag -> (DSubst, DTyVarBndr flag)
substTyVarBndr

-- | Non–capture-avoiding substitution on a 'DTyVarBndr'. This updates the
-- 'DSubst' to remove the 'DTyVarBndr' name from the domain (as that name is now
-- bound by the 'DTyVarBndr') and applies the substitution to the kind of the
-- 'DTyVarBndr'.
substTyVarBndr :: DSubst -> DTyVarBndr flag -> (DSubst, DTyVarBndr flag)
substTyVarBndr :: forall flag. DSubst -> DTyVarBndr flag -> (DSubst, DTyVarBndr flag)
substTyVarBndr DSubst
s tvb :: DTyVarBndr flag
tvb@(DPlainTV Name
n flag
_) = (Name -> DSubst -> DSubst
forall k a. Ord k => k -> Map k a -> Map k a
M.delete Name
n DSubst
s, DTyVarBndr flag
tvb)
substTyVarBndr DSubst
s (DKindedTV Name
n flag
f DType
k)  = (Name -> DSubst -> DSubst
forall k a. Ord k => k -> Map k a -> Map k a
M.delete Name
n DSubst
s, Name -> flag -> DType -> DTyVarBndr flag
forall flag. Name -> flag -> DType -> DTyVarBndr flag
DKindedTV Name
n flag
f (DSubst -> DType -> DType
substTy DSubst
s DType
k))