{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

module Jikka.Core.Language.ModuloExpr
  ( -- * Basic functions
    ModuloExpr,
    parseModuloExpr,
    formatTopModuloExpr,
    formatBottomModuloExpr,
    integerModuloExpr,
    negateModuloExpr,
    plusModuloExpr,
    minusModuloExpr,
    multModuloExpr,
    isZeroModuloExpr,
    isOneModuloExpr,
    moduloOfModuloExpr,
    arithmeticExprFromModuloExpr,

    -- * Utilities
    Modulo (..),
    formatModulo,
    isModulo,
  )
where

import Data.List
import Jikka.Common.Error
import Jikka.Core.Format (formatExpr)
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Runtime (modinv, modpow)
import Jikka.Core.Language.Util

-- | `Modulo` is just a newtype to avoid mistakes that swapping left and right of mod-op.
newtype Modulo = Modulo {Modulo -> Expr
unModulo :: Expr}
  deriving (Modulo -> Modulo -> Bool
(Modulo -> Modulo -> Bool)
-> (Modulo -> Modulo -> Bool) -> Eq Modulo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Modulo -> Modulo -> Bool
$c/= :: Modulo -> Modulo -> Bool
== :: Modulo -> Modulo -> Bool
$c== :: Modulo -> Modulo -> Bool
Eq, Eq Modulo
Eq Modulo
-> (Modulo -> Modulo -> Ordering)
-> (Modulo -> Modulo -> Bool)
-> (Modulo -> Modulo -> Bool)
-> (Modulo -> Modulo -> Bool)
-> (Modulo -> Modulo -> Bool)
-> (Modulo -> Modulo -> Modulo)
-> (Modulo -> Modulo -> Modulo)
-> Ord Modulo
Modulo -> Modulo -> Bool
Modulo -> Modulo -> Ordering
Modulo -> Modulo -> Modulo
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Modulo -> Modulo -> Modulo
$cmin :: Modulo -> Modulo -> Modulo
max :: Modulo -> Modulo -> Modulo
$cmax :: Modulo -> Modulo -> Modulo
>= :: Modulo -> Modulo -> Bool
$c>= :: Modulo -> Modulo -> Bool
> :: Modulo -> Modulo -> Bool
$c> :: Modulo -> Modulo -> Bool
<= :: Modulo -> Modulo -> Bool
$c<= :: Modulo -> Modulo -> Bool
< :: Modulo -> Modulo -> Bool
$c< :: Modulo -> Modulo -> Bool
compare :: Modulo -> Modulo -> Ordering
$ccompare :: Modulo -> Modulo -> Ordering
$cp1Ord :: Eq Modulo
Ord, Int -> Modulo -> ShowS
[Modulo] -> ShowS
Modulo -> String
(Int -> Modulo -> ShowS)
-> (Modulo -> String) -> ([Modulo] -> ShowS) -> Show Modulo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Modulo] -> ShowS
$cshowList :: [Modulo] -> ShowS
show :: Modulo -> String
$cshow :: Modulo -> String
showsPrec :: Int -> Modulo -> ShowS
$cshowsPrec :: Int -> Modulo -> ShowS
Show, ReadPrec [Modulo]
ReadPrec Modulo
Int -> ReadS Modulo
ReadS [Modulo]
(Int -> ReadS Modulo)
-> ReadS [Modulo]
-> ReadPrec Modulo
-> ReadPrec [Modulo]
-> Read Modulo
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Modulo]
$creadListPrec :: ReadPrec [Modulo]
readPrec :: ReadPrec Modulo
$creadPrec :: ReadPrec Modulo
readList :: ReadS [Modulo]
$creadList :: ReadS [Modulo]
readsPrec :: Int -> ReadS Modulo
$creadsPrec :: Int -> ReadS Modulo
Read)

formatModulo :: Modulo -> String
formatModulo :: Modulo -> String
formatModulo = Expr -> String
formatExpr (Expr -> String) -> (Modulo -> Expr) -> Modulo -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Modulo -> Expr
unModulo

isModulo :: Expr -> Modulo -> Bool
isModulo :: Expr -> Modulo -> Bool
isModulo Expr
e (Modulo Expr
m) = case Expr
e of
  FloorMod' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModNegate' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModPlus' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMinus' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMult' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModInv' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModPow' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  VecFloorMod' Integer
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  MatFloorMod' Integer
_ Integer
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatAp' Integer
_ Integer
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatAdd' Integer
_ Integer
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatMul' Integer
_ Integer
_ Integer
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatPow' Integer
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModSum' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModProduct' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  LitInt' Integer
n -> case Expr
m of
    LitInt' Integer
m -> Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n Bool -> Bool -> Bool
&& Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
m
    Expr
_ -> Bool
False
  Proj' [Type]
ts Integer
_ Expr
e | [Type] -> Bool
isVectorTy' [Type]
ts -> Expr
e Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m
  Proj' [Type]
ts Integer
_ Expr
e | [Type] -> Bool
isMatrixTy' [Type]
ts -> Expr
e Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m
  Map' Type
_ Type
_ Expr
f Expr
_ -> Expr
f Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m
  Lam VarName
_ Type
_ Expr
body -> Expr
body Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m
  e :: Expr
e@(App Expr
_ Expr
_) -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
    (e :: Expr
e@(Lam VarName
_ Type
_ Expr
_), [Expr]
_) -> Expr
e Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m
    (Tuple' [Type]
ts, [Expr]
es) | [Type] -> Bool
isVectorTy' [Type]
ts -> (Expr -> Bool) -> [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m) [Expr]
es
    (Tuple' [Type]
ts, [Expr]
es) | [Type] -> Bool
isMatrixTy' [Type]
ts -> (Expr -> Bool) -> [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m) [Expr]
es
    (Expr, [Expr])
_ -> Bool
False
  Expr
_ -> Bool
False

data ProductExpr = ProductExpr
  { ProductExpr -> Integer
productExprConst :: Integer,
    ProductExpr -> [Expr]
productExprList :: [Expr],
    ProductExpr -> [Expr]
productExprInvList :: [Expr]
  }
  deriving (ProductExpr -> ProductExpr -> Bool
(ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool) -> Eq ProductExpr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ProductExpr -> ProductExpr -> Bool
$c/= :: ProductExpr -> ProductExpr -> Bool
== :: ProductExpr -> ProductExpr -> Bool
$c== :: ProductExpr -> ProductExpr -> Bool
Eq, Eq ProductExpr
Eq ProductExpr
-> (ProductExpr -> ProductExpr -> Ordering)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> ProductExpr)
-> (ProductExpr -> ProductExpr -> ProductExpr)
-> Ord ProductExpr
ProductExpr -> ProductExpr -> Bool
ProductExpr -> ProductExpr -> Ordering
ProductExpr -> ProductExpr -> ProductExpr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ProductExpr -> ProductExpr -> ProductExpr
$cmin :: ProductExpr -> ProductExpr -> ProductExpr
max :: ProductExpr -> ProductExpr -> ProductExpr
$cmax :: ProductExpr -> ProductExpr -> ProductExpr
>= :: ProductExpr -> ProductExpr -> Bool
$c>= :: ProductExpr -> ProductExpr -> Bool
> :: ProductExpr -> ProductExpr -> Bool
$c> :: ProductExpr -> ProductExpr -> Bool
<= :: ProductExpr -> ProductExpr -> Bool
$c<= :: ProductExpr -> ProductExpr -> Bool
< :: ProductExpr -> ProductExpr -> Bool
$c< :: ProductExpr -> ProductExpr -> Bool
compare :: ProductExpr -> ProductExpr -> Ordering
$ccompare :: ProductExpr -> ProductExpr -> Ordering
$cp1Ord :: Eq ProductExpr
Ord, Int -> ProductExpr -> ShowS
[ProductExpr] -> ShowS
ProductExpr -> String
(Int -> ProductExpr -> ShowS)
-> (ProductExpr -> String)
-> ([ProductExpr] -> ShowS)
-> Show ProductExpr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ProductExpr] -> ShowS
$cshowList :: [ProductExpr] -> ShowS
show :: ProductExpr -> String
$cshow :: ProductExpr -> String
showsPrec :: Int -> ProductExpr -> ShowS
$cshowsPrec :: Int -> ProductExpr -> ShowS
Show, ReadPrec [ProductExpr]
ReadPrec ProductExpr
Int -> ReadS ProductExpr
ReadS [ProductExpr]
(Int -> ReadS ProductExpr)
-> ReadS [ProductExpr]
-> ReadPrec ProductExpr
-> ReadPrec [ProductExpr]
-> Read ProductExpr
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ProductExpr]
$creadListPrec :: ReadPrec [ProductExpr]
readPrec :: ReadPrec ProductExpr
$creadPrec :: ReadPrec ProductExpr
readList :: ReadS [ProductExpr]
$creadList :: ReadS [ProductExpr]
readsPrec :: Int -> ReadS ProductExpr
$creadsPrec :: Int -> ReadS ProductExpr
Read)

data SumExpr = SumExpr
  { SumExpr -> [ProductExpr]
sumExprList :: [ProductExpr],
    SumExpr -> Integer
sumExprConst :: Integer
  }
  deriving (SumExpr -> SumExpr -> Bool
(SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool) -> Eq SumExpr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SumExpr -> SumExpr -> Bool
$c/= :: SumExpr -> SumExpr -> Bool
== :: SumExpr -> SumExpr -> Bool
$c== :: SumExpr -> SumExpr -> Bool
Eq, Eq SumExpr
Eq SumExpr
-> (SumExpr -> SumExpr -> Ordering)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> SumExpr)
-> (SumExpr -> SumExpr -> SumExpr)
-> Ord SumExpr
SumExpr -> SumExpr -> Bool
SumExpr -> SumExpr -> Ordering
SumExpr -> SumExpr -> SumExpr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SumExpr -> SumExpr -> SumExpr
$cmin :: SumExpr -> SumExpr -> SumExpr
max :: SumExpr -> SumExpr -> SumExpr
$cmax :: SumExpr -> SumExpr -> SumExpr
>= :: SumExpr -> SumExpr -> Bool
$c>= :: SumExpr -> SumExpr -> Bool
> :: SumExpr -> SumExpr -> Bool
$c> :: SumExpr -> SumExpr -> Bool
<= :: SumExpr -> SumExpr -> Bool
$c<= :: SumExpr -> SumExpr -> Bool
< :: SumExpr -> SumExpr -> Bool
$c< :: SumExpr -> SumExpr -> Bool
compare :: SumExpr -> SumExpr -> Ordering
$ccompare :: SumExpr -> SumExpr -> Ordering
$cp1Ord :: Eq SumExpr
Ord, Int -> SumExpr -> ShowS
[SumExpr] -> ShowS
SumExpr -> String
(Int -> SumExpr -> ShowS)
-> (SumExpr -> String) -> ([SumExpr] -> ShowS) -> Show SumExpr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SumExpr] -> ShowS
$cshowList :: [SumExpr] -> ShowS
show :: SumExpr -> String
$cshow :: SumExpr -> String
showsPrec :: Int -> SumExpr -> ShowS
$cshowsPrec :: Int -> SumExpr -> ShowS
Show, ReadPrec [SumExpr]
ReadPrec SumExpr
Int -> ReadS SumExpr
ReadS [SumExpr]
(Int -> ReadS SumExpr)
-> ReadS [SumExpr]
-> ReadPrec SumExpr
-> ReadPrec [SumExpr]
-> Read SumExpr
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [SumExpr]
$creadListPrec :: ReadPrec [SumExpr]
readPrec :: ReadPrec SumExpr
$creadPrec :: ReadPrec SumExpr
readList :: ReadS [SumExpr]
$creadList :: ReadS [SumExpr]
readsPrec :: Int -> ReadS SumExpr
$creadsPrec :: Int -> ReadS SumExpr
Read)

data ModuloExpr = ModuloExpr
  { ModuloExpr -> SumExpr
unModuloExpr :: SumExpr,
    ModuloExpr -> Modulo
modulo :: Modulo
  }
  deriving (Int -> ModuloExpr -> ShowS
[ModuloExpr] -> ShowS
ModuloExpr -> String
(Int -> ModuloExpr -> ShowS)
-> (ModuloExpr -> String)
-> ([ModuloExpr] -> ShowS)
-> Show ModuloExpr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ModuloExpr] -> ShowS
$cshowList :: [ModuloExpr] -> ShowS
show :: ModuloExpr -> String
$cshow :: ModuloExpr -> String
showsPrec :: Int -> ModuloExpr -> ShowS
$cshowsPrec :: Int -> ModuloExpr -> ShowS
Show)

instance Eq ModuloExpr where
  ModuloExpr
e1 == :: ModuloExpr -> ModuloExpr -> Bool
== ModuloExpr
e2 = ModuloExpr -> SumExpr
unModuloExpr (ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e1) SumExpr -> SumExpr -> Bool
forall a. Eq a => a -> a -> Bool
== ModuloExpr -> SumExpr
unModuloExpr (ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e2) Bool -> Bool -> Bool
&& ModuloExpr -> Modulo
modulo ModuloExpr
e1 Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== ModuloExpr -> Modulo
modulo ModuloExpr
e2

instance Ord ModuloExpr where
  ModuloExpr
e1 compare :: ModuloExpr -> ModuloExpr -> Ordering
`compare` ModuloExpr
e2 = (ModuloExpr -> SumExpr
unModuloExpr (ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e1), ModuloExpr -> Modulo
modulo ModuloExpr
e1) (SumExpr, Modulo) -> (SumExpr, Modulo) -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` (ModuloExpr -> SumExpr
unModuloExpr (ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e2), ModuloExpr -> Modulo
modulo ModuloExpr
e2)

negateProductExpr :: ProductExpr -> ProductExpr
negateProductExpr :: ProductExpr -> ProductExpr
negateProductExpr ProductExpr
e = ProductExpr
e {productExprConst :: Integer
productExprConst = Integer -> Integer
forall a. Num a => a -> a
negate (ProductExpr -> Integer
productExprConst ProductExpr
e)}

multProductExpr :: ProductExpr -> ProductExpr -> ProductExpr
multProductExpr :: ProductExpr -> ProductExpr -> ProductExpr
multProductExpr ProductExpr
e1 ProductExpr
e2 =
  ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr
    { productExprConst :: Integer
productExprConst = ProductExpr -> Integer
productExprConst ProductExpr
e1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* ProductExpr -> Integer
productExprConst ProductExpr
e2,
      productExprList :: [Expr]
productExprList = ProductExpr -> [Expr]
productExprList ProductExpr
e1 [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ ProductExpr -> [Expr]
productExprList ProductExpr
e2,
      productExprInvList :: [Expr]
productExprInvList = ProductExpr -> [Expr]
productExprInvList ProductExpr
e1 [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ ProductExpr -> [Expr]
productExprInvList ProductExpr
e2
    }

powProductExpr :: Integer -> ProductExpr -> Integer -> ProductExpr
powProductExpr :: Integer -> ProductExpr -> Integer -> ProductExpr
powProductExpr Integer
m ProductExpr
e Integer
k =
  ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr
    { productExprConst :: Integer
productExprConst = Either Error Integer -> Integer
forall a. Either Error a -> a
fromSuccess (Either Error Integer -> Integer)
-> Either Error Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Either Error Integer
forall (m :: * -> *).
MonadError Error m =>
Integer -> Integer -> Integer -> m Integer
modpow (ProductExpr -> Integer
productExprConst ProductExpr
e) Integer
k Integer
m,
      productExprList :: [Expr]
productExprList = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Expr
e -> Expr -> Expr -> Expr -> Expr
ModPow' Expr
e (Integer -> Expr
LitInt' Integer
k) (Integer -> Expr
LitInt' Integer
m)) (ProductExpr -> [Expr]
productExprList ProductExpr
e),
      productExprInvList :: [Expr]
productExprInvList = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Expr
e -> Expr -> Expr -> Expr -> Expr
ModPow' Expr
e (Integer -> Expr
LitInt' Integer
k) (Integer -> Expr
LitInt' Integer
m)) (ProductExpr -> [Expr]
productExprInvList ProductExpr
e)
    }

partitionMaybe :: (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe :: (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe a -> Maybe b
f = \case
  [] -> ([], [])
  a
x : [a]
xs ->
    let ([b]
ys, [a]
xs') = (a -> Maybe b) -> [a] -> ([b], [a])
forall a b. (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe a -> Maybe b
f [a]
xs
     in case a -> Maybe b
f a
x of
          Just b
y -> (b
y b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
ys, [a]
xs')
          Maybe b
Nothing -> ([b]
ys, a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs')

fromLitInt :: Expr -> Maybe Integer
fromLitInt :: Expr -> Maybe Integer
fromLitInt (LitInt' Integer
k) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
k
fromLitInt Expr
_ = Maybe Integer
forall a. Maybe a
Nothing

invProductExpr :: Modulo -> ProductExpr -> ProductExpr
invProductExpr :: Modulo -> ProductExpr -> ProductExpr
invProductExpr Modulo
m ProductExpr
e =
  let ([Integer]
invKs, [Expr]
invList) = (Expr -> Maybe Integer) -> [Expr] -> ([Integer], [Expr])
forall a b. (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe Expr -> Maybe Integer
fromLitInt (ProductExpr -> [Expr]
productExprInvList ProductExpr
e)
      e' :: ProductExpr
e' =
        ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr
          { productExprConst :: Integer
productExprConst = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Integer]
invKs,
            productExprList :: [Expr]
productExprList = Integer -> Expr
LitInt' (ProductExpr -> Integer
productExprConst ProductExpr
e) Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: [Expr]
invList,
            productExprInvList :: [Expr]
productExprInvList = ProductExpr -> [Expr]
productExprList ProductExpr
e
          }
   in case Modulo
m of
        Modulo (LitInt' Integer
m) -> case Integer -> Integer -> Either Error Integer
forall (m :: * -> *).
MonadError Error m =>
Integer -> Integer -> m Integer
modinv (ProductExpr -> Integer
productExprConst ProductExpr
e) Integer
m of
          Right Integer
k ->
            ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr
              { productExprConst :: Integer
productExprConst = (Integer
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Integer]
invKs) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m,
                productExprList :: [Expr]
productExprList = [Expr]
invList,
                productExprInvList :: [Expr]
productExprInvList = ProductExpr -> [Expr]
productExprList ProductExpr
e
              }
          Either Error Integer
_ -> ProductExpr
e'
        Modulo
_ -> ProductExpr
e'

isInteger :: Modulo -> Bool
isInteger :: Modulo -> Bool
isInteger (Modulo (LitInt' Integer
_)) = Bool
True
isInteger Modulo
_ = Bool
False

moduloToInteger :: Modulo -> Integer
moduloToInteger :: Modulo -> Integer
moduloToInteger (Modulo (LitInt' Integer
m)) = Integer
m
moduloToInteger Modulo
m = String -> Integer
forall a. HasCallStack => String -> a
error (String -> Integer) -> String -> Integer
forall a b. (a -> b) -> a -> b
$ String
"Jikka.Core.Language.ModuloExpr.moduloToInteger: modulo is not an integer" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Modulo -> String
formatModulo Modulo
m

parseProductExpr :: Modulo -> Expr -> ProductExpr
parseProductExpr :: Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m = \case
  LitInt' Integer
n -> ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr {productExprConst :: Integer
productExprConst = Integer
n, productExprList :: [Expr]
productExprList = [], productExprInvList :: [Expr]
productExprInvList = []}
  Negate' Expr
e -> ProductExpr -> ProductExpr
negateProductExpr (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e)
  Mult' Expr
e1 Expr
e2 -> ProductExpr -> ProductExpr -> ProductExpr
multProductExpr (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e1) (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e2)
  JustDiv' Expr
e1 Expr
e2 -> ProductExpr -> ProductExpr -> ProductExpr
multProductExpr (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e1) (Modulo -> ProductExpr -> ProductExpr
invProductExpr Modulo
m (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e2))
  Pow' Expr
e1 (LitInt' Integer
k) | Modulo -> Bool
isInteger Modulo
m -> Integer -> ProductExpr -> Integer -> ProductExpr
powProductExpr (Modulo -> Integer
moduloToInteger Modulo
m) (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e1) Integer
k
  ModNegate' Expr
e Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> ProductExpr -> ProductExpr
negateProductExpr (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e)
  ModMult' Expr
e1 Expr
e2 Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> ProductExpr -> ProductExpr -> ProductExpr
multProductExpr (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e1) (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e2)
  ModPow' Expr
e1 (LitInt' Integer
k) (LitInt' Integer
m') | Expr -> Modulo
Modulo (Integer -> Expr
LitInt' Integer
m') Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> Integer -> ProductExpr -> Integer -> ProductExpr
powProductExpr Integer
m' (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e1) Integer
k
  ModInv' Expr
e1 Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> Modulo -> ProductExpr -> ProductExpr
invProductExpr Modulo
m (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e1)
  Expr
e -> ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr {productExprConst :: Integer
productExprConst = Integer
1, productExprList :: [Expr]
productExprList = [Expr
e], productExprInvList :: [Expr]
productExprInvList = []}

sumExprFromProductExpr :: ProductExpr -> SumExpr
sumExprFromProductExpr :: ProductExpr -> SumExpr
sumExprFromProductExpr ProductExpr
e =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList = [ProductExpr
e],
      sumExprConst :: Integer
sumExprConst = Integer
0
    }

integerSumExpr :: Integer -> SumExpr
integerSumExpr :: Integer -> SumExpr
integerSumExpr Integer
n =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprConst :: Integer
sumExprConst = Integer
n,
      sumExprList :: [ProductExpr]
sumExprList = []
    }

integerModuloExpr :: Modulo -> Integer -> ModuloExpr
integerModuloExpr :: Modulo -> Integer -> ModuloExpr
integerModuloExpr Modulo
m Integer
k = SumExpr -> Modulo -> ModuloExpr
ModuloExpr (Integer -> SumExpr
integerSumExpr Integer
k) Modulo
m

negateSumExpr :: SumExpr -> SumExpr
negateSumExpr :: SumExpr -> SumExpr
negateSumExpr SumExpr
e =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList = (ProductExpr -> ProductExpr) -> [ProductExpr] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map ProductExpr -> ProductExpr
negateProductExpr (SumExpr -> [ProductExpr]
sumExprList SumExpr
e),
      sumExprConst :: Integer
sumExprConst = Integer -> Integer
forall a. Num a => a -> a
negate (SumExpr -> Integer
sumExprConst SumExpr
e)
    }

plusSumExpr :: SumExpr -> SumExpr -> SumExpr
plusSumExpr :: SumExpr -> SumExpr -> SumExpr
plusSumExpr SumExpr
e1 SumExpr
e2 =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList = SumExpr -> [ProductExpr]
sumExprList SumExpr
e1 [ProductExpr] -> [ProductExpr] -> [ProductExpr]
forall a. [a] -> [a] -> [a]
++ SumExpr -> [ProductExpr]
sumExprList SumExpr
e2,
      sumExprConst :: Integer
sumExprConst = SumExpr -> Integer
sumExprConst SumExpr
e1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ SumExpr -> Integer
sumExprConst SumExpr
e2
    }

multSumExpr :: Modulo -> SumExpr -> SumExpr -> SumExpr
multSumExpr :: Modulo -> SumExpr -> SumExpr -> SumExpr
multSumExpr Modulo
m SumExpr
e1 SumExpr
e2 =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList =
        let es1 :: [ProductExpr]
es1 = Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m (Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e1)) ProductExpr -> [ProductExpr] -> [ProductExpr]
forall a. a -> [a] -> [a]
: SumExpr -> [ProductExpr]
sumExprList SumExpr
e1
            es2 :: [ProductExpr]
es2 = Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m (Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e2)) ProductExpr -> [ProductExpr] -> [ProductExpr]
forall a. a -> [a] -> [a]
: SumExpr -> [ProductExpr]
sumExprList SumExpr
e2
         in [ProductExpr] -> [ProductExpr]
forall a. [a] -> [a]
tail ([ProductExpr] -> [ProductExpr]) -> [ProductExpr] -> [ProductExpr]
forall a b. (a -> b) -> a -> b
$ ((ProductExpr, ProductExpr) -> ProductExpr)
-> [(ProductExpr, ProductExpr)] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map ((ProductExpr -> ProductExpr -> ProductExpr)
-> (ProductExpr, ProductExpr) -> ProductExpr
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ProductExpr -> ProductExpr -> ProductExpr
multProductExpr) ((,) (ProductExpr -> ProductExpr -> (ProductExpr, ProductExpr))
-> [ProductExpr] -> [ProductExpr -> (ProductExpr, ProductExpr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ProductExpr]
es1 [ProductExpr -> (ProductExpr, ProductExpr)]
-> [ProductExpr] -> [(ProductExpr, ProductExpr)]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [ProductExpr]
es2),
      sumExprConst :: Integer
sumExprConst = SumExpr -> Integer
sumExprConst SumExpr
e1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* SumExpr -> Integer
sumExprConst SumExpr
e2
    }

negateModuloExpr :: ModuloExpr -> ModuloExpr
negateModuloExpr :: ModuloExpr -> ModuloExpr
negateModuloExpr (ModuloExpr SumExpr
e Modulo
m) = SumExpr -> Modulo -> ModuloExpr
ModuloExpr (SumExpr -> SumExpr
negateSumExpr SumExpr
e) Modulo
m

plusModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
plusModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
plusModuloExpr (ModuloExpr SumExpr
e1 Modulo
m) (ModuloExpr SumExpr
e2 Modulo
m') | Modulo
m Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m' = ModuloExpr -> Maybe ModuloExpr
forall a. a -> Maybe a
Just (ModuloExpr -> Maybe ModuloExpr) -> ModuloExpr -> Maybe ModuloExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> Modulo -> ModuloExpr
ModuloExpr (SumExpr -> SumExpr -> SumExpr
plusSumExpr SumExpr
e1 SumExpr
e2) Modulo
m
plusModuloExpr ModuloExpr
_ ModuloExpr
_ = Maybe ModuloExpr
forall a. Maybe a
Nothing

minusModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
minusModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
minusModuloExpr (ModuloExpr SumExpr
e1 Modulo
m) (ModuloExpr SumExpr
e2 Modulo
m') | Modulo
m Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m' = ModuloExpr -> Maybe ModuloExpr
forall a. a -> Maybe a
Just (ModuloExpr -> Maybe ModuloExpr) -> ModuloExpr -> Maybe ModuloExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> Modulo -> ModuloExpr
ModuloExpr (SumExpr -> SumExpr -> SumExpr
plusSumExpr SumExpr
e1 (SumExpr -> SumExpr
negateSumExpr SumExpr
e2)) Modulo
m
minusModuloExpr ModuloExpr
_ ModuloExpr
_ = Maybe ModuloExpr
forall a. Maybe a
Nothing

multModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
multModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
multModuloExpr (ModuloExpr SumExpr
e1 Modulo
m) (ModuloExpr SumExpr
e2 Modulo
m') | Modulo
m Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m' = ModuloExpr -> Maybe ModuloExpr
forall a. a -> Maybe a
Just (ModuloExpr -> Maybe ModuloExpr) -> ModuloExpr -> Maybe ModuloExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> Modulo -> ModuloExpr
ModuloExpr (Modulo -> SumExpr -> SumExpr -> SumExpr
multSumExpr Modulo
m SumExpr
e1 SumExpr
e2) Modulo
m
multModuloExpr ModuloExpr
_ ModuloExpr
_ = Maybe ModuloExpr
forall a. Maybe a
Nothing

parseSumExpr :: Modulo -> Expr -> SumExpr
parseSumExpr :: Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m = \case
  LitInt' Integer
n -> SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr {sumExprList :: [ProductExpr]
sumExprList = [], sumExprConst :: Integer
sumExprConst = Integer
n}
  Negate' Expr
e -> SumExpr -> SumExpr
negateSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e)
  Plus' Expr
e1 Expr
e2 -> SumExpr -> SumExpr -> SumExpr
plusSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e1) (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e2)
  Minus' Expr
e1 Expr
e2 -> SumExpr -> SumExpr -> SumExpr
plusSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e1) (SumExpr -> SumExpr
negateSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e2))
  Mult' Expr
e1 Expr
e2 -> Modulo -> SumExpr -> SumExpr -> SumExpr
multSumExpr Modulo
m (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e1) (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e2)
  FloorMod' Expr
e Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e
  ModNegate' Expr
e Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> SumExpr -> SumExpr
negateSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e)
  ModPlus' Expr
e1 Expr
e2 Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> SumExpr -> SumExpr -> SumExpr
plusSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e1) (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e2)
  ModMinus' Expr
e1 Expr
e2 Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> SumExpr -> SumExpr -> SumExpr
plusSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e1) (SumExpr -> SumExpr
negateSumExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e2))
  ModMult' Expr
e1 Expr
e2 Expr
m' | Expr -> Modulo
Modulo Expr
m' Modulo -> Modulo -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo
m -> Modulo -> SumExpr -> SumExpr -> SumExpr
multSumExpr Modulo
m (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e1) (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e2)
  Expr
e -> ProductExpr -> SumExpr
sumExprFromProductExpr (Modulo -> Expr -> ProductExpr
parseProductExpr Modulo
m Expr
e)

getModuloFromExpr :: Expr -> Maybe Modulo
getModuloFromExpr :: Expr -> Maybe Modulo
getModuloFromExpr = \case
  FloorMod' Expr
_ Expr
m -> Modulo -> Maybe Modulo
forall a. a -> Maybe a
Just (Modulo -> Maybe Modulo) -> Modulo -> Maybe Modulo
forall a b. (a -> b) -> a -> b
$ Expr -> Modulo
Modulo Expr
m
  ModNegate' Expr
_ Expr
m -> Modulo -> Maybe Modulo
forall a. a -> Maybe a
Just (Modulo -> Maybe Modulo) -> Modulo -> Maybe Modulo
forall a b. (a -> b) -> a -> b
$ Expr -> Modulo
Modulo Expr
m
  ModPlus' Expr
_ Expr
_ Expr
m -> Modulo -> Maybe Modulo
forall a. a -> Maybe a
Just (Modulo -> Maybe Modulo) -> Modulo -> Maybe Modulo
forall a b. (a -> b) -> a -> b
$ Expr -> Modulo
Modulo Expr
m
  ModMinus' Expr
_ Expr
_ Expr
m -> Modulo -> Maybe Modulo
forall a. a -> Maybe a
Just (Modulo -> Maybe Modulo) -> Modulo -> Maybe Modulo
forall a b. (a -> b) -> a -> b
$ Expr -> Modulo
Modulo Expr
m
  ModMult' Expr
_ Expr
_ Expr
m -> Modulo -> Maybe Modulo
forall a. a -> Maybe a
Just (Modulo -> Maybe Modulo) -> Modulo -> Maybe Modulo
forall a b. (a -> b) -> a -> b
$ Expr -> Modulo
Modulo Expr
m
  Expr
_ -> Maybe Modulo
forall a. Maybe a
Nothing

-- | `parseModuloExpr` converts a given expr to a normal form \(\sum_i \prod_j e _ {i,j}) \bmod m\).
-- This assumes given exprs have the type \(\mathbf{int}\).
parseModuloExpr :: Expr -> Maybe ModuloExpr
parseModuloExpr :: Expr -> Maybe ModuloExpr
parseModuloExpr Expr
e = do
  Modulo
m <- Expr -> Maybe Modulo
getModuloFromExpr Expr
e
  ModuloExpr -> Maybe ModuloExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (ModuloExpr -> Maybe ModuloExpr) -> ModuloExpr -> Maybe ModuloExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> Modulo -> ModuloExpr
ModuloExpr (Modulo -> Expr -> SumExpr
parseSumExpr Modulo
m Expr
e) Modulo
m

formatTopProductExpr :: Modulo -> ProductExpr -> Expr
formatTopProductExpr :: Modulo -> ProductExpr -> Expr
formatTopProductExpr Modulo
m ProductExpr
e =
  let k :: Expr
k = Integer -> Expr
LitInt' (ProductExpr -> Integer
productExprConst ProductExpr
e)
      k' :: Expr -> Expr
k' Expr
e' = case ProductExpr -> Integer
productExprConst ProductExpr
e of
        Integer
0 -> Integer -> Expr
LitInt' Integer
0
        Integer
1 -> Expr
e'
        -1 -> Expr -> Expr
Negate' Expr
e'
        Integer
_ -> Expr -> Expr -> Expr
Mult' Expr
e' Expr
k
      invList :: [Expr]
invList = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Expr -> Expr -> Expr
`FloorMod'` Modulo -> Expr
unModulo Modulo
m) (ProductExpr -> [Expr]
productExprInvList ProductExpr
e)
   in case ProductExpr -> [Expr]
productExprList ProductExpr
e [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr]
invList of
        [] -> Expr
k
        Expr
eHead : [Expr]
esTail -> Expr -> Expr
k' ((Expr -> Expr -> Expr) -> Expr -> [Expr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Expr -> Expr -> Expr
Mult' Expr
eHead [Expr]
esTail)

formatTopSumExpr :: Modulo -> SumExpr -> Expr
formatTopSumExpr :: Modulo -> SumExpr -> Expr
formatTopSumExpr Modulo
m SumExpr
e = case SumExpr -> [ProductExpr]
sumExprList SumExpr
e of
  [] -> Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e)
  ProductExpr
eHead : [ProductExpr]
esTail ->
    let op :: ProductExpr -> Expr -> Expr -> Expr
op ProductExpr
e'
          | ProductExpr -> Integer
productExprConst ProductExpr
e' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = Expr -> Expr -> Expr
Plus'
          | ProductExpr -> Integer
productExprConst ProductExpr
e' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Expr -> Expr -> Expr
Minus'
          | Bool
otherwise = Expr -> Expr -> Expr
forall a b. a -> b -> a
const
        go :: Expr -> ProductExpr -> Expr
go Expr
e1 ProductExpr
e2 = ProductExpr -> Expr -> Expr -> Expr
op ProductExpr
e2 Expr
e1 (Modulo -> ProductExpr -> Expr
formatTopProductExpr Modulo
m (ProductExpr
e2 {productExprConst :: Integer
productExprConst = Integer -> Integer
forall a. Num a => a -> a
abs (ProductExpr -> Integer
productExprConst ProductExpr
e2)}))
        k' :: Expr -> Expr
k' Expr
e'
          | SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = Expr -> Expr -> Expr
Plus' Expr
e' (Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e))
          | SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Expr -> Expr -> Expr
Minus' Expr
e' (Integer -> Expr
LitInt' (Integer -> Integer
forall a. Num a => a -> a
abs (SumExpr -> Integer
sumExprConst SumExpr
e)))
          | Bool
otherwise = Expr
e'
     in Expr -> Expr
k' ((Expr -> ProductExpr -> Expr) -> Expr -> [ProductExpr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Expr -> ProductExpr -> Expr
go (Modulo -> ProductExpr -> Expr
formatTopProductExpr Modulo
m ProductExpr
eHead) [ProductExpr]
esTail)

-- | `formatTopModuloExpr` convert `ModuloExpr` to `Expr` with adding `FloorMod` at only the root.
formatTopModuloExpr :: ModuloExpr -> Expr
formatTopModuloExpr :: ModuloExpr -> Expr
formatTopModuloExpr ModuloExpr
e = Expr -> Expr -> Expr
FloorMod' (Modulo -> SumExpr -> Expr
formatTopSumExpr (ModuloExpr -> Modulo
modulo ModuloExpr
e) (SumExpr -> Expr) -> (ModuloExpr -> SumExpr) -> ModuloExpr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuloExpr -> SumExpr
unModuloExpr (ModuloExpr -> Expr) -> ModuloExpr -> Expr
forall a b. (a -> b) -> a -> b
$ ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e) (Modulo -> Expr
unModulo (Modulo -> Expr) -> Modulo -> Expr
forall a b. (a -> b) -> a -> b
$ ModuloExpr -> Modulo
modulo ModuloExpr
e)

formatBottomInteger :: Modulo -> Integer -> Expr
formatBottomInteger :: Modulo -> Integer -> Expr
formatBottomInteger Modulo
m Integer
k = case Modulo -> Expr
unModulo Modulo
m of
  LitInt' Integer
m -> Integer -> Expr
LitInt' (Integer
k Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m)
  Expr
m -> Expr -> Expr -> Expr
FloorMod' (Integer -> Expr
LitInt' Integer
k) Expr
m

formatBottomProductExpr :: Modulo -> ProductExpr -> Expr
formatBottomProductExpr :: Modulo -> ProductExpr -> Expr
formatBottomProductExpr Modulo
m ProductExpr
e =
  let k :: Expr
k = Modulo -> Integer -> Expr
formatBottomInteger Modulo
m (ProductExpr -> Integer
productExprConst ProductExpr
e)
      k' :: Expr -> Expr
k' Expr
e' = case ProductExpr -> Integer
productExprConst ProductExpr
e of
        Integer
0 -> Integer -> Expr
LitInt' Integer
0
        Integer
1 -> Expr
e'
        -1 -> Expr -> Expr -> Expr
ModNegate' Expr
e' (Modulo -> Expr
unModulo Modulo
m)
        Integer
_ -> Expr -> Expr -> Expr -> Expr
ModMult' Expr
e' Expr
k (Modulo -> Expr
unModulo Modulo
m)
      invList :: [Expr]
invList = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Expr -> Expr -> Expr
`ModInv'` Modulo -> Expr
unModulo Modulo
m) (ProductExpr -> [Expr]
productExprInvList ProductExpr
e)
      list :: [Expr]
list = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Expr
e -> if Expr
e Expr -> Modulo -> Bool
`isModulo` Modulo
m then Expr
e else Expr -> Expr -> Expr
FloorMod' Expr
e (Modulo -> Expr
unModulo Modulo
m)) (ProductExpr -> [Expr]
productExprList ProductExpr
e)
   in case [Expr]
list [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr]
invList of
        [] -> Expr
k
        Expr
eHead : [Expr]
esTail -> Expr -> Expr
k' ((Expr -> Expr -> Expr) -> Expr -> [Expr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Expr
e1 Expr
e2 -> Expr -> Expr -> Expr -> Expr
ModMult' Expr
e1 Expr
e2 (Modulo -> Expr
unModulo Modulo
m)) Expr
eHead [Expr]
esTail)

formatBottomSumExpr :: Modulo -> SumExpr -> Expr
formatBottomSumExpr :: Modulo -> SumExpr -> Expr
formatBottomSumExpr Modulo
m SumExpr
e = case SumExpr -> [ProductExpr]
sumExprList SumExpr
e of
  [] -> Modulo -> Integer -> Expr
formatBottomInteger Modulo
m (SumExpr -> Integer
sumExprConst SumExpr
e)
  ProductExpr
eHead : [ProductExpr]
esTail ->
    let go :: Expr -> ProductExpr -> Expr
go Expr
e1 ProductExpr
e2
          | ProductExpr -> Integer
productExprConst ProductExpr
e2 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Expr
e1
          | ProductExpr -> Integer
productExprConst ProductExpr
e2 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== -Integer
1 = Expr -> Expr -> Expr -> Expr
ModMinus' Expr
e1 (Modulo -> ProductExpr -> Expr
formatBottomProductExpr Modulo
m (ProductExpr
e2 {productExprConst :: Integer
productExprConst = Integer
1})) (Modulo -> Expr
unModulo Modulo
m)
          | Bool
otherwise = Expr -> Expr -> Expr -> Expr
ModPlus' Expr
e1 (Modulo -> ProductExpr -> Expr
formatBottomProductExpr Modulo
m ProductExpr
e2) (Modulo -> Expr
unModulo Modulo
m)
        plusConst :: Expr -> Expr
plusConst Expr
e'
          | SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Expr
e'
          | Bool
otherwise = Expr -> Expr -> Expr -> Expr
ModPlus' Expr
e' (Modulo -> Integer -> Expr
formatBottomInteger Modulo
m (SumExpr -> Integer
sumExprConst SumExpr
e)) (Modulo -> Expr
unModulo Modulo
m)
     in Expr -> Expr
plusConst ((Expr -> ProductExpr -> Expr) -> Expr -> [ProductExpr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Expr -> ProductExpr -> Expr
go (Modulo -> ProductExpr -> Expr
formatBottomProductExpr Modulo
m ProductExpr
eHead) [ProductExpr]
esTail)

-- | `formatBottomModuloExpr` convert `ModuloExpr` to `Expr` with adding `FloorMod` at every nodes.
formatBottomModuloExpr :: ModuloExpr -> Expr
formatBottomModuloExpr :: ModuloExpr -> Expr
formatBottomModuloExpr ModuloExpr
e = Modulo -> SumExpr -> Expr
formatBottomSumExpr (ModuloExpr -> Modulo
modulo ModuloExpr
e) (SumExpr -> Expr) -> (ModuloExpr -> SumExpr) -> ModuloExpr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuloExpr -> SumExpr
unModuloExpr (ModuloExpr -> Expr) -> ModuloExpr -> Expr
forall a b. (a -> b) -> a -> b
$ ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e

normalizeProductExpr :: Modulo -> ProductExpr -> ProductExpr
normalizeProductExpr :: Modulo -> ProductExpr -> ProductExpr
normalizeProductExpr Modulo
m ProductExpr
e =
  let k :: Integer
k = case Modulo
m of
        Modulo (LitInt' Integer
m) -> ProductExpr -> Integer
productExprConst ProductExpr
e Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m
        Modulo
_ -> ProductExpr -> Integer
productExprConst ProductExpr
e
      es :: [Expr]
es =
        if Integer
k Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0
          then []
          else [Expr] -> [Expr]
forall a. Ord a => [a] -> [a]
sort (ProductExpr -> [Expr]
productExprList ProductExpr
e)
      es' :: [Expr]
es' =
        if Integer
k Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0
          then []
          else [Expr] -> [Expr]
forall a. Ord a => [a] -> [a]
sort (ProductExpr -> [Expr]
productExprInvList ProductExpr
e)
   in ProductExpr
e
        { productExprList :: [Expr]
productExprList = [Expr]
es,
          productExprInvList :: [Expr]
productExprInvList = [Expr]
es',
          productExprConst :: Integer
productExprConst = Integer
k
        }

normalizeSumExpr :: Modulo -> SumExpr -> SumExpr
normalizeSumExpr :: Modulo -> SumExpr -> SumExpr
normalizeSumExpr Modulo
m SumExpr
e =
  let f :: ProductExpr -> ([Expr], [Expr])
f ProductExpr
e = (ProductExpr -> [Expr]
productExprList ProductExpr
e, ProductExpr -> [Expr]
productExprInvList ProductExpr
e)
      cmp :: ProductExpr -> ProductExpr -> Ordering
cmp ProductExpr
e1 ProductExpr
e2 = ProductExpr -> ([Expr], [Expr])
f ProductExpr
e1 ([Expr], [Expr]) -> ([Expr], [Expr]) -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` ProductExpr -> ([Expr], [Expr])
f ProductExpr
e2
      cmp' :: ProductExpr -> ProductExpr -> Bool
cmp' ProductExpr
e1 ProductExpr
e2 = ProductExpr -> ProductExpr -> Ordering
cmp ProductExpr
e1 ProductExpr
e2 Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ
      es :: [ProductExpr]
es = (ProductExpr -> ProductExpr -> Ordering)
-> [ProductExpr] -> [ProductExpr]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ProductExpr -> ProductExpr -> Ordering
cmp ((ProductExpr -> ProductExpr) -> [ProductExpr] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Modulo -> ProductExpr -> ProductExpr
normalizeProductExpr Modulo
m) (SumExpr -> [ProductExpr]
sumExprList SumExpr
e))
      es' :: [[ProductExpr]]
es' = (ProductExpr -> ProductExpr -> Bool)
-> [ProductExpr] -> [[ProductExpr]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy ProductExpr -> ProductExpr -> Bool
cmp' [ProductExpr]
es
      es'' :: [ProductExpr]
es'' =
        ([ProductExpr] -> ProductExpr) -> [[ProductExpr]] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map
          ( \[ProductExpr]
group ->
              ProductExpr :: Integer -> [Expr] -> [Expr] -> ProductExpr
ProductExpr
                { productExprConst :: Integer
productExprConst = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ProductExpr -> Integer) -> [ProductExpr] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map ProductExpr -> Integer
productExprConst [ProductExpr]
group),
                  productExprList :: [Expr]
productExprList = ProductExpr -> [Expr]
productExprList ([ProductExpr] -> ProductExpr
forall a. [a] -> a
head [ProductExpr]
group),
                  productExprInvList :: [Expr]
productExprInvList = ProductExpr -> [Expr]
productExprInvList ([ProductExpr] -> ProductExpr
forall a. [a] -> a
head [ProductExpr]
group)
                }
          )
          [[ProductExpr]]
es'
      es''' :: [ProductExpr]
es''' = (ProductExpr -> Bool) -> [ProductExpr] -> [ProductExpr]
forall a. (a -> Bool) -> [a] -> [a]
filter (\ProductExpr
e -> ProductExpr -> Integer
productExprConst ProductExpr
e Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 Bool -> Bool -> Bool
&& Bool -> Bool
not ([Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (ProductExpr -> [Expr]
productExprList ProductExpr
e))) [ProductExpr]
es''
      k :: Integer
k = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ProductExpr -> Integer) -> [ProductExpr] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\ProductExpr
e -> if [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (ProductExpr -> [Expr]
productExprList ProductExpr
e) then ProductExpr -> Integer
productExprConst ProductExpr
e else Integer
0) [ProductExpr]
es'')
      k' :: Integer
k' = SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
k
      k'' :: Integer
k'' = case Modulo
m of
        Modulo (LitInt' Integer
m) -> Integer
k' Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m
        Modulo
_ -> Integer
k'
   in SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
        { sumExprList :: [ProductExpr]
sumExprList = [ProductExpr]
es''',
          sumExprConst :: Integer
sumExprConst = Integer
k''
        }

normalizeModuloExpr :: ModuloExpr -> ModuloExpr
normalizeModuloExpr :: ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e = SumExpr -> Modulo -> ModuloExpr
ModuloExpr (Modulo -> SumExpr -> SumExpr
normalizeSumExpr (ModuloExpr -> Modulo
modulo ModuloExpr
e) (ModuloExpr -> SumExpr
unModuloExpr ModuloExpr
e)) (ModuloExpr -> Modulo
modulo ModuloExpr
e)

isZeroModuloExpr :: ModuloExpr -> Bool
isZeroModuloExpr :: ModuloExpr -> Bool
isZeroModuloExpr ModuloExpr
e = ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e ModuloExpr -> ModuloExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo -> Integer -> ModuloExpr
integerModuloExpr (ModuloExpr -> Modulo
modulo ModuloExpr
e) Integer
0

isOneModuloExpr :: ModuloExpr -> Bool
isOneModuloExpr :: ModuloExpr -> Bool
isOneModuloExpr ModuloExpr
e = ModuloExpr -> ModuloExpr
normalizeModuloExpr ModuloExpr
e ModuloExpr -> ModuloExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Modulo -> Integer -> ModuloExpr
integerModuloExpr (ModuloExpr -> Modulo
modulo ModuloExpr
e) Integer
1

moduloOfModuloExpr :: ModuloExpr -> Expr
moduloOfModuloExpr :: ModuloExpr -> Expr
moduloOfModuloExpr = Modulo -> Expr
unModulo (Modulo -> Expr) -> (ModuloExpr -> Modulo) -> ModuloExpr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuloExpr -> Modulo
modulo

arithmeticExprFromModuloExpr :: ModuloExpr -> ArithmeticExpr
arithmeticExprFromModuloExpr :: ModuloExpr -> ArithmeticExpr
arithmeticExprFromModuloExpr ModuloExpr
e = Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> ArithmeticExpr)
-> (SumExpr -> Expr) -> SumExpr -> ArithmeticExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Modulo -> SumExpr -> Expr
formatTopSumExpr (ModuloExpr -> Modulo
modulo ModuloExpr
e) (SumExpr -> ArithmeticExpr) -> SumExpr -> ArithmeticExpr
forall a b. (a -> b) -> a -> b
$ ModuloExpr -> SumExpr
unModuloExpr ModuloExpr
e