module Language.SMTLib2.Internals.Type.Struct where import Language.SMTLib2.Internals.Type.Nat import Language.SMTLib2.Internals.Type.List (List(..)) import qualified Language.SMTLib2.Internals.Type.List as List import Prelude hiding (mapM,insert) import Data.GADT.Compare import Data.GADT.Show import Data.Functor.Identity data Tree a = Leaf a | Node [Tree a] data Struct e tp where Singleton :: e t -> Struct e (Leaf t) Struct :: List (Struct e) ts -> Struct e (Node ts) type family Index (struct :: Tree a) (idx :: [Nat]) :: Tree a where Index x '[] = x Index (Node xs) (n ': ns) = Index (List.Index xs n) ns type family ElementIndex (struct :: Tree a) (idx :: [Nat]) :: a where ElementIndex (Leaf x) '[] = x ElementIndex (Node xs) (n ': ns) = ElementIndex (List.Index xs n) ns type family Insert (struct :: Tree a) (idx :: [Nat]) (el :: Tree a) :: Tree a where Insert x '[] y = y Insert (Node xs) (n ': ns) y = Node (List.Insert xs n (Insert (List.Index xs n) ns y)) type family Remove (struct :: Tree a) (idx :: [Nat]) :: Tree a where Remove (Node xs) '[n] = Node (List.Remove xs n) Remove (Node xs) (n1 ': n2 ': ns) = Node (List.Insert xs n1 (Remove (List.Index xs n1) (n2 ': ns))) type family Size (struct :: Tree a) :: Nat where Size (Leaf x) = S Z Size (Node '[]) = Z Size (Node (x ': xs)) = (Size x) + (Size (Node xs)) access :: Monad m => Struct e tp -> List Natural idx -> (e (ElementIndex tp idx) -> m (a,e (ElementIndex tp idx))) -> m (a,Struct e tp) access (Singleton x) Nil f = do (res,nx) <- f x return (res,Singleton nx) access (Struct xs) (n ::: ns) f = do (res,nxs) <- List.access' xs n (\x -> access x ns f) return (res,Struct nxs) accessElement :: Monad m => Struct e tp -> List Natural idx -> (e (ElementIndex tp idx) -> m (a,e ntp)) -> m (a,Struct e (Insert tp idx (Leaf ntp))) accessElement (Singleton x) Nil f = do (res,nx) <- f x return (res,Singleton nx) accessElement (Struct xs) (n ::: ns) f = do (res,nxs) <- List.access xs n (\x -> accessElement x ns f) return (res,Struct nxs) index :: Struct e tp -> List Natural idx -> Struct e (Index tp idx) index x Nil = x index (Struct xs) (n ::: ns) = index (List.index xs n) ns elementIndex :: Struct e tp -> List Natural idx -> e (ElementIndex tp idx) elementIndex (Singleton x) Nil = x elementIndex (Struct xs) (n ::: ns) = elementIndex (List.index xs n) ns insert :: Struct e tps -> List Natural idx -> Struct e tp -> Struct e (Insert tps idx tp) insert x Nil y = y insert (Struct xs) (n ::: ns) y = Struct (List.insert xs n (insert (List.index xs n) ns y)) remove :: Struct e tps -> List Natural idx -> Struct e (Remove tps idx) remove (Struct xs) (n ::: Nil) = Struct (List.remove xs n) remove (Struct xs) (n1 ::: n2 ::: ns) = Struct (List.insert xs n1 (remove (List.index xs n1) (n2 ::: ns))) mapM :: Monad m => (forall x. e x -> m (e' x)) -> Struct e tps -> m (Struct e' tps) mapM f (Singleton x) = do nx <- f x return (Singleton nx) mapM f (Struct xs) = do nxs <- List.mapM (mapM f) xs return (Struct nxs) mapIndexM :: Monad m => (forall idx. List Natural idx -> e (ElementIndex tps idx) -> m (e' (ElementIndex tps idx))) -> Struct e tps -> m (Struct e' tps) mapIndexM f (Singleton x) = do nx <- f Nil x return (Singleton nx) mapIndexM f (Struct xs) = do nxs <- List.mapIndexM (\n -> mapIndexM (\ns -> f (n ::: ns))) xs return (Struct nxs) map :: (forall x. e x -> e' x) -> Struct e tps -> Struct e' tps map f = runIdentity . (mapM (return.f)) size :: Struct e tps -> Natural (Size tps) size (Singleton x) = Succ Zero size (Struct Nil) = Zero size (Struct (x ::: xs)) = naturalAdd (size x) (size (Struct xs)) flatten :: Monad m => (forall x. e x -> m a) -> ([a] -> m a) -> Struct e tps -> m a flatten f _ (Singleton x) = f x flatten f g (Struct xs) = do nxs <- List.toList (flatten f g) xs g nxs flattenIndex :: Monad m => (forall idx. List Natural idx -> e (ElementIndex tps idx) -> m a) -> ([a] -> m a) -> Struct e tps -> m a flattenIndex f _ (Singleton x) = f Nil x flattenIndex f g (Struct xs) = do nxs <- List.toListIndex (\n x -> flattenIndex (\idx -> f (n ::: idx)) g x) xs g nxs zipWithM :: Monad m => (forall x. e1 x -> e2 x -> m (e3 x)) -> Struct e1 tps -> Struct e2 tps -> m (Struct e3 tps) zipWithM f (Singleton x) (Singleton y) = do z <- f x y return (Singleton z) zipWithM f (Struct xs) (Struct ys) = do zs <- List.zipWithM (zipWithM f) xs ys return (Struct zs) zipFlatten :: Monad m => (forall x. e1 x -> e2 x -> m a) -> ([a] -> m a) -> Struct e1 tps -> Struct e2 tps -> m a zipFlatten f _ (Singleton x) (Singleton y) = f x y zipFlatten f g (Struct xs) (Struct ys) = do zs <- List.zipToListM (zipFlatten f g) xs ys g zs instance GEq e => Eq (Struct e tps) where (==) (Singleton x) (Singleton y) = case geq x y of Just Refl -> True Nothing -> False (==) (Struct xs) (Struct ys) = xs==ys instance GEq e => GEq (Struct e) where geq (Singleton x) (Singleton y) = do Refl <- geq x y return Refl geq (Struct xs) (Struct ys) = do Refl <- geq xs ys return Refl geq _ _ = Nothing instance GCompare e => Ord (Struct e tps) where compare (Singleton x) (Singleton y) = case gcompare x y of GEQ -> EQ GLT -> LT GGT -> GT compare (Struct xs) (Struct ys) = compare xs ys instance GCompare e => GCompare (Struct e) where gcompare (Singleton x) (Singleton y) = case gcompare x y of GEQ -> GEQ GLT -> GLT GGT -> GGT gcompare (Singleton _) _ = GLT gcompare _ (Singleton _) = GGT gcompare (Struct xs) (Struct ys) = case gcompare xs ys of GEQ -> GEQ GLT -> GLT GGT -> GGT instance GShow e => Show (Struct e tps) where showsPrec p (Singleton x) = gshowsPrec p x showsPrec p (Struct xs) = showParen (p>10) $ showString "Struct " . showsPrec 11 xs instance GShow e => GShow (Struct e) where gshowsPrec = showsPrec