-- Terms and substitutions, implemented using flatterms. -- This module contains all the low-level icky bits -- and provides primitives for building higher-level stuff. {-# LANGUAGE BangPatterns, CPP, PatternGuards, PatternSynonyms, ViewPatterns, RecordWildCards, GeneralizedNewtypeDeriving, RankNTypes, MagicHash, UnboxedTuples, MultiParamTypeClasses, FlexibleInstances, FunctionalDependencies, ScopedTypeVariables #-} module Twee.Term.Core where #include "errors.h" import Data.Primitive import Control.Monad.ST.Strict import Data.Bits import Data.Int import GHC.Types(Int(..)) import GHC.Prim import GHC.ST hiding (liftST) import Data.Ord -------------------------------------------------------------------------------- -- Symbols. A symbol is a single function or variable in a flatterm. -------------------------------------------------------------------------------- data Symbol = Symbol { -- Is it a function? isFun :: Bool, -- What is its number? index :: Int, -- What is the size of the term rooted at this symbol? size :: Int } instance Show Symbol where show Symbol{..} | isFun = show (MkFun index) ++ "=" ++ show size | otherwise = show (MkVar index) -- Convert symbols to/from Int64 for storage in flatterms. -- The encoding: -- * bits 0-30: size -- * bit 31: 0 (variable) or 1 (function) -- * bits 32-63: index {-# INLINE toSymbol #-} toSymbol :: Int64 -> Symbol toSymbol n = Symbol (testBit n 31) (fromIntegral (n `unsafeShiftR` 32)) (fromIntegral (n .&. 0x7fffffff)) {-# INLINE fromSymbol #-} fromSymbol :: Symbol -> Int64 fromSymbol Symbol{..} | index < 0 = ERROR("negative symbol index") fromSymbol Symbol{..} = fromIntegral size + fromIntegral index `unsafeShiftL` 32 + fromIntegral (fromEnum isFun) `unsafeShiftL` 31 -------------------------------------------------------------------------------- -- Flatterms, or rather lists of terms. -------------------------------------------------------------------------------- -- A TermList is a slice of an unboxed array of symbols. data TermList f = TermList { low :: {-# UNPACK #-} !Int, high :: {-# UNPACK #-} !Int, array :: {-# UNPACK #-} !ByteArray } at :: Int -> TermList f -> Term f at n (TermList lo hi arr) | n < 0 || n + lo >= hi = ERROR("term index out of bounds") | otherwise = case TermList (lo+n) hi arr of Cons t _ -> t {-# INLINE lenList #-} -- The length (number of symbols in) a flatterm. lenList :: TermList f -> Int lenList (TermList low high _) = high - low -- A term is a special case of a termlist. -- We store it as the termlist together with the root symbol. data Term f = Term { root :: {-# UNPACK #-} !Int64, termlist :: {-# UNPACK #-} !(TermList f) } instance Eq (Term f) where x == y = termlist x == termlist y instance Ord (Term f) where compare = comparing termlist -- Pattern synonyms for termlists: -- * Empty :: TermList f -- Empty is the empty termlist. -- * Cons t ts :: Term f -> TermList f -> TermList f -- Cons t ts is the termlist t:ts. -- * ConsSym t ts :: Term f -> TermList f -> TermList f -- ConsSym t ts is like Cons t ts but ts also includes t's children -- (operationally, ts seeks one term to the right in the termlist). -- * UnsafeCons/UnsafeConsSym: like Cons and ConsSym but don't check -- that the termlist is non-empty. pattern Empty <- (patHead -> Nothing) pattern Cons t ts <- (patHead -> Just (t, _, ts)) pattern ConsSym t ts <- (patHead -> Just (t, ts, _)) pattern UnsafeCons t ts <- (unsafePatHead -> Just (t, _, ts)) pattern UnsafeConsSym t ts <- (unsafePatHead -> Just (t, ts, _)) {-# INLINE unsafePatHead #-} unsafePatHead :: TermList f -> Maybe (Term f, TermList f, TermList f) unsafePatHead TermList{..} = Just (Term x (TermList low (low+size) array), TermList (low+1) high array, TermList (low+size) high array) where x = indexByteArray array low Symbol{..} = toSymbol x {-# INLINE patHead #-} patHead :: TermList f -> Maybe (Term f, TermList f, TermList f) patHead t@TermList{..} | low == high = Nothing | otherwise = unsafePatHead t -- Pattern synonyms for single terms. -- * Var :: Var -> Term f -- * Fun :: Fun f -> TermList f -> Term f newtype Fun f = MkFun Int deriving Eq newtype Var = MkVar Int deriving (Eq, Ord, Enum) instance Show (Fun f) where show (MkFun x) = "f" ++ show x instance Show Var where show (MkVar x) = "x" ++ show x pattern Var x <- Term (patRoot -> Left x) _ pattern Fun f ts <- Term (patRoot -> Right (f :: Fun f)) (patNext -> (ts :: TermList f)) {-# INLINE patRoot #-} patRoot :: Int64 -> Either Var (Fun f) patRoot root | isFun = Right (MkFun index) | otherwise = Left (MkVar index) where Symbol{..} = toSymbol root {-# INLINE patNext #-} patNext :: TermList f -> TermList f patNext (TermList lo hi array) = TermList (lo+1) hi array -- Convert a term to a termlist. {-# INLINE singleton #-} singleton :: Term f -> TermList f singleton Term{..} = termlist -- We can implement equality almost without access to the -- internal representation of the termlists, but we cheat by -- comparing Int64s instead of Symbols. instance Eq (TermList f) where {-# INLINE (==) #-} t == u = lenList t == lenList u && eqSameLength t u eqSameLength :: TermList f -> TermList f -> Bool eqSameLength Empty !_ = True eqSameLength (ConsSym s1 t) (UnsafeConsSym s2 u) = root s1 == root s2 && eqSameLength t u instance Ord (TermList f) where {-# INLINE compare #-} compare t u = case compare (lenList t) (lenList u) of EQ -> compareContents t u x -> x compareContents :: TermList f -> TermList f -> Ordering compareContents Empty !_ = EQ compareContents (ConsSym s1 t) (UnsafeConsSym s2 u) = case compare (root s1) (root s2) of EQ -> compareContents t u x -> x -------------------------------------------------------------------------------- -- Building terms imperatively. -------------------------------------------------------------------------------- -- A monad for building terms. newtype Builder f = Builder { unBuilder :: -- Takes: the term array and size, and current position in the term. -- Returns the final position, which may be out of bounds. forall s. Builder1 s } type Builder1 s = State# s -> MutableByteArray# s -> Int# -> Int# -> (# State# s, Int# #) instance Monoid (Builder f) where {-# INLINE mempty #-} mempty = Builder built {-# INLINE mappend #-} Builder m1 `mappend` Builder m2 = Builder (m1 `then_` m2) {-# INLINE buildTermList #-} buildTermList :: Builder f -> TermList f buildTermList builder = runST $ do let Builder m = builder loop n@(I# n#) = do MutableByteArray marray# <- newByteArray (n * sizeOf (fromSymbol __)) n' <- ST $ \s -> case m s marray# n# 0# of (# s, n# #) -> (# s, I# n# #) if n' <= n then do !array <- unsafeFreezeByteArray (MutableByteArray marray#) return (TermList 0 n' array) else loop (n'*2) loop 16 {-# INLINE getArray #-} getArray :: (MutableByteArray s -> Builder1 s) -> Builder1 s getArray k = \s array n i -> k (MutableByteArray array) s array n i {-# INLINE getSize #-} getSize :: (Int -> Builder1 s) -> Builder1 s getSize k = \s array n i -> k (I# n) s array n i {-# INLINE getIndex #-} getIndex :: (Int -> Builder1 s) -> Builder1 s getIndex k = \s array n i -> k (I# i) s array n i {-# INLINE putIndex #-} putIndex :: Int -> Builder1 s putIndex (I# i) = \s _ _ _ -> (# s, i #) {-# INLINE liftST #-} liftST :: ST s () -> Builder1 s liftST (ST m) = \s _ _ i -> case m s of (# s, () #) -> (# s, i #) {-# INLINE built #-} built :: Builder1 s built = \s _ _ i -> (# s, i #) {-# INLINE then_ #-} then_ :: Builder1 s -> Builder1 s -> Builder1 s then_ m1 m2 = \s array n i -> case m1 s array n i of (# s, i #) -> m2 s array n i {-# INLINE checked #-} checked :: Int -> Builder1 s -> Builder1 s checked j m = getSize $ \n -> getIndex $ \i -> if i + j <= n then m else putIndex (i + j) {-# INLINE emitSymbolBuilder #-} emitSymbolBuilder :: Symbol -> Builder f -> Builder f emitSymbolBuilder x inner = Builder $ checked 1 $ getArray $ \array -> getIndex $ \n -> putIndex (n+1) `then_` unBuilder inner `then_` getIndex (\m -> liftST $ writeByteArray array n (fromSymbol x { size = m - n })) -- Emit a function symbol. -- The second argument is called to emit the function's arguments. {-# INLINE emitFun #-} emitFun :: Fun f -> Builder f -> Builder f emitFun (MkFun f) inner = emitSymbolBuilder (Symbol True f 0) inner -- Emit a variable. {-# INLINE emitVar #-} emitVar :: Var -> Builder f emitVar (MkVar x) = emitSymbolBuilder (Symbol False x 1) mempty -- Emit a whole termlist. {-# INLINE emitTermList #-} emitTermList :: TermList f -> Builder f emitTermList (TermList lo hi array) = Builder $ checked (hi-lo) $ getArray $ \marray -> getIndex $ \n -> let k = sizeOf (fromSymbol __) in liftST (copyByteArray marray (n*k) array (lo*k) ((hi-lo)*k)) `then_` putIndex (n + hi-lo)