module Twee.Term.Core where
#include "errors.h"
import Data.Primitive
import Control.Monad.ST.Strict
import Data.Bits
import Data.Int
import GHC.Types(Int(..))
import GHC.Prim
import GHC.ST hiding (liftST)
import Data.Ord
data Symbol =
Symbol {
isFun :: Bool,
index :: Int,
size :: Int }
instance Show Symbol where
show Symbol{..}
| isFun = show (MkFun index) ++ "=" ++ show size
| otherwise = show (MkVar index)
toSymbol :: Int64 -> Symbol
toSymbol n =
Symbol (testBit n 31)
(fromIntegral (n `unsafeShiftR` 32))
(fromIntegral (n .&. 0x7fffffff))
fromSymbol :: Symbol -> Int64
fromSymbol Symbol{..} | index < 0 = ERROR("negative symbol index")
fromSymbol Symbol{..} =
fromIntegral size +
fromIntegral index `unsafeShiftL` 32 +
fromIntegral (fromEnum isFun) `unsafeShiftL` 31
data TermList f =
TermList {
low :: !Int,
high :: !Int,
array :: !ByteArray }
at :: Int -> TermList f -> Term f
at n (TermList lo hi arr)
| n < 0 || n + lo >= hi = ERROR("term index out of bounds")
| otherwise =
case TermList (lo+n) hi arr of
Cons t _ -> t
lenList :: TermList f -> Int
lenList (TermList low high _) = high low
data Term f =
Term {
root :: !Int64,
termlist :: !(TermList f) }
instance Eq (Term f) where
x == y = termlist x == termlist y
instance Ord (Term f) where
compare = comparing termlist
pattern Empty <- (patHead -> Nothing)
pattern Cons t ts <- (patHead -> Just (t, _, ts))
pattern ConsSym t ts <- (patHead -> Just (t, ts, _))
pattern UnsafeCons t ts <- (unsafePatHead -> Just (t, _, ts))
pattern UnsafeConsSym t ts <- (unsafePatHead -> Just (t, ts, _))
unsafePatHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
unsafePatHead TermList{..} =
Just (Term x (TermList low (low+size) array),
TermList (low+1) high array,
TermList (low+size) high array)
where
x = indexByteArray array low
Symbol{..} = toSymbol x
patHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
patHead t@TermList{..}
| low == high = Nothing
| otherwise = unsafePatHead t
newtype Fun f = MkFun Int deriving Eq
newtype Var = MkVar Int deriving (Eq, Ord, Enum)
instance Show (Fun f) where show (MkFun x) = "f" ++ show x
instance Show Var where show (MkVar x) = "x" ++ show x
pattern Var x <- Term (patRoot -> Left x) _
pattern Fun f ts <- Term (patRoot -> Right (f :: Fun f)) (patNext -> (ts :: TermList f))
patRoot :: Int64 -> Either Var (Fun f)
patRoot root
| isFun = Right (MkFun index)
| otherwise = Left (MkVar index)
where
Symbol{..} = toSymbol root
patNext :: TermList f -> TermList f
patNext (TermList lo hi array) = TermList (lo+1) hi array
singleton :: Term f -> TermList f
singleton Term{..} = termlist
instance Eq (TermList f) where
t == u = lenList t == lenList u && eqSameLength t u
eqSameLength :: TermList f -> TermList f -> Bool
eqSameLength Empty !_ = True
eqSameLength (ConsSym s1 t) (UnsafeConsSym s2 u) =
root s1 == root s2 && eqSameLength t u
instance Ord (TermList f) where
compare t u =
case compare (lenList t) (lenList u) of
EQ -> compareContents t u
x -> x
compareContents :: TermList f -> TermList f -> Ordering
compareContents Empty !_ = EQ
compareContents (ConsSym s1 t) (UnsafeConsSym s2 u) =
case compare (root s1) (root s2) of
EQ -> compareContents t u
x -> x
newtype Builder f =
Builder {
unBuilder ::
forall s. Builder1 s }
type Builder1 s = State# s -> MutableByteArray# s -> Int# -> Int# -> (# State# s, Int# #)
instance Monoid (Builder f) where
mempty = Builder built
Builder m1 `mappend` Builder m2 = Builder (m1 `then_` m2)
buildTermList :: Builder f -> TermList f
buildTermList builder = runST $ do
let
Builder m = builder
loop n@(I# n#) = do
MutableByteArray marray# <-
newByteArray (n * sizeOf (fromSymbol __))
n' <-
ST $ \s ->
case m s marray# n# 0# of
(# s, n# #) -> (# s, I# n# #)
if n' <= n then do
!array <- unsafeFreezeByteArray (MutableByteArray marray#)
return (TermList 0 n' array)
else loop (n'*2)
loop 16
getArray :: (MutableByteArray s -> Builder1 s) -> Builder1 s
getArray k = \s array n i -> k (MutableByteArray array) s array n i
getSize :: (Int -> Builder1 s) -> Builder1 s
getSize k = \s array n i -> k (I# n) s array n i
getIndex :: (Int -> Builder1 s) -> Builder1 s
getIndex k = \s array n i -> k (I# i) s array n i
putIndex :: Int -> Builder1 s
putIndex (I# i) = \s _ _ _ -> (# s, i #)
liftST :: ST s () -> Builder1 s
liftST (ST m) =
\s _ _ i ->
case m s of
(# s, () #) -> (# s, i #)
built :: Builder1 s
built = \s _ _ i -> (# s, i #)
then_ :: Builder1 s -> Builder1 s -> Builder1 s
then_ m1 m2 =
\s array n i ->
case m1 s array n i of
(# s, i #) -> m2 s array n i
checked :: Int -> Builder1 s -> Builder1 s
checked j m =
getSize $ \n ->
getIndex $ \i ->
if i + j <= n then m else putIndex (i + j)
emitSymbolBuilder :: Symbol -> Builder f -> Builder f
emitSymbolBuilder x inner =
Builder $ checked 1 $
getArray $ \array ->
getIndex $ \n ->
putIndex (n+1) `then_`
unBuilder inner `then_`
getIndex (\m ->
liftST $ writeByteArray array n (fromSymbol x { size = m n }))
emitFun :: Fun f -> Builder f -> Builder f
emitFun (MkFun f) inner = emitSymbolBuilder (Symbol True f 0) inner
emitVar :: Var -> Builder f
emitVar (MkVar x) = emitSymbolBuilder (Symbol False x 1) mempty
emitTermList :: TermList f -> Builder f
emitTermList (TermList lo hi array) =
Builder $ checked (hilo) $
getArray $ \marray ->
getIndex $ \n ->
let k = sizeOf (fromSymbol __) in
liftST (copyByteArray marray (n*k) array (lo*k) ((hilo)*k)) `then_`
putIndex (n + hilo)