{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.Prim.Internal.Instances.PEvalRotateTerm
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.Prim.Internal.Instances.PEvalRotateTerm
  ( pevalFiniteBitsSymRotateRotateLeftTerm,
    pevalFiniteBitsSymRotateRotateRightTerm,
  )
where

import Data.Bits (Bits (rotateR), FiniteBits (finiteBitSize))
import Data.Proxy (Proxy (Proxy))
import qualified Data.SBV as SBV
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Internal.Core.Data.Class.SymRotate (SymRotate (symRotate))
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.Prim.Internal.Instances.SupportedPrim (bvIsNonZeroFromGEq1)
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( PEvalRotateTerm (pevalRotateLeftTerm, pevalRotateRightTerm, sbvRotateLeftTerm, sbvRotateRightTerm, withSbvRotateTermConstraint),
    SupportedNonFuncPrim (withNonFuncPrim),
    Term (ConTerm),
    conTerm,
    rotateLeftTerm,
    rotateRightTerm,
  )
import Grisette.Internal.SymPrim.Prim.Internal.Unfold (unaryUnfoldOnce)

pevalFiniteBitsSymRotateRotateLeftTerm ::
  forall a.
  (Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
  Term a ->
  Term a ->
  Term a
pevalFiniteBitsSymRotateRotateLeftTerm :: forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateLeftTerm Term a
t Term a
n =
  PartialRuleUnary a a -> TotalRuleUnary a a -> TotalRuleUnary a a
forall a b.
SupportedPrim b =>
PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
unaryUnfoldOnce
    (Term a -> PartialRuleUnary a a
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Maybe (Term a)
`doPevalFiniteBitsSymRotateRotateLeftTerm` Term a
n)
    (Term a -> TotalRuleUnary a a
forall a. PEvalRotateTerm a => Term a -> Term a -> Term a
`rotateLeftTerm` Term a
n)
    Term a
t

doPevalFiniteBitsSymRotateRotateLeftTerm ::
  forall a.
  (Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
  Term a ->
  Term a ->
  Maybe (Term a)
doPevalFiniteBitsSymRotateRotateLeftTerm :: forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Maybe (Term a)
doPevalFiniteBitsSymRotateRotateLeftTerm (ConTerm Id
_ a
a) (ConTerm Id
_ a
n)
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 = Term a -> Maybe (Term a)
forall a. a -> Maybe a
Just (Term a -> Maybe (Term a)) -> Term a -> Maybe (Term a)
forall a b. (a -> b) -> a -> b
$ a -> Term a
forall t.
(SupportedPrim t, Typeable t, Hashable t, Eq t, Show t) =>
t -> Term t
conTerm (a -> Term a) -> a -> Term a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. SymRotate a => a -> a -> a
symRotate a
a a
n -- Just $ conTerm $ rotateL a (fromIntegral n)
doPevalFiniteBitsSymRotateRotateLeftTerm Term a
x (ConTerm Id
_ a
0) = Term a -> Maybe (Term a)
forall a. a -> Maybe a
Just Term a
x
-- doPevalFiniteBitsSymRotateRotateLeftTerm (RotateLeftTerm _ x (ConTerm _ n)) (ConTerm _ n1)
--   | n >= 0 && n1 >= 0 = Just $ pevalFiniteBitsSymRotateRotateLeftTerm x (conTerm $ n + n1)
doPevalFiniteBitsSymRotateRotateLeftTerm Term a
x (ConTerm Id
_ a
n)
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n :: Integer) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Id -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Id
bs =
      Term a -> Maybe (Term a)
forall a. a -> Maybe a
Just (Term a -> Maybe (Term a)) -> Term a -> Maybe (Term a)
forall a b. (a -> b) -> a -> b
$
        Term a -> Term a -> Term a
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateLeftTerm
          Term a
x
          (a -> Term a
forall t.
(SupportedPrim t, Typeable t, Hashable t, Eq t, Show t) =>
t -> Term t
conTerm (a -> Term a) -> a -> Term a
forall a b. (a -> b) -> a -> b
$ a
n a -> a -> a
forall a. Integral a => a -> a -> a
`mod` Id -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Id
bs)
  where
    bs :: Id
bs = a -> Id
forall b. FiniteBits b => b -> Id
finiteBitSize a
n
doPevalFiniteBitsSymRotateRotateLeftTerm Term a
_ Term a
_ = Maybe (Term a)
forall a. Maybe a
Nothing

pevalFiniteBitsSymRotateRotateRightTerm ::
  forall a.
  (Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
  Term a ->
  Term a ->
  Term a
pevalFiniteBitsSymRotateRotateRightTerm :: forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateRightTerm Term a
t Term a
n =
  PartialRuleUnary a a -> TotalRuleUnary a a -> TotalRuleUnary a a
forall a b.
SupportedPrim b =>
PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
unaryUnfoldOnce
    (Term a -> PartialRuleUnary a a
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Maybe (Term a)
`doPevalFiniteBitsSymRotateRotateRightTerm` Term a
n)
    (Term a -> TotalRuleUnary a a
forall a. PEvalRotateTerm a => Term a -> Term a -> Term a
`rotateRightTerm` Term a
n)
    Term a
t

doPevalFiniteBitsSymRotateRotateRightTerm ::
  forall a.
  (Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
  Term a ->
  Term a ->
  Maybe (Term a)
doPevalFiniteBitsSymRotateRotateRightTerm :: forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Maybe (Term a)
doPevalFiniteBitsSymRotateRotateRightTerm (ConTerm Id
_ a
a) (ConTerm Id
_ a
n)
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 =
      Term a -> Maybe (Term a)
forall a. a -> Maybe a
Just (Term a -> Maybe (Term a)) -> (a -> Term a) -> a -> Maybe (Term a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Term a
forall t.
(SupportedPrim t, Typeable t, Hashable t, Eq t, Show t) =>
t -> Term t
conTerm (a -> Maybe (Term a)) -> a -> Maybe (Term a)
forall a b. (a -> b) -> a -> b
$
        a -> Id -> a
forall a. Bits a => a -> Id -> a
rotateR
          a
a
          ( Integer -> Id
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Id) -> Integer -> Id
forall a b. (a -> b) -> a -> b
$
              (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n :: Integer)
                Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Id -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Id
forall b. FiniteBits b => b -> Id
finiteBitSize a
n)
          )
doPevalFiniteBitsSymRotateRotateRightTerm Term a
x (ConTerm Id
_ a
0) = Term a -> Maybe (Term a)
forall a. a -> Maybe a
Just Term a
x
-- doPevalFiniteBitsSymRotateRotateRightTerm (RotateRightTerm _ x (ConTerm _ n)) (ConTerm _ n1)
--   | n >= 0 && n1 >= 0 = Just $ pevalFiniteBitsSymRotateRotateRightTerm x (conTerm $ n + n1)
doPevalFiniteBitsSymRotateRotateRightTerm Term a
x (ConTerm Id
_ a
n)
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n :: Integer) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Id -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Id
bs =
      Term a -> Maybe (Term a)
forall a. a -> Maybe a
Just (Term a -> Maybe (Term a)) -> Term a -> Maybe (Term a)
forall a b. (a -> b) -> a -> b
$
        Term a -> Term a -> Term a
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateRightTerm
          Term a
x
          (a -> Term a
forall t.
(SupportedPrim t, Typeable t, Hashable t, Eq t, Show t) =>
t -> Term t
conTerm (a -> Term a) -> a -> Term a
forall a b. (a -> b) -> a -> b
$ a
n a -> a -> a
forall a. Integral a => a -> a -> a
`mod` Id -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Id
bs)
  where
    bs :: Id
bs = a -> Id
forall b. FiniteBits b => b -> Id
finiteBitSize a
n
doPevalFiniteBitsSymRotateRotateRightTerm Term a
_ Term a
_ = Maybe (Term a)
forall a. Maybe a
Nothing

instance (KnownNat n, 1 <= n) => PEvalRotateTerm (IntN n) where
  pevalRotateLeftTerm :: Term (IntN n) -> Term (IntN n) -> Term (IntN n)
pevalRotateLeftTerm = Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateLeftTerm
  pevalRotateRightTerm :: Term (IntN n) -> Term (IntN n) -> Term (IntN n)
pevalRotateRightTerm = Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateRightTerm
  withSbvRotateTermConstraint :: forall (n :: Nat) (proxy :: Nat -> *) r.
KnownIsZero n =>
proxy n -> (SIntegral (NonFuncSBVBaseType n (IntN n)) => r) -> r
withSbvRotateTermConstraint proxy n
p SIntegral (NonFuncSBVBaseType n (IntN n)) => r
r =
    Proxy n -> (BVIsNonZero n => r) -> r
forall (w :: Nat) r (proxy :: Nat -> *).
(1 <= w) =>
proxy w -> (BVIsNonZero w => r) -> r
bvIsNonZeroFromGEq1 (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n) ((BVIsNonZero n => r) -> r) -> (BVIsNonZero n => r) -> r
forall a b. (a -> b) -> a -> b
$
      forall a (n :: Nat) (proxy :: Nat -> *) r.
(SupportedNonFuncPrim a, KnownIsZero n) =>
proxy n
-> ((SymVal (NonFuncSBVBaseType n a), EqSymbolic (SBVType n a),
     Mergeable (SBVType n a), SMTDefinable (SBVType n a),
     Mergeable (SBVType n a),
     SBVType n a ~ SBV (NonFuncSBVBaseType n a), PrimConstraint n a) =>
    r)
-> r
withNonFuncPrim @(IntN n) proxy n
p r
(SymVal (NonFuncSBVBaseType n (IntN n)),
 EqSymbolic (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
 SMTDefinable (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
 SBVType n (IntN n) ~ SBV (NonFuncSBVBaseType n (IntN n)),
 PrimConstraint n (IntN n)) =>
r
SIntegral (NonFuncSBVBaseType n (IntN n)) => r
r

  -- SBV's rotateLeft and rotateRight are broken for signed values, so we have to
  -- do this
  -- https://github.com/LeventErkok/sbv/issues/673
  sbvRotateLeftTerm :: forall (proxy :: Nat -> *) (n :: Nat).
KnownIsZero n =>
proxy n
-> SBVType n (IntN n) -> SBVType n (IntN n) -> SBVType n (IntN n)
sbvRotateLeftTerm proxy n
p SBVType n (IntN n)
l SBVType n (IntN n)
r =
    forall a (n :: Nat) (proxy :: Nat -> *) r.
(SupportedNonFuncPrim a, KnownIsZero n) =>
proxy n
-> ((SymVal (NonFuncSBVBaseType n a), EqSymbolic (SBVType n a),
     Mergeable (SBVType n a), SMTDefinable (SBVType n a),
     Mergeable (SBVType n a),
     SBVType n a ~ SBV (NonFuncSBVBaseType n a), PrimConstraint n a) =>
    r)
-> r
withNonFuncPrim @(IntN n) proxy n
p (((SymVal (NonFuncSBVBaseType n (IntN n)),
   EqSymbolic (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
   SMTDefinable (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
   SBVType n (IntN n) ~ SBV (NonFuncSBVBaseType n (IntN n)),
   PrimConstraint n (IntN n)) =>
  SBVType n (IntN n))
 -> SBVType n (IntN n))
-> ((SymVal (NonFuncSBVBaseType n (IntN n)),
     EqSymbolic (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
     SMTDefinable (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
     SBVType n (IntN n) ~ SBV (NonFuncSBVBaseType n (IntN n)),
     PrimConstraint n (IntN n)) =>
    SBVType n (IntN n))
-> SBVType n (IntN n)
forall a b. (a -> b) -> a -> b
$
      forall t (n :: Nat) (proxy :: Nat -> *) r.
(PEvalRotateTerm t, KnownIsZero n) =>
proxy n -> (SIntegral (NonFuncSBVBaseType n t) => r) -> r
withSbvRotateTermConstraint @(IntN n) proxy n
p ((SIntegral (NonFuncSBVBaseType n (IntN n)) => SBVType n (IntN n))
 -> SBVType n (IntN n))
-> (SIntegral (NonFuncSBVBaseType n (IntN n)) =>
    SBVType n (IntN n))
-> SBVType n (IntN n)
forall a b. (a -> b) -> a -> b
$
        SBV (WordN n) -> SBV (IntN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral (SBV (WordN n) -> SBV (IntN n)) -> SBV (WordN n) -> SBV (IntN n)
forall a b. (a -> b) -> a -> b
$
          SBV (WordN n) -> SBV (WordN n) -> SBV (WordN n)
forall a b. (SIntegral a, SIntegral b) => SBV a -> SBV b -> SBV a
SBV.sRotateLeft
            (SBV (IntN n) -> SBV (WordN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral SBV (IntN n)
SBVType n (IntN n)
l :: SBV.SWord n)
            (SBV (IntN n) -> SBV (WordN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral SBV (IntN n)
SBVType n (IntN n)
r :: SBV.SWord n)
  sbvRotateRightTerm :: forall (proxy :: Nat -> *) (n :: Nat).
KnownIsZero n =>
proxy n
-> SBVType n (IntN n) -> SBVType n (IntN n) -> SBVType n (IntN n)
sbvRotateRightTerm proxy n
p SBVType n (IntN n)
l SBVType n (IntN n)
r =
    forall a (n :: Nat) (proxy :: Nat -> *) r.
(SupportedNonFuncPrim a, KnownIsZero n) =>
proxy n
-> ((SymVal (NonFuncSBVBaseType n a), EqSymbolic (SBVType n a),
     Mergeable (SBVType n a), SMTDefinable (SBVType n a),
     Mergeable (SBVType n a),
     SBVType n a ~ SBV (NonFuncSBVBaseType n a), PrimConstraint n a) =>
    r)
-> r
withNonFuncPrim @(IntN n) proxy n
p (((SymVal (NonFuncSBVBaseType n (IntN n)),
   EqSymbolic (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
   SMTDefinable (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
   SBVType n (IntN n) ~ SBV (NonFuncSBVBaseType n (IntN n)),
   PrimConstraint n (IntN n)) =>
  SBVType n (IntN n))
 -> SBVType n (IntN n))
-> ((SymVal (NonFuncSBVBaseType n (IntN n)),
     EqSymbolic (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
     SMTDefinable (SBVType n (IntN n)), Mergeable (SBVType n (IntN n)),
     SBVType n (IntN n) ~ SBV (NonFuncSBVBaseType n (IntN n)),
     PrimConstraint n (IntN n)) =>
    SBVType n (IntN n))
-> SBVType n (IntN n)
forall a b. (a -> b) -> a -> b
$
      forall t (n :: Nat) (proxy :: Nat -> *) r.
(PEvalRotateTerm t, KnownIsZero n) =>
proxy n -> (SIntegral (NonFuncSBVBaseType n t) => r) -> r
withSbvRotateTermConstraint @(IntN n) proxy n
p ((SIntegral (NonFuncSBVBaseType n (IntN n)) => SBVType n (IntN n))
 -> SBVType n (IntN n))
-> (SIntegral (NonFuncSBVBaseType n (IntN n)) =>
    SBVType n (IntN n))
-> SBVType n (IntN n)
forall a b. (a -> b) -> a -> b
$
        SBV (WordN n) -> SBV (IntN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral (SBV (WordN n) -> SBV (IntN n)) -> SBV (WordN n) -> SBV (IntN n)
forall a b. (a -> b) -> a -> b
$
          SBV (WordN n) -> SBV (WordN n) -> SBV (WordN n)
forall a b. (SIntegral a, SIntegral b) => SBV a -> SBV b -> SBV a
SBV.sRotateRight
            (SBV (IntN n) -> SBV (WordN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral SBV (IntN n)
SBVType n (IntN n)
l :: SBV.SWord n)
            (SBV (IntN n) -> SBV (WordN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral SBV (IntN n)
SBVType n (IntN n)
r :: SBV.SWord n)

instance (KnownNat n, 1 <= n) => PEvalRotateTerm (WordN n) where
  pevalRotateLeftTerm :: Term (WordN n) -> Term (WordN n) -> Term (WordN n)
pevalRotateLeftTerm = Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateLeftTerm
  pevalRotateRightTerm :: Term (WordN n) -> Term (WordN n) -> Term (WordN n)
pevalRotateRightTerm = Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, PEvalRotateTerm a) =>
Term a -> Term a -> Term a
pevalFiniteBitsSymRotateRotateRightTerm
  withSbvRotateTermConstraint :: forall (n :: Nat) (proxy :: Nat -> *) r.
KnownIsZero n =>
proxy n -> (SIntegral (NonFuncSBVBaseType n (WordN n)) => r) -> r
withSbvRotateTermConstraint proxy n
p SIntegral (NonFuncSBVBaseType n (WordN n)) => r
r =
    Proxy n -> (BVIsNonZero n => r) -> r
forall (w :: Nat) r (proxy :: Nat -> *).
(1 <= w) =>
proxy w -> (BVIsNonZero w => r) -> r
bvIsNonZeroFromGEq1 (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n) ((BVIsNonZero n => r) -> r) -> (BVIsNonZero n => r) -> r
forall a b. (a -> b) -> a -> b
$
      forall a (n :: Nat) (proxy :: Nat -> *) r.
(SupportedNonFuncPrim a, KnownIsZero n) =>
proxy n
-> ((SymVal (NonFuncSBVBaseType n a), EqSymbolic (SBVType n a),
     Mergeable (SBVType n a), SMTDefinable (SBVType n a),
     Mergeable (SBVType n a),
     SBVType n a ~ SBV (NonFuncSBVBaseType n a), PrimConstraint n a) =>
    r)
-> r
withNonFuncPrim @(WordN n) proxy n
p r
(SymVal (NonFuncSBVBaseType n (WordN n)),
 EqSymbolic (SBVType n (WordN n)), Mergeable (SBVType n (WordN n)),
 SMTDefinable (SBVType n (WordN n)),
 Mergeable (SBVType n (WordN n)),
 SBVType n (WordN n) ~ SBV (NonFuncSBVBaseType n (WordN n)),
 PrimConstraint n (WordN n)) =>
r
SIntegral (NonFuncSBVBaseType n (WordN n)) => r
r