module Twee.Term.Core where
import Data.Primitive(sizeOf)
#ifdef BOUNDS_CHECKS
import Data.Primitive.ByteArray.Checked
#else
import Data.Primitive.ByteArray
#endif
import Control.Monad.ST.Strict
import Data.Bits
import Data.Int
import GHC.Types(Int(..))
import GHC.Prim
import GHC.ST hiding (liftST)
import Data.Ord
import Twee.Label
import Data.Typeable
data Symbol =
Symbol {
isFun :: Bool,
index :: Int,
size :: Int }
instance Show Symbol where
show Symbol{..}
| isFun = show (F index) ++ "=" ++ show size
| otherwise = show (V index)
toSymbol :: Int64 -> Symbol
toSymbol n =
Symbol (testBit n 31)
(fromIntegral (n `unsafeShiftR` 32))
(fromIntegral (n .&. 0x7fffffff))
fromSymbol :: Symbol -> Int64
fromSymbol Symbol{..} =
fromIntegral size +
fromIntegral index `unsafeShiftL` 32 +
fromIntegral (fromEnum isFun) `unsafeShiftL` 31
data TermList f =
TermList {
low :: !Int,
high :: !Int,
array :: !ByteArray }
at :: Int -> TermList f -> Term f
at n (TermList lo hi arr)
| n < 0 || lo+n >= hi = error "term index out of bounds"
| otherwise =
case TermList (lo+n) hi arr of
UnsafeCons t _ -> t
lenList :: TermList f -> Int
lenList (TermList low high _) = high low
data Term f =
Term {
root :: !Int64,
termlist :: !(TermList f) }
instance Eq (Term f) where
x == y = termlist x == termlist y
instance Ord (Term f) where
compare = comparing termlist
pattern Empty :: TermList f
pattern Empty <- (patHead -> Nothing)
pattern Cons :: Term f -> TermList f -> TermList f
pattern Cons t ts <- (patHead -> Just (t, _, ts))
pattern UnsafeCons :: Term f -> TermList f -> TermList f
pattern UnsafeCons t ts <- (unsafePatHead -> Just (t, _, ts))
pattern ConsSym :: Term f -> TermList f -> TermList f
pattern ConsSym t ts <- (patHead -> Just (t, ts, _))
pattern UnsafeConsSym :: Term f -> TermList f -> TermList f
pattern UnsafeConsSym t ts <- (unsafePatHead -> Just (t, ts, _))
unsafePatHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
unsafePatHead TermList{..} =
Just (Term x (TermList low (low+size) array),
TermList (low+1) high array,
TermList (low+size) high array)
where
!x = indexByteArray array low
Symbol{..} = toSymbol x
patHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
patHead t@TermList{..}
| low == high = Nothing
| otherwise = unsafePatHead t
newtype Fun f =
F {
fun_id :: Int }
instance Eq (Fun f) where
f == g = fun_id f == fun_id g
instance Ord (Fun f) where
compare = comparing fun_id
fun :: (Ord f, Typeable f) => f -> Fun f
fun f = F (fromIntegral (labelNum (label f)))
fun_value :: Fun f -> f
fun_value f = find (unsafeMkLabel (fromIntegral (fun_id f)))
newtype Var =
V {
var_id :: Int } deriving (Eq, Ord, Enum)
instance Show (Fun f) where show f = "f" ++ show (fun_id f)
instance Show Var where show x = "x" ++ show (var_id x)
pattern Var :: Var -> Term f
pattern Var x <- (patTerm -> Left x)
pattern App :: Fun f -> TermList f -> Term f
pattern App f ts <- (patTerm -> Right (f, ts))
patTerm :: Term f -> Either Var (Fun f, TermList f)
patTerm t@Term{..}
| isFun = Right (F index, ts)
| otherwise = Left (V index)
where
Symbol{..} = toSymbol root
!(UnsafeConsSym _ ts) = singleton t
singleton :: Term f -> TermList f
singleton Term{..} = termlist
instance Eq (TermList f) where
t == u = eqTermList t u
eqTermList :: TermList f -> TermList f -> Bool
eqTermList
(TermList (I# low1) (I# high1) (ByteArray array1))
(TermList (I# low2) (I# high2) (ByteArray array2)) =
weqTermList low1 high1 array1 low2 high2 array2
weqTermList ::
Int# -> Int# -> ByteArray# ->
Int# -> Int# -> ByteArray# ->
Bool
weqTermList low1 high1 array1 low2 high2 array2 =
lenList t == lenList u && eqSameLength t u
where
t = TermList (I# low1) (I# high1) (ByteArray array1)
u = TermList (I# low2) (I# high2) (ByteArray array2)
eqSameLength Empty !_ = True
eqSameLength (ConsSym s1 t) (UnsafeConsSym s2 u) =
root s1 == root s2 && eqSameLength t u
instance Ord (TermList f) where
compare t u =
case compare (lenList t) (lenList u) of
EQ -> compareContents t u
x -> x
compareContents :: TermList f -> TermList f -> Ordering
compareContents Empty !_ = EQ
compareContents (ConsSym s1 t) (UnsafeConsSym s2 u) =
case compare (root s1) (root s2) of
EQ -> compareContents t u
x -> x
newtype Builder f =
Builder {
unBuilder ::
forall s. Builder1 s f }
type Builder1 s f = State# s -> MutableByteArray# s -> Int# -> Int# -> (# State# s, Int# #)
instance Monoid (Builder f) where
mempty = Builder built
Builder m1 `mappend` Builder m2 = Builder (m1 `then_` m2)
buildTermList :: Builder f -> TermList f
buildTermList builder = runST $ do
let
Builder m = builder
loop n@(I# n#) = do
MutableByteArray mbytearray# <-
newByteArray (n * sizeOf (fromSymbol undefined))
n' <-
ST $ \s ->
case m s mbytearray# n# 0# of
(# s, n# #) -> (# s, I# n# #)
if n' <= n then do
!bytearray <- unsafeFreezeByteArray (MutableByteArray mbytearray#)
return (TermList 0 n' bytearray)
else loop (n'*2)
loop 32
getByteArray :: (MutableByteArray s -> Builder1 s f) -> Builder1 s f
getByteArray k = \s bytearray n i -> k (MutableByteArray bytearray) s bytearray n i
getSize :: (Int -> Builder1 s f) -> Builder1 s f
getSize k = \s bytearray n i -> k (I# n) s bytearray n i
getIndex :: (Int -> Builder1 s f) -> Builder1 s f
getIndex k = \s bytearray n i -> k (I# i) s bytearray n i
putIndex :: Int -> Builder1 s f
putIndex (I# i) = \s _ _ _ -> (# s, i #)
liftST :: ST s () -> Builder1 s f
liftST (ST m) =
\s _ _ i ->
case m s of
(# s, () #) -> (# s, i #)
built :: Builder1 s f
built = \s _ _ i -> (# s, i #)
then_ :: Builder1 s f -> Builder1 s f -> Builder1 s f
then_ m1 m2 =
\s bytearray n i ->
case m1 s bytearray n i of
(# s, i #) -> m2 s bytearray n i
checked :: Int -> Builder1 s f -> Builder1 s f
checked j m =
getSize $ \n ->
getIndex $ \i ->
if i + j <= n then m else putIndex (i + j)
emitSymbolBuilder :: Symbol -> Builder f -> Builder f
emitSymbolBuilder x inner =
Builder $ checked 1 $
getByteArray $ \bytearray ->
getIndex $ \n ->
putIndex (n+1) `then_`
unBuilder inner `then_`
getIndex (\m ->
liftST $ writeByteArray bytearray n (fromSymbol x { size = m n }))
emitApp :: Fun f -> Builder f -> Builder f
emitApp (F n) inner = emitSymbolBuilder (Symbol True n 0) inner
emitVar :: Var -> Builder f
emitVar x = emitSymbolBuilder (Symbol False (var_id x) 1) mempty
emitTermList :: TermList f -> Builder f
emitTermList (TermList lo hi array) =
Builder $ checked (hilo) $
getByteArray $ \mbytearray ->
getIndex $ \n ->
let k = sizeOf (fromSymbol undefined) in
liftST (copyByteArray mbytearray (n*k) array (lo*k) ((hilo)*k)) `then_`
putIndex (n + hilo)
isSubtermOfList :: Term f -> TermList f -> Bool
isSubtermOfList t u =
isSubArrayOf (singleton t) u
isSubArrayOf :: TermList f -> TermList f -> Bool
isSubArrayOf t u =
lenList t <= lenList u && (here t u || next t u)
where
here Empty _ = True
here (ConsSym s1 t) (UnsafeConsSym s2 u) =
root s1 == root s2 && here t u
next t (UnsafeConsSym _ u) = isSubArrayOf t u
occursList :: Var -> TermList f -> Bool
occursList (V x) t = symbolOccursList (fromSymbol (Symbol False x 1)) t
symbolOccursList :: Int64 -> TermList f -> Bool
symbolOccursList !_ Empty = False
symbolOccursList n (ConsSym t ts) = root t == n || symbolOccursList n ts