module TypeLet.Plugin.Substitution (
    letsToSubst
  , Cycle(..)
  , formatLetCycle
  ) where

import Data.Bifunctor
import Data.Foldable (toList)
import Data.List (intersperse)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe (mapMaybe)

import qualified Data.Graph as G

import TypeLet.Plugin.Constraints
import TypeLet.Plugin.GhcTcPluginAPI

-- | Construct idempotent substitution
--
-- TODO: Not entirely sure if this is correct, might be too simplistic and/or an
-- abuse of the ghc API; it is a /whole/ lot simpler than @niFixTCvSubst@, which
-- is disconcerning. However, it seems to work for the examples so far; perhaps
-- our use case is simpler? Needs more thought.
letsToSubst ::
     [GenLocated CtLoc CLet]
  -> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst :: [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry HasDebugCallStack => [TyVar] -> [Type] -> TCvSubst
zipTvSubst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
go []) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
inorder
  where
    go :: [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
    go :: [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
go [(TyVar, Type)]
acc []         = [(TyVar, Type)]
acc
    go [(TyVar, Type)]
acc ((TyVar
x, Type
t):[(TyVar, Type)]
s) = [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
go ((TyVar
x, Type
t) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (TyVar -> Type -> Type -> Type
subst1 TyVar
x Type
t)) [(TyVar, Type)]
acc) [(TyVar, Type)]
s

    subst1 :: TyVar -> Type -> Type -> Type
    subst1 :: TyVar -> Type -> Type -> Type
subst1 TyVar
x Type
t = HasCallStack => [TyVar] -> [Type] -> Type -> Type
substTyWith [TyVar
x] [Type
t]

-- | Order the assignments
--
-- Suppose we have two assignments
--
-- > x := xT
-- > y := yT    where  x in (freevars yT)
--
-- Then the substitution should map @y@ to @(x := xT) yT@. We do this by
-- constructing the substitution in order, adding assignments one by one,
-- applying them to all assignments already in the accumulated substitution
-- as we go. In this example, this means adding @y := yT@ /first/, so that
-- we can apply @x := xT@ later (note that recursive definitions are
-- impossible in our use case).
--
-- In order to find the right order, we construct a graph of assignments. To
-- continue with our example, this graph will contain an edge
--
-- > (y := yT) -----> (x := xT)
--
-- The required assignment ordering is then obtained by topological sort.
inorder ::
     [GenLocated CtLoc CLet]
  -> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
inorder :: [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) [(TyVar, Type)]
inorder [GenLocated CtLoc CLet]
lets =
    case forall key node. Ord key => [(node, key, [key])] -> [Cycle node]
cycles [(GenLocated CtLoc CLet, TyVar, [TyVar])]
edges of
      Cycle (GenLocated CtLoc CLet)
c:[Cycle (GenLocated CtLoc CLet)]
_ -> forall a b. a -> Either a b
Left Cycle (GenLocated CtLoc CLet)
c
      []  -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ [
          (TyVar
x, Type
t)
        | (L CtLoc
_ (CLet Type
_ TyVar
x Type
t), TyVar
_, [TyVar]
_) <- forall a b. (a -> b) -> [a] -> [b]
map Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar])
nodeFromVertex forall a b. (a -> b) -> a -> b
$ Graph -> [Vertex]
G.topSort Graph
graph
        ]
  where
    graph          :: G.Graph
    nodeFromVertex :: G.Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar])
    _vertexFromKey :: TyVar -> Maybe G.Vertex
    (Graph
graph, Vertex -> (GenLocated CtLoc CLet, TyVar, [TyVar])
nodeFromVertex, TyVar -> Maybe Vertex
_vertexFromKey) = forall key node.
Ord key =>
[(node, key, [key])]
-> (Graph, Vertex -> (node, key, [key]), key -> Maybe Vertex)
G.graphFromEdges [(GenLocated CtLoc CLet, TyVar, [TyVar])]
edges

    edges :: [(GenLocated CtLoc CLet, TyVar, [TyVar])]
    edges :: [(GenLocated CtLoc CLet, TyVar, [TyVar])]
edges = [
        ( GenLocated CtLoc CLet
l
        , TyVar
y
        , [ TyVar
x
          | L CtLoc
_ (CLet Type
_ TyVar
x Type
_) <- [GenLocated CtLoc CLet]
lets
          , TyVar
x TyVar -> VarSet -> Bool
`elemVarSet` (Type -> VarSet
tyCoVarsOfType Type
yT)
          ]
        )
      | l :: GenLocated CtLoc CLet
l@(L CtLoc
_ (CLet Type
_ TyVar
y Type
yT)) <- [GenLocated CtLoc CLet]
lets -- variables name match description above
      ]

-- | Format a cycle
--
-- We (arbitrarily) pick the first 'CLet' in the cycle for the location of the
-- error.
formatLetCycle ::
     Cycle (GenLocated CtLoc CLet)
  -> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle :: Cycle (GenLocated CtLoc CLet)
-> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle (Cycle vs :: NonEmpty (GenLocated CtLoc CLet)
vs@(L CtLoc
l CLet
_ :| [GenLocated CtLoc CLet]
_)) = forall l e. l -> e -> GenLocated l e
L CtLoc
l forall a b. (a -> b) -> a -> b
$
          String -> TcPluginErrorMessage
Txt String
"Cycle in type-level let bindings: "
      TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
:|: ( forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 TcPluginErrorMessage
-> TcPluginErrorMessage -> TcPluginErrorMessage
(:|:)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a] -> [a]
intersperse (String -> TcPluginErrorMessage
Txt String
", ")
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (\(L CtLoc
_ CLet
l') -> CLet -> TcPluginErrorMessage
formatCLet CLet
l')
          forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (GenLocated CtLoc CLet)
vs
          )

{-------------------------------------------------------------------------------
  Auxiliary
-------------------------------------------------------------------------------}

-- | Cycle in a graph
data Cycle a = Cycle (NonEmpty a)

cycles :: Ord key => [(node, key, [key])] -> [Cycle node]
cycles :: forall key node. Ord key => [(node, key, [key])] -> [Cycle node]
cycles = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall a. SCC a -> Maybe (Cycle a)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall key node. Ord key => [(node, key, [key])] -> [SCC node]
G.stronglyConnComp
  where
    aux :: G.SCC a -> Maybe (Cycle a)
    aux :: forall a. SCC a -> Maybe (Cycle a)
aux (G.AcyclicSCC a
_) = forall a. Maybe a
Nothing
    aux (G.CyclicSCC [a]
vs) =
        case [a]
vs of
          a
v:[a]
vs'      -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> Cycle a
Cycle (a
v forall a. a -> [a] -> NonEmpty a
:| [a]
vs')
          [a]
_otherwise -> forall a. Maybe a
Nothing