module TypeLet.Plugin.Substitution (
    letsToSubst
  , Cycle(..)
  , formatLetCycle
  ) where

import Data.Bifunctor
import Data.Foldable (toList)
import Data.List (intersperse)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe (mapMaybe)

import qualified Data.Graph as G

import TypeLet.Plugin.Constraints
import TypeLet.Plugin.GhcTcPluginAPI

-- | Construct idempotent substitution
--
-- TODO: Not entirely sure if this is correct, might be too simplistic and/or an
-- abuse of the ghc API; it is a /whole/ lot simpler than @niFixTCvSubst@, which
-- is disconcerning. However, it seems to work for the examples so far; perhaps
-- our use case is simpler? Needs more thought.
letsToSubst ::
     [GenLocated CtLoc CLet]
  -> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst :: [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst = ([(TyVar, Type)] -> TCvSubst)
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([TyVar] -> [Type] -> TCvSubst) -> ([TyVar], [Type]) -> TCvSubst
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [TyVar] -> [Type] -> TCvSubst
HasDebugCallStack => [TyVar] -> [Type] -> TCvSubst
zipTvSubst (([TyVar], [Type]) -> TCvSubst)
-> ([(TyVar, Type)] -> ([TyVar], [Type]))
-> [(TyVar, Type)]
-> TCvSubst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(TyVar, Type)] -> ([TyVar], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(TyVar, Type)] -> ([TyVar], [Type]))
-> ([(TyVar, Type)] -> [(TyVar, Type)])
-> [(TyVar, Type)]
-> ([TyVar], [Type])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
go []) (Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
 -> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst)
-> ([GenLocated CtLoc CLet]
    -> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)])
-> [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
inorder
  where
    go :: [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
    go :: [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
go [(TyVar, Type)]
acc []         = [(TyVar, Type)]
acc
    go [(TyVar, Type)]
acc ((TyVar
x, Type
t):[(TyVar, Type)]
s) = [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
go ((TyVar
x, Type
t) (TyVar, Type) -> [(TyVar, Type)] -> [(TyVar, Type)]
forall a. a -> [a] -> [a]
: ((TyVar, Type) -> (TyVar, Type))
-> [(TyVar, Type)] -> [(TyVar, Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> (TyVar, Type) -> (TyVar, Type)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (TyVar -> Type -> Type -> Type
subst1 TyVar
x Type
t)) [(TyVar, Type)]
acc) [(TyVar, Type)]
s

    subst1 :: TyVar -> Type -> Type -> Type
    subst1 :: TyVar -> Type -> Type -> Type
subst1 TyVar
x Type
t = HasCallStack => [TyVar] -> [Type] -> Type -> Type
[TyVar] -> [Type] -> Type -> Type
substTyWith [TyVar
x] [Type
t]

-- | Order the assignments
--
-- Suppose we have two assignments
--
-- > x := xT
-- > y := yT    where  x in (freevars yT)
--
-- Then the substitution should map @y@ to @(x := xT) yT@. We do this by
-- constructing the substitution in order, adding assignments one by one,
-- applying them to all assignments already in the accumulated substitution
-- as we go. In this example, this means adding @y := yT@ /first/, so that
-- we can apply @x := xT@ later (note that recursive definitions are
-- impossible in our use case).
--
-- In order to find the right order, we construct a graph of assignments. To
-- continue with our example, this graph will contain an edge
--
-- > (y := yT) -----> (x := xT)
--
-- The required assignment ordering is then obtained by topological sort.
inorder ::
     [GenLocated CtLoc CLet]
  -> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
inorder :: [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
inorder [GenLocated CtLoc CLet]
lets =
    case [(GenLocated CtLoc CLet, TyVar, [TyVar])]
-> [Cycle (GenLocated CtLoc CLet)]
forall key node. Ord key => [(node, key, [key])] -> [Cycle node]
cycles [(GenLocated CtLoc CLet, TyVar, [TyVar])]
edges of
      Cycle (GenLocated CtLoc CLet)
c:[Cycle (GenLocated CtLoc CLet)]
_ -> Cycle (GenLocated CtLoc CLet)
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
forall a b. a -> Either a b
Left Cycle (GenLocated CtLoc CLet)
c
      []  -> [(TyVar, Type)]
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
forall a b. b -> Either a b
Right ([(TyVar, Type)]
 -> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)])
-> [(TyVar, Type)]
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
forall a b. (a -> b) -> a -> b
$ [
          (TyVar
x, Type
t)
        | (L CtLoc
_ (CLet Type
_ TyVar
x Type
t), TyVar
_, [TyVar]
_) <- (Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar]))
-> [Vertex] -> [(GenLocated CtLoc CLet, TyVar, [TyVar])]
forall a b. (a -> b) -> [a] -> [b]
map Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar])
nodeFromVertex ([Vertex] -> [(GenLocated CtLoc CLet, TyVar, [TyVar])])
-> [Vertex] -> [(GenLocated CtLoc CLet, TyVar, [TyVar])]
forall a b. (a -> b) -> a -> b
$ Graph -> [Vertex]
G.topSort Graph
graph
        ]
  where
    graph          :: G.Graph
    nodeFromVertex :: G.Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar])
    _vertexFromKey :: TyVar -> Maybe G.Vertex
    (Graph
graph, Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar])
nodeFromVertex, TyVar -> Maybe Vertex
_vertexFromKey) = [(GenLocated CtLoc CLet, TyVar, [TyVar])]
-> (Graph, Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar]),
    TyVar -> Maybe Vertex)
forall key node.
Ord key =>
[(node, key, [key])]
-> (Graph, Vertex -> (node, key, [key]), key -> Maybe Vertex)
G.graphFromEdges [(GenLocated CtLoc CLet, TyVar, [TyVar])]
edges

    edges :: [(GenLocated CtLoc CLet, TyVar, [TyVar])]
    edges :: [(GenLocated CtLoc CLet, TyVar, [TyVar])]
edges = [
        ( GenLocated CtLoc CLet
l
        , TyVar
y
        , [ TyVar
x
          | L CtLoc
_ (CLet Type
_ TyVar
x Type
_) <- [GenLocated CtLoc CLet]
lets
          , TyVar
x TyVar -> VarSet -> Bool
`elemVarSet` (Type -> VarSet
tyCoVarsOfType Type
yT)
          ]
        )
      | l :: GenLocated CtLoc CLet
l@(L CtLoc
_ (CLet Type
_ TyVar
y Type
yT)) <- [GenLocated CtLoc CLet]
lets -- variables name match description above
      ]

-- | Format a cycle
--
-- We (arbitrarily) pick the first 'CLet' in the cycle for the location of the
-- error.
formatLetCycle ::
     Cycle (GenLocated CtLoc CLet)
  -> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle :: Cycle (GenLocated CtLoc CLet)
-> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle (Cycle vs :: NonEmpty (GenLocated CtLoc CLet)
vs@(L CtLoc
l CLet
_ :| [GenLocated CtLoc CLet]
_)) = CtLoc
-> TcPluginErrorMessage -> GenLocated CtLoc TcPluginErrorMessage
forall l e. l -> e -> GenLocated l e
L CtLoc
l (TcPluginErrorMessage -> GenLocated CtLoc TcPluginErrorMessage)
-> TcPluginErrorMessage -> GenLocated CtLoc TcPluginErrorMessage
forall a b. (a -> b) -> a -> b
$
          String -> TcPluginErrorMessage
Txt String
"Cycle in type-level let bindings: "
      TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: ( (TcPluginErrorMessage
 -> TcPluginErrorMessage -> TcPluginErrorMessage)
-> [TcPluginErrorMessage] -> TcPluginErrorMessage
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
(:|:)
          ([TcPluginErrorMessage] -> TcPluginErrorMessage)
-> ([GenLocated CtLoc CLet] -> [TcPluginErrorMessage])
-> [GenLocated CtLoc CLet]
-> TcPluginErrorMessage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcPluginErrorMessage
-> [TcPluginErrorMessage] -> [TcPluginErrorMessage]
forall a. a -> [a] -> [a]
intersperse (String -> TcPluginErrorMessage
Txt String
", ")
          ([TcPluginErrorMessage] -> [TcPluginErrorMessage])
-> ([GenLocated CtLoc CLet] -> [TcPluginErrorMessage])
-> [GenLocated CtLoc CLet]
-> [TcPluginErrorMessage]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GenLocated CtLoc CLet -> TcPluginErrorMessage)
-> [GenLocated CtLoc CLet] -> [TcPluginErrorMessage]
forall a b. (a -> b) -> [a] -> [b]
map (\(L CtLoc
_ CLet
l') -> CLet -> TcPluginErrorMessage
formatCLet CLet
l')
          ([GenLocated CtLoc CLet] -> TcPluginErrorMessage)
-> [GenLocated CtLoc CLet] -> TcPluginErrorMessage
forall a b. (a -> b) -> a -> b
$ NonEmpty (GenLocated CtLoc CLet) -> [GenLocated CtLoc CLet]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (GenLocated CtLoc CLet)
vs
          )

{-------------------------------------------------------------------------------
  Auxiliary
-------------------------------------------------------------------------------}

-- | Cycle in a graph
data Cycle a = Cycle (NonEmpty a)

cycles :: Ord key => [(node, key, [key])] -> [Cycle node]
cycles :: [(node, key, [key])] -> [Cycle node]
cycles = (SCC node -> Maybe (Cycle node)) -> [SCC node] -> [Cycle node]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SCC node -> Maybe (Cycle node)
forall a. SCC a -> Maybe (Cycle a)
aux ([SCC node] -> [Cycle node])
-> ([(node, key, [key])] -> [SCC node])
-> [(node, key, [key])]
-> [Cycle node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(node, key, [key])] -> [SCC node]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
G.stronglyConnComp
  where
    aux :: G.SCC a -> Maybe (Cycle a)
    aux :: SCC a -> Maybe (Cycle a)
aux (G.AcyclicSCC a
_) = Maybe (Cycle a)
forall a. Maybe a
Nothing
    aux (G.CyclicSCC [a]
vs) =
        case [a]
vs of
          a
v:[a]
vs'      -> Cycle a -> Maybe (Cycle a)
forall a. a -> Maybe a
Just (Cycle a -> Maybe (Cycle a)) -> Cycle a -> Maybe (Cycle a)
forall a b. (a -> b) -> a -> b
$ NonEmpty a -> Cycle a
forall a. NonEmpty a -> Cycle a
Cycle (a
v a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
vs')
          [a]
_otherwise -> Maybe (Cycle a)
forall a. Maybe a
Nothing