{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Codec.Candid.TypTable where

import qualified Data.Map as M
import Control.Monad.State.Lazy
import Data.Void
import Data.Text.Prettyprint.Doc
import Data.DList (singleton, DList)
import Data.Graph
import Data.Foldable

import Codec.Candid.Types

data SeqDesc where
    SeqDesc :: forall k. (Pretty k, Ord k) => M.Map k (Type k) -> [Type k] -> SeqDesc

instance Pretty SeqDesc where
    pretty :: SeqDesc -> Doc ann
pretty (SeqDesc Map k (Type k)
m [Type k]
ts) = ([(k, Type k)], [Type k]) -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Map k (Type k) -> [(k, Type k)]
forall k a. Map k a -> [(k, a)]
M.toList Map k (Type k)
m, [Type k]
ts)

data Ref k f  = Ref k (f (Ref k f))

buildSeqDesc :: forall k. (Pretty k, Ord k) => [Type (Ref k Type)] -> SeqDesc
buildSeqDesc :: [Type (Ref k Type)] -> SeqDesc
buildSeqDesc [Type (Ref k Type)]
ts = Map k (Type k) -> [Type k] -> SeqDesc
forall k.
(Pretty k, Ord k) =>
Map k (Type k) -> [Type k] -> SeqDesc
SeqDesc Map k (Type k)
m [Type k]
ts'
  where
    ([Type k]
ts', Map k (Type k)
m) = State (Map k (Type k)) [Type k]
-> Map k (Type k) -> ([Type k], Map k (Type k))
forall s a. State s a -> s -> (a, s)
runState ((Type (Ref k Type) -> StateT (Map k (Type k)) Identity (Type k))
-> [Type (Ref k Type)] -> State (Map k (Type k)) [Type k]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Ref k Type -> StateT (Map k (Type k)) Identity k)
-> Type (Ref k Type) -> StateT (Map k (Type k)) Identity (Type k)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ref k Type -> StateT (Map k (Type k)) Identity k
go) [Type (Ref k Type)]
ts) Map k (Type k)
forall a. Monoid a => a
mempty

    go :: Ref k Type -> State (M.Map k (Type k)) k
    go :: Ref k Type -> StateT (Map k (Type k)) Identity k
go (Ref k
k Type (Ref k Type)
t) = do
        Bool
seen <- (Map k (Type k) -> Bool) -> StateT (Map k (Type k)) Identity Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (k -> Map k (Type k) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member k
k)
        Bool
-> StateT (Map k (Type k)) Identity ()
-> StateT (Map k (Type k)) Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
seen (StateT (Map k (Type k)) Identity ()
 -> StateT (Map k (Type k)) Identity ())
-> StateT (Map k (Type k)) Identity ()
-> StateT (Map k (Type k)) Identity ()
forall a b. (a -> b) -> a -> b
$ mdo
            (Map k (Type k) -> Map k (Type k))
-> StateT (Map k (Type k)) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (k -> Type k -> Map k (Type k) -> Map k (Type k)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k Type k
t')
            Type k
t' <- (Ref k Type -> StateT (Map k (Type k)) Identity k)
-> Type (Ref k Type) -> StateT (Map k (Type k)) Identity (Type k)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ref k Type -> StateT (Map k (Type k)) Identity k
go Type (Ref k Type)
t
            () -> StateT (Map k (Type k)) Identity ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        k -> StateT (Map k (Type k)) Identity k
forall (m :: * -> *) a. Monad m => a -> m a
return k
k

voidEmptyTypes :: SeqDesc -> SeqDesc
voidEmptyTypes :: SeqDesc -> SeqDesc
voidEmptyTypes (SeqDesc Map k (Type k)
m [Type k]
ts) = Map k (Type k) -> [Type k] -> SeqDesc
forall k.
(Pretty k, Ord k) =>
Map k (Type k) -> [Type k] -> SeqDesc
SeqDesc Map k (Type k)
m' [Type k]
ts
  where
    edges :: [(k, k, [k])]
edges = [ (k
k,k
k, DList k -> [k]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Type k -> DList k
forall k. Type k -> DList k
underRec Type k
t)) | (k
k,Type k
t) <- Map k (Type k) -> [(k, Type k)]
forall k a. Map k a -> [(k, a)]
M.toList Map k (Type k)
m ]
    sccs :: [SCC k]
sccs = [(k, k, [k])] -> [SCC k]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp [(k, k, [k])]
edges
    bad :: [k]
bad = [[k]] -> [k]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ [k]
xs | CyclicSCC [k]
xs <- [SCC k]
sccs ]
    m' :: Map k (Type k)
m' = (Map k (Type k) -> k -> Map k (Type k))
-> Map k (Type k) -> [k] -> Map k (Type k)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Map k (Type k)
m k
k -> k -> Type k -> Map k (Type k) -> Map k (Type k)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k Type k
forall a. Type a
EmptyT Map k (Type k)
m) Map k (Type k)
m [k]
bad


underRec :: Type k -> DList k
underRec :: Type k -> DList k
underRec (RefT k
x) = k -> DList k
forall a. a -> DList a
singleton k
x
underRec (RecT Fields k
fs) = ((FieldName, Type k) -> DList k) -> Fields k -> DList k
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Type k -> DList k
forall k. Type k -> DList k
underRec (Type k -> DList k)
-> ((FieldName, Type k) -> Type k)
-> (FieldName, Type k)
-> DList k
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FieldName, Type k) -> Type k
forall a b. (a, b) -> b
snd) Fields k
fs
underRec Type k
_ = DList k
forall a. Monoid a => a
mempty

tieKnot :: SeqDesc -> [Type Void]
tieKnot :: SeqDesc -> [Type Void]
tieKnot (SeqDesc Map k (Type k)
m ([Type k]
ts :: [Type k])) = [Type Void]
ts'
  where
    f :: k -> Type Void
    f :: k -> Type Void
f k
k = Map k (Type Void)
m' Map k (Type Void) -> k -> Type Void
forall k a. Ord k => Map k a -> k -> a
M.! k
k
    m' :: M.Map k (Type Void)
    m' :: Map k (Type Void)
m' = (Type k -> (k -> Type Void) -> Type Void
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= k -> Type Void
f) (Type k -> Type Void) -> Map k (Type k) -> Map k (Type Void)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map k (Type k)
m
    ts' :: [Type Void]
    ts' :: [Type Void]
ts' = (Type k -> (k -> Type Void) -> Type Void
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= k -> Type Void
f) (Type k -> Type Void) -> [Type k] -> [Type Void]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type k]
ts