{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DerivingVia         #-}
{-# LANGUAGE TypeApplications    #-}

module ZkFold.Symbolic.Data.Ord (Ord (..), Lexicographical (..), circuitGE, circuitGT, getBitsBE) where

import qualified Data.Bool                                              as Haskell
import           Prelude                                                (concatMap, reverse, zipWith, ($), (.))
import qualified Prelude                                                as Haskell

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field                        (Zp)
import           ZkFold.Base.Algebra.Basic.Number                       (Prime)
import           ZkFold.Symbolic.Compiler
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (boolCheckC)
import           ZkFold.Symbolic.Data.Bool                              (Bool (..), BoolType (..))
import           ZkFold.Symbolic.Data.Conditional                       (Conditional (..))
import           ZkFold.Symbolic.Data.DiscreteField                     (DiscreteField (..))

-- TODO (Issue #23): add `compare`
class Ord b a where
    (<=) :: a -> a -> b

    (<) :: a -> a -> b

    (>=) :: a -> a -> b

    (>) :: a -> a -> b

    max :: a -> a -> a
    -- max x y = bool @b y x $ x <= y

    min :: a -> a -> a
    -- min x y = bool @b y x $ x >= y

instance Haskell.Ord a => Ord Haskell.Bool a where
    <= :: a -> a -> Bool
(<=) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(Haskell.<=)

    < :: a -> a -> Bool
(<) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(Haskell.<)

    >= :: a -> a -> Bool
(>=) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(Haskell.>=)

    > :: a -> a -> Bool
(>) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(Haskell.>)

    max :: a -> a -> a
max = a -> a -> a
forall a. Ord a => a -> a -> a
Haskell.max

    min :: a -> a -> a
min = a -> a -> a
forall a. Ord a => a -> a -> a
Haskell.min

instance (Prime p, Haskell.Ord x) => Ord (Bool (Zp p)) x where
    x
x <= :: x -> x -> Bool (Zp p)
<= x
y = Zp p -> Bool (Zp p)
forall x. x -> Bool x
Bool (Zp p -> Bool (Zp p)) -> Zp p -> Bool (Zp p)
forall a b. (a -> b) -> a -> b
$ Zp p -> Zp p -> Bool -> Zp p
forall a. a -> a -> Bool -> a
Haskell.bool Zp p
forall a. AdditiveMonoid a => a
zero Zp p
forall a. MultiplicativeMonoid a => a
one (x
x x -> x -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.<= x
y)

    x
x < :: x -> x -> Bool (Zp p)
<  x
y = Zp p -> Bool (Zp p)
forall x. x -> Bool x
Bool (Zp p -> Bool (Zp p)) -> Zp p -> Bool (Zp p)
forall a b. (a -> b) -> a -> b
$ Zp p -> Zp p -> Bool -> Zp p
forall a. a -> a -> Bool -> a
Haskell.bool Zp p
forall a. AdditiveMonoid a => a
zero Zp p
forall a. MultiplicativeMonoid a => a
one (x
x x -> x -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.<  x
y)

    x
x >= :: x -> x -> Bool (Zp p)
>= x
y = Zp p -> Bool (Zp p)
forall x. x -> Bool x
Bool (Zp p -> Bool (Zp p)) -> Zp p -> Bool (Zp p)
forall a b. (a -> b) -> a -> b
$ Zp p -> Zp p -> Bool -> Zp p
forall a. a -> a -> Bool -> a
Haskell.bool Zp p
forall a. AdditiveMonoid a => a
zero Zp p
forall a. MultiplicativeMonoid a => a
one (x
x x -> x -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.>= x
y)

    x
x > :: x -> x -> Bool (Zp p)
>  x
y = Zp p -> Bool (Zp p)
forall x. x -> Bool x
Bool (Zp p -> Bool (Zp p)) -> Zp p -> Bool (Zp p)
forall a b. (a -> b) -> a -> b
$ Zp p -> Zp p -> Bool -> Zp p
forall a. a -> a -> Bool -> a
Haskell.bool Zp p
forall a. AdditiveMonoid a => a
zero Zp p
forall a. MultiplicativeMonoid a => a
one (x
x x -> x -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.>  x
y)

    max :: x -> x -> x
max x
x x
y = x -> x -> Bool -> x
forall a. a -> a -> Bool -> a
Haskell.bool x
x x
y (Bool -> x) -> Bool -> x
forall a b. (a -> b) -> a -> b
$ x
x x -> x -> Bool
forall b a. Ord b a => a -> a -> b
<= x
y

    min :: x -> x -> x
min x
x x
y = x -> x -> Bool -> x
forall a. a -> a -> Bool -> a
Haskell.bool x
x x
y (Bool -> x) -> Bool -> x
forall a b. (a -> b) -> a -> b
$ x
x x -> x -> Bool
forall b a. Ord b a => a -> a -> b
>= x
y

newtype Lexicographical a = Lexicographical a
-- ^ A newtype wrapper for easy definition of Ord instances
-- (though not necessarily a most effective one)

deriving newtype instance SymbolicData a x => SymbolicData a (Lexicographical x)

deriving via (Lexicographical (ArithmeticCircuit a))
    instance Arithmetic a => Ord (Bool (ArithmeticCircuit a)) (ArithmeticCircuit a)

-- | Every @SymbolicData@ type can be compared lexicographically.
instance SymbolicData a x => Ord (Bool (ArithmeticCircuit a)) (Lexicographical x) where
    Lexicographical x
x <= :: Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
<= Lexicographical x
y = Lexicographical x
y Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
forall b a. Ord b a => a -> a -> b
>= Lexicographical x
x

    Lexicographical x
x < :: Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
<  Lexicographical x
y = Lexicographical x
y Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
forall b a. Ord b a => a -> a -> b
> Lexicographical x
x

    Lexicographical x
x >= :: Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
>= Lexicographical x
y = [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
forall a.
Arithmetic a =>
[ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
circuitGE (Lexicographical x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
getBitsBE Lexicographical x
x) (Lexicographical x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
getBitsBE Lexicographical x
y)

    Lexicographical x
x > :: Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
> Lexicographical x
y = [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
forall a.
Arithmetic a =>
[ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
circuitGT (Lexicographical x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
getBitsBE Lexicographical x
x) (Lexicographical x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
getBitsBE Lexicographical x
y)

    max :: Lexicographical x -> Lexicographical x -> Lexicographical x
max Lexicographical x
x Lexicographical x
y = forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool (ArithmeticCircuit a)) Lexicographical x
x Lexicographical x
y (Bool (ArithmeticCircuit a) -> Lexicographical x)
-> Bool (ArithmeticCircuit a) -> Lexicographical x
forall a b. (a -> b) -> a -> b
$ Lexicographical x
x Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
forall b a. Ord b a => a -> a -> b
< Lexicographical x
y

    min :: Lexicographical x -> Lexicographical x -> Lexicographical x
min Lexicographical x
x Lexicographical x
y = forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool (ArithmeticCircuit a)) Lexicographical x
x Lexicographical x
y (Bool (ArithmeticCircuit a) -> Lexicographical x)
-> Bool (ArithmeticCircuit a) -> Lexicographical x
forall a b. (a -> b) -> a -> b
$ Lexicographical x
x Lexicographical x
-> Lexicographical x -> Bool (ArithmeticCircuit a)
forall b a. Ord b a => a -> a -> b
> Lexicographical x
y

getBitsBE :: SymbolicData a x => x -> [ArithmeticCircuit a]
-- ^ @getBitsBE x@ returns a list of circuits computing bits of @x@, eldest to
-- youngest.
getBitsBE :: forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
getBitsBE x
x = (ArithmeticCircuit a -> [ArithmeticCircuit a])
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap ([ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a]
reverse ([ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> (ArithmeticCircuit a -> [ArithmeticCircuit a])
-> ArithmeticCircuit a
-> [ArithmeticCircuit a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a -> [ArithmeticCircuit a]
forall a. BinaryExpansion a => a -> [a]
binaryExpansion) ([ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a b. (a -> b) -> a -> b
$ x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
x

circuitGE :: Arithmetic a => [ArithmeticCircuit a] -> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
-- ^ Given two lists of bits of equal length, compares them lexicographically.
circuitGE :: forall a.
Arithmetic a =>
[ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
circuitGE [ArithmeticCircuit a]
xs [ArithmeticCircuit a]
ys = (Bool (ArithmeticCircuit a)
 -> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a))
-> (ArithmeticCircuit a -> ArithmeticCircuit a)
-> [ArithmeticCircuit a]
-> Bool (ArithmeticCircuit a)
forall b x.
DiscreteField b x =>
(b -> b -> b) -> (x -> x) -> [x] -> b
bitCheckGE Bool (ArithmeticCircuit a)
-> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a)
forall a.
Arithmetic a =>
Bool (ArithmeticCircuit a)
-> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a)
dor ArithmeticCircuit a -> ArithmeticCircuit a
forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
boolCheckC ((ArithmeticCircuit a -> ArithmeticCircuit a -> ArithmeticCircuit a)
-> [ArithmeticCircuit a]
-> [ArithmeticCircuit a]
-> [ArithmeticCircuit a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) [ArithmeticCircuit a]
xs [ArithmeticCircuit a]
ys)

circuitGT :: Arithmetic a => [ArithmeticCircuit a] -> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
-- ^ Given two lists of bits of equal length, compares them lexicographically.
circuitGT :: forall a.
Arithmetic a =>
[ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
circuitGT [ArithmeticCircuit a]
xs [ArithmeticCircuit a]
ys = (Bool (ArithmeticCircuit a)
 -> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> Bool (ArithmeticCircuit a)
forall b x. DiscreteField b x => (b -> b -> b) -> [x] -> b
bitCheckGT Bool (ArithmeticCircuit a)
-> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a)
forall a.
Arithmetic a =>
Bool (ArithmeticCircuit a)
-> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a)
dor ((ArithmeticCircuit a -> ArithmeticCircuit a -> ArithmeticCircuit a)
-> [ArithmeticCircuit a]
-> [ArithmeticCircuit a]
-> [ArithmeticCircuit a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) [ArithmeticCircuit a]
xs [ArithmeticCircuit a]
ys)

dor ::
  Arithmetic a =>
  Bool (ArithmeticCircuit a) ->
  Bool (ArithmeticCircuit a) ->
  Bool (ArithmeticCircuit a)
-- ^ @dorAnd a b@ is a schema which computes @a || b@ given @a && b@ is false.
dor :: forall a.
Arithmetic a =>
Bool (ArithmeticCircuit a)
-> Bool (ArithmeticCircuit a) -> Bool (ArithmeticCircuit a)
dor (Bool ArithmeticCircuit a
a) (Bool ArithmeticCircuit a
b) = ArithmeticCircuit a -> Bool (ArithmeticCircuit a)
forall x. x -> Bool x
Bool (ArithmeticCircuit a
a ArithmeticCircuit a -> ArithmeticCircuit a -> ArithmeticCircuit a
forall a. AdditiveSemigroup a => a -> a -> a
+ ArithmeticCircuit a
b)

bitCheckGE :: DiscreteField b x => (b -> b -> b) -> (x -> x) -> [x] -> b
-- ^ @bitCheckGE pl bc ds@ checks if @ds@ contains delta lexicographically
-- greater than or equal to 0, given @pl a b = a || b@ when @a && b@ is false
-- and @bc d = d (d - 1)@.
bitCheckGE :: forall b x.
DiscreteField b x =>
(b -> b -> b) -> (x -> x) -> [x] -> b
bitCheckGE b -> b -> b
_  x -> x
_  []     = b
forall b. BoolType b => b
true
bitCheckGE b -> b -> b
_  x -> x
bc [x
d]    = x -> b
forall b a. DiscreteField b a => a -> b
isZero (x -> x
bc x
d)
bitCheckGE b -> b -> b
pl x -> x
bc (x
d:[x]
ds) = b -> b -> b
pl (x -> b
forall b a. DiscreteField b a => a -> b
isZero (x -> b) -> x -> b
forall a b. (a -> b) -> a -> b
$ x
d x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) (x -> b
forall b a. DiscreteField b a => a -> b
isZero x
d b -> b -> b
forall b. BoolType b => b -> b -> b
&& (b -> b -> b) -> (x -> x) -> [x] -> b
forall b x.
DiscreteField b x =>
(b -> b -> b) -> (x -> x) -> [x] -> b
bitCheckGE b -> b -> b
pl x -> x
bc [x]
ds)

bitCheckGT :: DiscreteField b x => (b -> b -> b) -> [x] -> b
-- ^ @bitCheckGT pl ds@ checks if @ds@ contains delta lexicographically greater
-- than 0, given @pl a b = a || b@ when @a && b@ is false.
bitCheckGT :: forall b x. DiscreteField b x => (b -> b -> b) -> [x] -> b
bitCheckGT b -> b -> b
_  []     = b
forall b. BoolType b => b
false
bitCheckGT b -> b -> b
_  [x
d]    = x -> b
forall b a. DiscreteField b a => a -> b
isZero (x
d x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one)
bitCheckGT b -> b -> b
pl (x
d:[x]
ds) = b -> b -> b
pl (x -> b
forall b a. DiscreteField b a => a -> b
isZero (x -> b) -> x -> b
forall a b. (a -> b) -> a -> b
$ x
d x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) (x -> b
forall b a. DiscreteField b a => a -> b
isZero x
d b -> b -> b
forall b. BoolType b => b -> b -> b
&& (b -> b -> b) -> [x] -> b
forall b x. DiscreteField b x => (b -> b -> b) -> [x] -> b
bitCheckGT b -> b -> b
pl [x]
ds)