{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators #-}
module What4.Utils.Serialize
(
withRounding
, makeSymbol
, asyncLinked
, withAsyncLinked
) where
import qualified Control.Exception as E
import Text.Printf ( printf )
import qualified Data.BitVector.Sized as BV
import What4.BaseTypes
import qualified What4.Interface as S
import What4.Symbol ( SolverSymbol, userSymbol )
import qualified UnliftIO as U
asyncLinked :: (U.MonadUnliftIO m) => m () -> m (U.Async ())
asyncLinked :: forall (m :: Type -> Type). MonadUnliftIO m => m () -> m (Async ())
asyncLinked m ()
action = do
((forall a. m a -> m a) -> m (Async ())) -> m (Async ())
forall (m :: Type -> Type) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
U.mask (((forall a. m a -> m a) -> m (Async ())) -> m (Async ()))
-> ((forall a. m a -> m a) -> m (Async ())) -> m (Async ())
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
Async ()
a <- m () -> m (Async ())
forall (m :: Type -> Type) a. MonadUnliftIO m => m a -> m (Async a)
U.async (m () -> m (Async ())) -> m () -> m (Async ())
forall a b. (a -> b) -> a -> b
$ (AsyncException -> m ()) -> m () -> m ()
forall (m :: Type -> Type) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleUnliftIO AsyncException -> m ()
forall (m :: Type -> Type). Monad m => AsyncException -> m ()
threadKilledHandler (m () -> m ()
forall a. m a -> m a
restore m ()
action)
m (Async ()) -> m (Async ())
forall a. m a -> m a
restore (m (Async ()) -> m (Async ())) -> m (Async ()) -> m (Async ())
forall a b. (a -> b) -> a -> b
$ do
Async () -> m ()
forall (m :: Type -> Type) a. MonadIO m => Async a -> m ()
U.link Async ()
a
Async () -> m (Async ())
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Async ()
a
threadKilledHandler :: Monad m => E.AsyncException -> m ()
threadKilledHandler :: forall (m :: Type -> Type). Monad m => AsyncException -> m ()
threadKilledHandler AsyncException
E.ThreadKilled = () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
threadKilledHandler AsyncException
e = AsyncException -> m ()
forall a e. Exception e => e -> a
E.throw AsyncException
e
withAsyncLinked :: (U.MonadUnliftIO m) => m () -> (U.Async () -> m a) -> m a
withAsyncLinked :: forall (m :: Type -> Type) a.
MonadUnliftIO m =>
m () -> (Async () -> m a) -> m a
withAsyncLinked m ()
child Async () -> m a
parent = do
((forall a. m a -> m a) -> m a) -> m a
forall (m :: Type -> Type) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
U.mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
m () -> (Async () -> m a) -> m a
forall (m :: Type -> Type) a b.
MonadUnliftIO m =>
m a -> (Async a -> m b) -> m b
U.withAsync ((AsyncException -> m ()) -> m () -> m ()
forall (m :: Type -> Type) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleUnliftIO AsyncException -> m ()
forall (m :: Type -> Type). Monad m => AsyncException -> m ()
threadKilledHandler (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ m () -> m ()
forall a. m a -> m a
restore m ()
child) ((Async () -> m a) -> m a) -> (Async () -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Async ()
a -> m a -> m a
forall a. m a -> m a
restore (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
Async () -> m ()
forall (m :: Type -> Type) a. MonadIO m => Async a -> m ()
U.link Async ()
a
Async () -> m a
parent Async ()
a
handleUnliftIO :: (U.MonadUnliftIO m, U.Exception e)
=> (e -> m a) -> m a -> m a
handleUnliftIO :: forall (m :: Type -> Type) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleUnliftIO e -> m a
h m a
a = (UnliftIO m -> IO a) -> m a
forall (m :: Type -> Type) a.
MonadUnliftIO m =>
(UnliftIO m -> IO a) -> m a
U.withUnliftIO ((UnliftIO m -> IO a) -> m a) -> (UnliftIO m -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ \UnliftIO m
u ->
(e -> IO a) -> IO a -> IO a
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (UnliftIO m -> forall a. m a -> IO a
forall (m :: Type -> Type). UnliftIO m -> forall a. m a -> IO a
U.unliftIO UnliftIO m
u (m a -> IO a) -> (e -> m a) -> e -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
h) (UnliftIO m -> forall a. m a -> IO a
forall (m :: Type -> Type). UnliftIO m -> forall a. m a -> IO a
U.unliftIO UnliftIO m
u m a
a)
makeSymbol :: String -> SolverSymbol
makeSymbol :: String -> SolverSymbol
makeSymbol String
name = case String -> Either SolverSymbolError SolverSymbol
userSymbol String
sanitizedName of
Right SolverSymbol
symbol -> SolverSymbol
symbol
Left SolverSymbolError
_ -> String -> SolverSymbol
forall a. HasCallStack => String -> a
error (String -> SolverSymbol) -> String -> SolverSymbol
forall a b. (a -> b) -> a -> b
$ String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"tried to create symbol with bad name: %s (%s)"
String
name String
sanitizedName
where
sanitizedName :: String
sanitizedName = (Char -> Char) -> String -> String
forall a b. (a -> b) -> [a] -> [b]
map (\Char
c -> case Char
c of Char
' ' -> Char
'_'; Char
'.' -> Char
'_'; Char
_ -> Char
c) String
name
withRounding
:: forall sym tp
. S.IsExprBuilder sym
=> sym
-> S.SymBV sym 2
-> (S.RoundingMode -> IO (S.SymExpr sym tp))
-> IO (S.SymExpr sym tp)
withRounding :: forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> SymBV sym 2
-> (RoundingMode -> IO (SymExpr sym tp))
-> IO (SymExpr sym tp)
withRounding sym
sym SymBV sym 2
r RoundingMode -> IO (SymExpr sym tp)
action = do
SymExpr sym BaseBoolType
cRNE <- RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
S.RNE
SymExpr sym BaseBoolType
cRTZ <- RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
S.RTZ
SymExpr sym BaseBoolType
cRTP <- RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
S.RTP
(sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp))
-> sym
-> SymExpr sym BaseBoolType
-> IO (SymExpr sym tp)
-> IO (SymExpr sym tp)
-> IO (SymExpr sym tp)
forall sym v.
IsExprBuilder sym =>
(sym -> Pred sym -> v -> v -> IO v)
-> sym -> Pred sym -> IO v -> IO v -> IO v
S.iteM sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> Pred sym
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
forall (tp :: BaseType).
sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
S.baseTypeIte sym
sym SymExpr sym BaseBoolType
cRNE
(RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RNE) (IO (SymExpr sym tp) -> IO (SymExpr sym tp))
-> IO (SymExpr sym tp) -> IO (SymExpr sym tp)
forall a b. (a -> b) -> a -> b
$
(sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp))
-> sym
-> SymExpr sym BaseBoolType
-> IO (SymExpr sym tp)
-> IO (SymExpr sym tp)
-> IO (SymExpr sym tp)
forall sym v.
IsExprBuilder sym =>
(sym -> Pred sym -> v -> v -> IO v)
-> sym -> Pred sym -> IO v -> IO v -> IO v
S.iteM sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> Pred sym
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
forall (tp :: BaseType).
sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
S.baseTypeIte sym
sym SymExpr sym BaseBoolType
cRTZ
(RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RTZ) (IO (SymExpr sym tp) -> IO (SymExpr sym tp))
-> IO (SymExpr sym tp) -> IO (SymExpr sym tp)
forall a b. (a -> b) -> a -> b
$
(sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp))
-> sym
-> SymExpr sym BaseBoolType
-> IO (SymExpr sym tp)
-> IO (SymExpr sym tp)
-> IO (SymExpr sym tp)
forall sym v.
IsExprBuilder sym =>
(sym -> Pred sym -> v -> v -> IO v)
-> sym -> Pred sym -> IO v -> IO v -> IO v
S.iteM sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> Pred sym
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
forall (tp :: BaseType).
sym
-> SymExpr sym BaseBoolType
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
S.baseTypeIte sym
sym SymExpr sym BaseBoolType
cRTP (RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RTP) (RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RTN)
where
roundingCond :: S.RoundingMode -> IO (S.Pred sym)
roundingCond :: RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
rm =
sym -> SymBV sym 2 -> SymBV sym 2 -> IO (SymExpr sym BaseBoolType)
forall (w :: Natural).
(1 <= w) =>
sym -> SymBV sym w -> SymBV sym w -> IO (SymExpr sym BaseBoolType)
forall sym (w :: Natural).
(IsExprBuilder sym, 1 <= w) =>
sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
S.bvEq sym
sym SymBV sym 2
r (SymBV sym 2 -> IO (SymExpr sym BaseBoolType))
-> IO (SymBV sym 2) -> IO (SymExpr sym BaseBoolType)
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< sym -> NatRepr 2 -> BV 2 -> IO (SymBV sym 2)
forall (w :: Natural).
(1 <= w) =>
sym -> NatRepr w -> BV w -> IO (SymBV sym w)
forall sym (w :: Natural).
(IsExprBuilder sym, 1 <= w) =>
sym -> NatRepr w -> BV w -> IO (SymBV sym w)
S.bvLit sym
sym NatRepr 2
forall (n :: Natural). KnownNat n => NatRepr n
knownNat (NatRepr 2 -> Integer -> BV 2
forall (w :: Natural). NatRepr w -> Integer -> BV w
BV.mkBV NatRepr 2
forall (n :: Natural). KnownNat n => NatRepr n
knownNat (RoundingMode -> Integer
roundingModeToBits RoundingMode
rm))
roundingModeToBits :: S.RoundingMode -> Integer
roundingModeToBits :: RoundingMode -> Integer
roundingModeToBits = \case
RoundingMode
S.RNE -> Integer
0
RoundingMode
S.RTZ -> Integer
1
RoundingMode
S.RTP -> Integer
2
RoundingMode
S.RTN -> Integer
3
RoundingMode
S.RNA -> String -> Integer
forall a. HasCallStack => String -> a
error (String -> Integer) -> String -> Integer
forall a b. (a -> b) -> a -> b
$ String
"unsupported rounding mode: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ RoundingMode -> String
forall a. Show a => a -> String
show RoundingMode
S.RNA