-- Terms and substitutions, implemented using flatterms.
-- This module contains all the low-level icky bits
-- and provides primitives for building higher-level stuff.
{-# LANGUAGE CPP, PatternSynonyms, ViewPatterns,
    MagicHash, UnboxedTuples, BangPatterns,
    RankNTypes, RecordWildCards, GeneralizedNewtypeDeriving, CPP #-}
{-# OPTIONS_GHC -O2 -fmax-worker-args=100 #-}
#ifdef USE_LLVM
{-# OPTIONS_GHC -fllvm #-}
#endif
module Twee.Term.Core where

import Data.Primitive(sizeOf)
#ifdef BOUNDS_CHECKS
import Data.Primitive.ByteArray.Checked
#else
import Data.Primitive.ByteArray
#endif
import Control.Monad.ST.Strict
import Data.Bits
import Data.Int
import GHC.Types(Int(..))
import GHC.Prim
import GHC.ST hiding (liftST)
import Data.Ord
import Data.Semigroup(Semigroup(..))

--------------------------------------------------------------------------------
-- Symbols. A symbol is a single function or variable in a flatterm.
--------------------------------------------------------------------------------

data Symbol =
  Symbol {
    -- Is it a function?
    Symbol -> Bool
isFun :: Bool,
    -- What is its number?
    Symbol -> Int
index :: Int,
    -- What is the size of the term rooted at this symbol?
    Symbol -> Int
size  :: Int }

instance Show Symbol where
  show :: Symbol -> String
show Symbol{Bool
Int
size :: Int
index :: Int
isFun :: Bool
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..}
    | Bool
isFun = String
"f" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
index String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
size
    | Bool
otherwise = Var -> String
forall a. Show a => a -> String
show (Int -> Var
V Int
index)

-- Convert symbols to/from Int64 for storage in flatterms.
-- The encoding:
--   * bits 0-30: size
--   * bit  31: 0 (variable) or 1 (function)
--   * bits 32-63: index
{-# INLINE toSymbol #-}
toSymbol :: Int64 -> Symbol
toSymbol :: Int64 -> Symbol
toSymbol Int64
n =
  Bool -> Int -> Int -> Symbol
Symbol (Int64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int64
n Int
31)
    (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
32))
    (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
n Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
0x7fffffff))

{-# INLINE fromSymbol #-}
fromSymbol :: Symbol -> Int64
fromSymbol :: Symbol -> Int64
fromSymbol Symbol{Bool
Int
size :: Int
index :: Int
isFun :: Bool
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..} =
  Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+
  Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+
  Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Bool -> Int
forall a. Enum a => a -> Int
fromEnum Bool
isFun) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
31

--------------------------------------------------------------------------------
-- Flatterms, or rather lists of terms.
--------------------------------------------------------------------------------

-- | @'TermList' f@ is a list of terms whose function symbols have type @f@.
-- It is either a 'Cons' or an 'Empty'. You can turn it into a @['Term' f]@
-- with 'Twee.Term.unpack'.

-- A TermList is a slice of an unboxed array of symbols.
data TermList f =
  TermList {
    TermList f -> Int
low   :: {-# UNPACK #-} !Int,
    TermList f -> Int
high  :: {-# UNPACK #-} !Int,
    TermList f -> ByteArray
array :: {-# UNPACK #-} !ByteArray }

-- | Index into a termlist.
at :: Int -> TermList f -> Term f
at :: Int -> TermList f -> Term f
at Int
n TermList f
t
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| TermList f -> Int
forall f. TermList f -> Int
low TermList f
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TermList f -> Int
forall f. TermList f -> Int
high TermList f
t = String -> Term f
forall a. HasCallStack => String -> a
error String
"term index out of bounds"
  | Bool
otherwise = Int -> TermList f -> Term f
forall f. Int -> TermList f -> Term f
unsafeAt Int
n TermList f
t

-- | Index into a termlist, without bounds checking.
unsafeAt :: Int -> TermList f -> Term f
unsafeAt :: Int -> TermList f -> Term f
unsafeAt Int
n (TermList Int
lo Int
hi ByteArray
arr) =
  case Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList (Int
loInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n) Int
hi ByteArray
arr of
    UnsafeCons Term f
t TermList f
_ -> Term f
t

{-# INLINE lenList #-}
-- | The length of (number of symbols in) a termlist.
lenList :: TermList f -> Int
lenList :: TermList f -> Int
lenList (TermList Int
low Int
high ByteArray
_) = Int
high Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
low

-- | @'Term' f@ is a term whose function symbols have type @f@.
-- It is either a 'Var' or an 'App'.

-- A term is a special case of a termlist.
-- We store it as the termlist together with the root symbol.
data Term f =
  Term {
    Term f -> Int64
root     :: {-# UNPACK #-} !Int64,
    Term f -> TermList f
termlist :: {-# UNPACK #-} !(TermList f) }

instance Eq (Term f) where
  Term f
x == :: Term f -> Term f -> Bool
== Term f
y = Term f -> TermList f
forall f. Term f -> TermList f
termlist Term f
x TermList f -> TermList f -> Bool
forall a. Eq a => a -> a -> Bool
== Term f -> TermList f
forall f. Term f -> TermList f
termlist Term f
y

instance Ord (Term f) where
  compare :: Term f -> Term f -> Ordering
compare = (Term f -> TermList f) -> Term f -> Term f -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing Term f -> TermList f
forall f. Term f -> TermList f
termlist

-- Pattern synonyms for termlists:
-- * Empty :: TermList f
--   Empty is the empty termlist.
-- * Cons t ts :: Term f -> TermList f -> TermList f
--   Cons t ts is the termlist t:ts.
-- * ConsSym t ts :: Term f -> TermList f -> TermList f
--   ConsSym t ts is like Cons t ts but ts also includes t's children
--   (operationally, ts seeks one term to the right in the termlist).
-- * UnsafeCons/UnsafeConsSym: like Cons and ConsSym but don't check
--   that the termlist is non-empty.

-- | Matches the empty termlist.
pattern Empty :: TermList f
pattern $mEmpty :: forall r f. TermList f -> (Void# -> r) -> (Void# -> r) -> r
Empty <- (patHead -> Nothing)

-- | Matches a non-empty termlist, unpacking it into head and tail.
pattern Cons :: Term f -> TermList f -> TermList f
pattern $mCons :: forall r f.
TermList f -> (Term f -> TermList f -> r) -> (Void# -> r) -> r
Cons t ts <- (patHead -> Just (t, _, ts))

{-# COMPLETE Empty, Cons #-}
{-# COMPLETE Empty, ConsSym #-}

-- | Like 'Cons', but does not check that the termlist is non-empty. Use only if
-- you are sure the termlist is non-empty.
pattern UnsafeCons :: Term f -> TermList f -> TermList f
pattern $mUnsafeCons :: forall r f.
TermList f -> (Term f -> TermList f -> r) -> (Void# -> r) -> r
UnsafeCons t ts <- (unsafePatHead -> (t, _, ts))

-- | Matches a non-empty termlist, unpacking it into head and
-- /everything except the root symbol of the head/.
-- Useful for iterating through terms one symbol at a time.
--
-- For example, if @ts@ is the termlist @[f(x,y), g(z)]@,
-- then @let ConsSym u us = ts@ results in the following bindings:
--
-- > u  = f(x,y)
-- > us = [x, y, g(z)]
pattern ConsSym :: Term f -> TermList f -> TermList f -> TermList f
pattern $mConsSym :: forall r f.
TermList f
-> (Term f -> TermList f -> TermList f -> r) -> (Void# -> r) -> r
ConsSym{TermList f -> Term f
hd, TermList f -> TermList f
tl, TermList f -> TermList f
rest} <- (patHead -> Just (hd, rest, tl))

-- | Like 'ConsSym', but does not check that the termlist is non-empty. Use only
-- if you are sure the termlist is non-empty.
pattern UnsafeConsSym :: Term f -> TermList f -> TermList f -> TermList f
pattern $mUnsafeConsSym :: forall r f.
TermList f
-> (Term f -> TermList f -> TermList f -> r) -> (Void# -> r) -> r
UnsafeConsSym{TermList f -> Term f
uhd, TermList f -> TermList f
utl, TermList f -> TermList f
urest} <- (unsafePatHead -> (uhd, urest, utl))

-- A helper for UnsafeCons/UnsafeConsSym.
{-# INLINE unsafePatHead #-}
unsafePatHead :: TermList f -> (Term f, TermList f, TermList f)
unsafePatHead :: TermList f -> (Term f, TermList f, TermList f)
unsafePatHead TermList{Int
ByteArray
array :: ByteArray
high :: Int
low :: Int
array :: forall f. TermList f -> ByteArray
high :: forall f. TermList f -> Int
low :: forall f. TermList f -> Int
..} =
  (Int64 -> TermList f -> Term f
forall f. Int64 -> TermList f -> Term f
Term Int64
x (Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList Int
low (Int
lowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
size) ByteArray
array),
   Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList (Int
lowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
high ByteArray
array,
   Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList (Int
lowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
size) Int
high ByteArray
array)
  where
    !x :: Int64
x = ByteArray -> Int -> Int64
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
array Int
low
    Symbol{Bool
Int
index :: Int
isFun :: Bool
size :: Int
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..} = Int64 -> Symbol
toSymbol Int64
x

-- A helper for Cons/ConsSym.
{-# INLINE patHead #-}
patHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
patHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
patHead t :: TermList f
t@TermList{Int
ByteArray
array :: ByteArray
high :: Int
low :: Int
array :: forall f. TermList f -> ByteArray
high :: forall f. TermList f -> Int
low :: forall f. TermList f -> Int
..}
  | Int
low Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
high = Maybe (Term f, TermList f, TermList f)
forall a. Maybe a
Nothing
  | Bool
otherwise = (Term f, TermList f, TermList f)
-> Maybe (Term f, TermList f, TermList f)
forall a. a -> Maybe a
Just (TermList f -> (Term f, TermList f, TermList f)
forall f. TermList f -> (Term f, TermList f, TermList f)
unsafePatHead TermList f
t)

-- Pattern synonyms for single terms.
-- * Var :: Var -> Term f
-- * App :: Fun f -> TermList f -> Term f

-- | A function symbol. @f@ is the underlying type of function symbols defined
-- by the user; @'Fun' f@ is an @f@ together with an automatically-generated unique number.
newtype Fun f =
  F {
    -- | The unique number of a 'Fun'. Must fit in 32 bits.
    Fun f -> Int
fun_id :: Int }
  deriving (Fun f -> Fun f -> Bool
(Fun f -> Fun f -> Bool) -> (Fun f -> Fun f -> Bool) -> Eq (Fun f)
forall f. Fun f -> Fun f -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Fun f -> Fun f -> Bool
$c/= :: forall f. Fun f -> Fun f -> Bool
== :: Fun f -> Fun f -> Bool
$c== :: forall f. Fun f -> Fun f -> Bool
Eq, Eq (Fun f)
Eq (Fun f)
-> (Fun f -> Fun f -> Ordering)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Fun f)
-> (Fun f -> Fun f -> Fun f)
-> Ord (Fun f)
Fun f -> Fun f -> Bool
Fun f -> Fun f -> Ordering
Fun f -> Fun f -> Fun f
forall f. Eq (Fun f)
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall f. Fun f -> Fun f -> Bool
forall f. Fun f -> Fun f -> Ordering
forall f. Fun f -> Fun f -> Fun f
min :: Fun f -> Fun f -> Fun f
$cmin :: forall f. Fun f -> Fun f -> Fun f
max :: Fun f -> Fun f -> Fun f
$cmax :: forall f. Fun f -> Fun f -> Fun f
>= :: Fun f -> Fun f -> Bool
$c>= :: forall f. Fun f -> Fun f -> Bool
> :: Fun f -> Fun f -> Bool
$c> :: forall f. Fun f -> Fun f -> Bool
<= :: Fun f -> Fun f -> Bool
$c<= :: forall f. Fun f -> Fun f -> Bool
< :: Fun f -> Fun f -> Bool
$c< :: forall f. Fun f -> Fun f -> Bool
compare :: Fun f -> Fun f -> Ordering
$ccompare :: forall f. Fun f -> Fun f -> Ordering
$cp1Ord :: forall f. Eq (Fun f)
Ord)

-- | A variable.
newtype Var =
  V {
    -- | The variable's number.
    -- Don't use huge variable numbers:
    -- they will be truncated to 32 bits when stored in a term.
    Var -> Int
var_id :: Int } deriving (Var -> Var -> Bool
(Var -> Var -> Bool) -> (Var -> Var -> Bool) -> Eq Var
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Var -> Var -> Bool
$c/= :: Var -> Var -> Bool
== :: Var -> Var -> Bool
$c== :: Var -> Var -> Bool
Eq, Eq Var
Eq Var
-> (Var -> Var -> Ordering)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Var)
-> (Var -> Var -> Var)
-> Ord Var
Var -> Var -> Bool
Var -> Var -> Ordering
Var -> Var -> Var
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Var -> Var -> Var
$cmin :: Var -> Var -> Var
max :: Var -> Var -> Var
$cmax :: Var -> Var -> Var
>= :: Var -> Var -> Bool
$c>= :: Var -> Var -> Bool
> :: Var -> Var -> Bool
$c> :: Var -> Var -> Bool
<= :: Var -> Var -> Bool
$c<= :: Var -> Var -> Bool
< :: Var -> Var -> Bool
$c< :: Var -> Var -> Bool
compare :: Var -> Var -> Ordering
$ccompare :: Var -> Var -> Ordering
$cp1Ord :: Eq Var
Ord, Int -> Var
Var -> Int
Var -> [Var]
Var -> Var
Var -> Var -> [Var]
Var -> Var -> Var -> [Var]
(Var -> Var)
-> (Var -> Var)
-> (Int -> Var)
-> (Var -> Int)
-> (Var -> [Var])
-> (Var -> Var -> [Var])
-> (Var -> Var -> [Var])
-> (Var -> Var -> Var -> [Var])
-> Enum Var
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Var -> Var -> Var -> [Var]
$cenumFromThenTo :: Var -> Var -> Var -> [Var]
enumFromTo :: Var -> Var -> [Var]
$cenumFromTo :: Var -> Var -> [Var]
enumFromThen :: Var -> Var -> [Var]
$cenumFromThen :: Var -> Var -> [Var]
enumFrom :: Var -> [Var]
$cenumFrom :: Var -> [Var]
fromEnum :: Var -> Int
$cfromEnum :: Var -> Int
toEnum :: Int -> Var
$ctoEnum :: Int -> Var
pred :: Var -> Var
$cpred :: Var -> Var
succ :: Var -> Var
$csucc :: Var -> Var
Enum)
instance Show Var where
  show :: Var -> String
show Var
x = String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Var -> Int
var_id Var
x)

-- | Matches a variable.
pattern Var :: Var -> Term f
pattern $mVar :: forall r f. Term f -> (Var -> r) -> (Void# -> r) -> r
Var x <- (patTerm -> Left x)

-- | Matches a function application.
pattern App :: Fun f -> TermList f -> Term f
pattern $mApp :: forall r f.
Term f -> (Fun f -> TermList f -> r) -> (Void# -> r) -> r
App f ts <- (patTerm -> Right (f, ts))

{-# COMPLETE Var, App #-}

-- A helper function for Var and App.
{-# INLINE patTerm #-}
patTerm :: Term f -> Either Var (Fun f, TermList f)
patTerm :: Term f -> Either Var (Fun f, TermList f)
patTerm Term{Int64
TermList f
termlist :: TermList f
root :: Int64
termlist :: forall f. Term f -> TermList f
root :: forall f. Term f -> Int64
..}
  | Bool
isFun     = (Fun f, TermList f) -> Either Var (Fun f, TermList f)
forall a b. b -> Either a b
Right (Int -> Fun f
forall f. Int -> Fun f
F Int
index, TermList f
ts)
  | Bool
otherwise = Var -> Either Var (Fun f, TermList f)
forall a b. a -> Either a b
Left (Int -> Var
V Int
index)
  where
    Symbol{Bool
Int
size :: Int
index :: Int
isFun :: Bool
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..} = Int64 -> Symbol
toSymbol Int64
root
    !UnsafeConsSym{urest :: forall f. TermList f -> TermList f
urest = TermList f
ts} = TermList f
termlist

-- | Convert a term to a termlist.
{-# INLINE singleton #-}
singleton :: Term f -> TermList f
singleton :: Term f -> TermList f
singleton Term{Int64
TermList f
termlist :: TermList f
root :: Int64
termlist :: forall f. Term f -> TermList f
root :: forall f. Term f -> Int64
..} = TermList f
termlist

instance Eq (TermList f) where
  TermList f
t == :: TermList f -> TermList f -> Bool
== TermList f
u = TermList f -> TermList f -> Ordering
forall a. Ord a => a -> a -> Ordering
compare TermList f
t TermList f
u Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance Ord (TermList f) where
  {-# INLINE compare #-}
  compare :: TermList f -> TermList f -> Ordering
compare TermList f
t TermList f
u =
    Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
t) (TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
u) Ordering -> Ordering -> Ordering
forall a. Monoid a => a -> a -> a
`mappend`
    ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays (TermList f -> ByteArray
forall f. TermList f -> ByteArray
array TermList f
t) (TermList f -> Int
forall f. TermList f -> Int
low TermList f
t Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k)
      (TermList f -> ByteArray
forall f. TermList f -> ByteArray
array TermList f
u) (TermList f -> Int
forall f. TermList f -> Int
low TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) ((TermList f -> Int
forall f. TermList f -> Int
high TermList f
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- TermList f -> Int
forall f. TermList f -> Int
low TermList f
t) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k)
    where
      k :: Int
k = Int64 -> Int
forall a. Prim a => a -> Int
sizeOf (Symbol -> Int64
fromSymbol Symbol
forall a. HasCallStack => a
undefined)

--------------------------------------------------------------------------------
-- Building terms.
--------------------------------------------------------------------------------

-- | A monoid for building terms.
-- 'mempty' represents the empty termlist, while 'mappend' appends two termlists.
newtype Builder f =
  Builder {
    Builder f -> forall s. Builder1 s f
unBuilder ::
      -- Takes: the term array and size, and current position in the term.
      -- Returns the final position, which may be out of bounds.
      forall s. Builder1 s f }

type Builder1 s f = State# s -> MutableByteArray# s -> Int# -> Int# -> (# State# s, Int# #)

instance Semigroup (Builder f) where
  {-# INLINE (<>) #-}
  Builder forall s. Builder1 s f
m1 <> :: Builder f -> Builder f -> Builder f
<> Builder forall s. Builder1 s f
m2 = (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder (Builder1 s f
forall s. Builder1 s f
m1 Builder1 s f -> Builder1 s f -> Builder1 s f
forall s f. Builder1 s f -> Builder1 s f -> Builder1 s f
`then_` Builder1 s f
forall s. Builder1 s f
m2)
instance Monoid (Builder f) where
  {-# INLINE mempty #-}
  mempty :: Builder f
mempty = (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder forall s. Builder1 s f
forall s f. Builder1 s f
built
  {-# INLINE mappend #-}
  mappend :: Builder f -> Builder f -> Builder f
mappend = Builder f -> Builder f -> Builder f
forall a. Semigroup a => a -> a -> a
(<>)

-- Build a termlist from a Builder.
-- Works by guessing an appropriate size, and retrying if that was too small.
{-# INLINE buildTermList #-}
buildTermList :: Builder f -> TermList f
buildTermList :: Builder f -> TermList f
buildTermList Builder f
builder = (forall s. ST s (TermList f)) -> TermList f
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (TermList f)) -> TermList f)
-> (forall s. ST s (TermList f)) -> TermList f
forall a b. (a -> b) -> a -> b
$ do
  let
    Builder forall s. Builder1 s f
m = Builder f
builder
    loop :: Int -> ST s (TermList f)
loop n :: Int
n@(I# Int#
n#) = do
      MutableByteArray MutableByteArray# s
mbytearray# <-
        Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int64 -> Int
forall a. Prim a => a -> Int
sizeOf (Symbol -> Int64
fromSymbol Symbol
forall a. HasCallStack => a
undefined))
      Int
n' <-
        STRep s Int -> ST s Int
forall s a. STRep s a -> ST s a
ST (STRep s Int -> ST s Int) -> STRep s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ \State# s
s ->
          case Builder1 s f
forall s. Builder1 s f
m State# s
s MutableByteArray# s
mbytearray# Int#
n# Int#
0# of
            (# State# s
s, Int#
n# #) -> (# State# s
s, Int# -> Int
I# Int#
n# #)
      if Int
n' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n then do
        MutableByteArray (PrimState (ST s))
-> Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
resizeMutableByteArray (MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mbytearray#) (Int
n' Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int64 -> Int
forall a. Prim a => a -> Int
sizeOf (Symbol -> Int64
fromSymbol Symbol
forall a. HasCallStack => a
undefined))
        !ByteArray
bytearray <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray (MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mbytearray#)
        TermList f -> ST s (TermList f)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList Int
0 Int
n' ByteArray
bytearray)
       else Int -> ST s (TermList f)
loop (Int
n'Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
2)
  Int -> ST s (TermList f)
forall s f. Int -> ST s (TermList f)
loop Int
128

-- Get at the term array.
{-# INLINE getByteArray #-}
getByteArray :: (MutableByteArray s -> Builder1 s f) -> Builder1 s f
getByteArray :: (MutableByteArray s -> Builder1 s f) -> Builder1 s f
getByteArray MutableByteArray s -> Builder1 s f
k = \State# s
s MutableByteArray# s
bytearray Int#
n Int#
i -> MutableByteArray s -> Builder1 s f
k (MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
bytearray) State# s
s MutableByteArray# s
bytearray Int#
n Int#
i

-- Get at the array size.
{-# INLINE getSize #-}
getSize :: (Int -> Builder1 s f) -> Builder1 s f
getSize :: (Int -> Builder1 s f) -> Builder1 s f
getSize Int -> Builder1 s f
k = \State# s
s MutableByteArray# s
bytearray Int#
n Int#
i -> Int -> Builder1 s f
k (Int# -> Int
I# Int#
n) State# s
s MutableByteArray# s
bytearray Int#
n Int#
i

-- Get at the current array index.
{-# INLINE getIndex #-}
getIndex :: (Int -> Builder1 s f) -> Builder1 s f
getIndex :: (Int -> Builder1 s f) -> Builder1 s f
getIndex Int -> Builder1 s f
k = \State# s
s MutableByteArray# s
bytearray Int#
n Int#
i -> Int -> Builder1 s f
k (Int# -> Int
I# Int#
i) State# s
s MutableByteArray# s
bytearray Int#
n Int#
i

-- Change the current array index.
{-# INLINE putIndex #-}
putIndex :: Int -> Builder1 s f
putIndex :: Int -> Builder1 s f
putIndex (I# Int#
i) = \State# s
s MutableByteArray# s
_ Int#
_ Int#
_ -> (# State# s
s, Int#
i #)

-- Lift an ST computation into a builder.
{-# INLINE liftST #-}
liftST :: ST s () -> Builder1 s f
liftST :: ST s () -> Builder1 s f
liftST (ST STRep s ()
m) =
  \State# s
s MutableByteArray# s
_ Int#
_ Int#
i ->
  case STRep s ()
m State# s
s of
    (# State# s
s, () #) -> (# State# s
s, Int#
i #)

-- Finish building.
{-# INLINE built #-}
built :: Builder1 s f
built :: Builder1 s f
built = \State# s
s MutableByteArray# s
_ Int#
_ Int#
i -> (# State# s
s, Int#
i #)

-- Sequence two builder operations.
{-# INLINE then_ #-}
then_ :: Builder1 s f -> Builder1 s f -> Builder1 s f
then_ :: Builder1 s f -> Builder1 s f -> Builder1 s f
then_ Builder1 s f
m1 Builder1 s f
m2 =
  \State# s
s MutableByteArray# s
bytearray Int#
n Int#
i ->
    case Builder1 s f
m1 State# s
s MutableByteArray# s
bytearray Int#
n Int#
i of
      (# State# s
s, Int#
i #) -> Builder1 s f
m2 State# s
s MutableByteArray# s
bytearray Int#
n Int#
i

-- checked j m executes m only if the array has room for j more symbols.
{-# INLINE checked #-}
checked :: Int -> Builder1 s f -> Builder1 s f
checked :: Int -> Builder1 s f -> Builder1 s f
checked Int
j Builder1 s f
m =
  (Int -> Builder1 s f) -> Builder1 s f
forall s f. (Int -> Builder1 s f) -> Builder1 s f
getSize ((Int -> Builder1 s f) -> Builder1 s f)
-> (Int -> Builder1 s f) -> Builder1 s f
forall a b. (a -> b) -> a -> b
$ \Int
n ->
  (Int -> Builder1 s f) -> Builder1 s f
forall s f. (Int -> Builder1 s f) -> Builder1 s f
getIndex ((Int -> Builder1 s f) -> Builder1 s f)
-> (Int -> Builder1 s f) -> Builder1 s f
forall a b. (a -> b) -> a -> b
$ \Int
i ->
  if Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n then Builder1 s f
m else Int -> Builder1 s f
forall s f. Int -> Builder1 s f
putIndex (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)

-- Emit an arbitrary symbol, with given arguments.
{-# INLINE emitSymbolBuilder #-}
emitSymbolBuilder :: Symbol -> Builder f -> Builder f
emitSymbolBuilder :: Symbol -> Builder f -> Builder f
emitSymbolBuilder Symbol
x Builder f
inner =
  (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder ((forall s. Builder1 s f) -> Builder f)
-> (forall s. Builder1 s f) -> Builder f
forall a b. (a -> b) -> a -> b
$ Int -> Builder1 s Any -> Builder1 s Any
forall s f. Int -> Builder1 s f -> Builder1 s f
checked Int
1 (Builder1 s Any -> Builder1 s Any)
-> Builder1 s Any -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$
    (MutableByteArray s -> Builder1 s Any) -> Builder1 s Any
forall s f. (MutableByteArray s -> Builder1 s f) -> Builder1 s f
getByteArray ((MutableByteArray s -> Builder1 s Any) -> Builder1 s Any)
-> (MutableByteArray s -> Builder1 s Any) -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$ \MutableByteArray s
bytearray ->
    -- Skip the symbol itself, then fill it in at the end, when we know the size
    -- of the symbol's arguments.
    (Int -> Builder1 s Any) -> Builder1 s Any
forall s f. (Int -> Builder1 s f) -> Builder1 s f
getIndex ((Int -> Builder1 s Any) -> Builder1 s Any)
-> (Int -> Builder1 s Any) -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$ \Int
n ->
    Int -> Builder1 s Any
forall s f. Int -> Builder1 s f
putIndex (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Builder1 s Any -> Builder1 s Any -> Builder1 s Any
forall s f. Builder1 s f -> Builder1 s f -> Builder1 s f
`then_`
    Builder f -> forall s. Builder1 s f
forall f. Builder f -> forall s. Builder1 s f
unBuilder Builder f
inner Builder1 s Any -> Builder1 s Any -> Builder1 s Any
forall s f. Builder1 s f -> Builder1 s f -> Builder1 s f
`then_`
    -- Fill in the symbol.
    (Int -> Builder1 s Any) -> Builder1 s Any
forall s f. (Int -> Builder1 s f) -> Builder1 s f
getIndex (\Int
m ->
      ST s () -> Builder1 s Any
forall s f. ST s () -> Builder1 s f
liftST (ST s () -> Builder1 s Any) -> ST s () -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$ MutableByteArray (PrimState (ST s)) -> Int -> Int64 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
bytearray Int
n (Symbol -> Int64
fromSymbol Symbol
x { size :: Int
size = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n }))

-- Emit a function application.
{-# INLINE emitApp #-}
emitApp :: Fun f -> Builder f -> Builder f
emitApp :: Fun f -> Builder f -> Builder f
emitApp (F Int
n) Builder f
inner = Symbol -> Builder f -> Builder f
forall f. Symbol -> Builder f -> Builder f
emitSymbolBuilder (Bool -> Int -> Int -> Symbol
Symbol Bool
True Int
n Int
0) Builder f
inner

-- Emit a variable.
{-# INLINE emitVar #-}
emitVar :: Var -> Builder f
emitVar :: Var -> Builder f
emitVar Var
x = Symbol -> Builder f -> Builder f
forall f. Symbol -> Builder f -> Builder f
emitSymbolBuilder (Bool -> Int -> Int -> Symbol
Symbol Bool
False (Var -> Int
var_id Var
x) Int
1) Builder f
forall a. Monoid a => a
mempty

-- Emit a whole termlist.
{-# INLINE emitTermList #-}
emitTermList :: TermList f -> Builder f
emitTermList :: TermList f -> Builder f
emitTermList (TermList Int
lo Int
hi ByteArray
array) =
  (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder ((forall s. Builder1 s f) -> Builder f)
-> (forall s. Builder1 s f) -> Builder f
forall a b. (a -> b) -> a -> b
$ Int -> Builder1 s Any -> Builder1 s Any
forall s f. Int -> Builder1 s f -> Builder1 s f
checked (Int
hiInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
lo) (Builder1 s Any -> Builder1 s Any)
-> Builder1 s Any -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$
    (MutableByteArray s -> Builder1 s Any) -> Builder1 s Any
forall s f. (MutableByteArray s -> Builder1 s f) -> Builder1 s f
getByteArray ((MutableByteArray s -> Builder1 s Any) -> Builder1 s Any)
-> (MutableByteArray s -> Builder1 s Any) -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$ \MutableByteArray s
mbytearray ->
    (Int -> Builder1 s Any) -> Builder1 s Any
forall s f. (Int -> Builder1 s f) -> Builder1 s f
getIndex ((Int -> Builder1 s Any) -> Builder1 s Any)
-> (Int -> Builder1 s Any) -> Builder1 s Any
forall a b. (a -> b) -> a -> b
$ \Int
n ->
    let k :: Int
k = Int64 -> Int
forall a. Prim a => a -> Int
sizeOf (Symbol -> Int64
fromSymbol Symbol
forall a. HasCallStack => a
undefined) in
    ST s () -> Builder1 s Any
forall s f. ST s () -> Builder1 s f
liftST (MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mbytearray (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k) ByteArray
array (Int
loInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k) ((Int
hiInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
lo)Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k)) Builder1 s Any -> Builder1 s Any -> Builder1 s Any
forall s f. Builder1 s f -> Builder1 s f -> Builder1 s f
`then_`
    Int -> Builder1 s Any
forall s f. Int -> Builder1 s f
putIndex (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
hiInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
lo)

----------------------------------------------------------------------
-- Efficient subterm testing.
----------------------------------------------------------------------

-- | Is a term contained as a subterm in a given termlist?
{-# INLINE isSubtermOfList #-}
isSubtermOfList :: Term f -> TermList f -> Bool
isSubtermOfList :: Term f -> TermList f -> Bool
isSubtermOfList Term f
t TermList f
u =
  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [ Term f -> TermList f
forall f. Term f -> TermList f
singleton Term f
t TermList f -> TermList f -> Bool
forall a. Eq a => a -> a -> Bool
== TermList f
u{low :: Int
low = TermList f -> Int
forall f. TermList f -> Int
low TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i, high :: Int
high = TermList f -> Int
forall f. TermList f -> Int
low TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n}
     | Int
i <- [Int
0..TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n]]
  where
    n :: Int
n = TermList f -> Int
forall f. TermList f -> Int
lenList (Term f -> TermList f
forall f. Term f -> TermList f
singleton Term f
t)

-- | Check if a variable occurs in a termlist.
{-# INLINE occursList #-}
occursList :: Var -> TermList f -> Bool
occursList :: Var -> TermList f -> Bool
occursList (V Int
x) TermList f
t = Int64 -> TermList f -> Bool
forall f. Int64 -> TermList f -> Bool
symbolOccursList (Symbol -> Int64
fromSymbol (Bool -> Int -> Int -> Symbol
Symbol Bool
False Int
x Int
1)) TermList f
t

symbolOccursList :: Int64 -> TermList f -> Bool
symbolOccursList :: Int64 -> TermList f -> Bool
symbolOccursList !Int64
_ TermList f
Empty = Bool
False
symbolOccursList Int64
n ConsSym{hd :: forall f. TermList f -> Term f
hd = Term f
t, rest :: forall f. TermList f -> TermList f
rest = TermList f
ts} = Term f -> Int64
forall f. Term f -> Int64
root Term f
t Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
n Bool -> Bool -> Bool
|| Int64 -> TermList f -> Bool
forall f. Int64 -> TermList f -> Bool
symbolOccursList Int64
n TermList f
ts