{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.PropagateMod
-- Description : propagates modulo operations, and replaces integer functions with corresponding functions with modulo. / 剰余演算を伝播させ、整数の関数を対応する modulo 付きの関数で置き換えます。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.PropagateMod
  ( run,
  )
where

import Control.Monad.Trans.Maybe
import Data.List
import Data.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Format (formatType)
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.TypeCheck
import Jikka.Core.Language.Util

isModulo' :: Expr -> Expr -> Bool
isModulo' :: Expr -> Expr -> Bool
isModulo' Expr
e Expr
m = Expr
e Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo Expr
m

putFloorMod :: MonadAlpha m => Modulo -> Expr -> m (Maybe Expr)
putFloorMod :: Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Modulo Expr
m) =
  MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> (Expr -> MaybeT m Expr) -> Expr -> m (Maybe Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
    Negate' Expr
e -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModNegate' Expr
e Expr
m
    Plus' Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModPlus' Expr
e1 Expr
e2 Expr
m
    Minus' Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModMinus' Expr
e1 Expr
e2 Expr
m
    Mult' Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModMult' Expr
e1 Expr
e2 Expr
m
    JustDiv' Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModMult' Expr
e1 (Expr -> Expr -> Expr
ModInv' Expr
e2 Expr
m) Expr
m
    Pow' Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModPow' Expr
e1 Expr
e2 Expr
m
    MatAp' Integer
h Integer
w Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Expr -> Expr -> Expr -> Expr
ModMatAp' Integer
h Integer
w Expr
e1 Expr
e2 Expr
m
    MatAdd' Integer
h Integer
w Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Expr -> Expr -> Expr -> Expr
ModMatAdd' Integer
h Integer
w Expr
e1 Expr
e2 Expr
m
    MatMul' Integer
h Integer
n Integer
w Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Expr -> Expr -> Expr -> Expr
ModMatMul' Integer
h Integer
n Integer
w Expr
e1 Expr
e2 Expr
m
    MatPow' Integer
n Expr
e1 Expr
e2 -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Expr -> Expr -> Expr -> Expr
ModMatPow' Integer
n Expr
e1 Expr
e2 Expr
m
    Sum' Expr
e -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModSum' Expr
e Expr
m
    Product' Expr
e -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModProduct' Expr
e Expr
m
    LitInt' Integer
n -> case Expr
m of
      LitInt' Integer
m -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Expr
LitInt' (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m)
      Expr
_ -> m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
    Proj' [Type]
ts Integer
i Expr
e | [Type] -> Bool
isVectorTy' [Type]
ts -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
i (Integer -> Expr -> Expr -> Expr
VecFloorMod' ([Type] -> Integer
forall i a. Num i => [a] -> i
genericLength [Type]
ts) Expr
e Expr
m)
    Proj' [Type]
ts Integer
i Expr
e
      | [Type] -> Bool
isMatrixTy' [Type]
ts ->
        let (Int
h, Int
w) = Maybe (Int, Int) -> (Int, Int)
forall a. HasCallStack => Maybe a -> a
fromJust (Type -> Maybe (Int, Int)
sizeOfMatrixTy ([Type] -> Type
TupleTy [Type]
ts))
         in Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
i (Integer -> Integer -> Expr -> Expr -> Expr
MatFloorMod' (Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
h) (Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
w) Expr
e Expr
m)
    Map' Type
t1 Type
t2 Expr
f Expr
xs -> do
      Expr
f <- m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
f
      Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 Expr
f Expr
xs
    Foldl' Type
t1 Type
t2 Expr
f Expr
init Expr
xs -> do
      Expr
f <- m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
f
      Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
t1 Type
t2 Expr
f Expr
init Expr
xs
    Lam VarName
x Type
t Expr
body -> do
      -- TODO: rename only if required
      VarName
y <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
      Expr
body <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (VarName -> Expr
Var VarName
y) Expr
body
      Expr
body <- m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
body
      Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> Expr
Lam VarName
y Type
t Expr
body
    e :: Expr
e@(App Expr
_ Expr
_) -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
      (f :: Expr
f@(Lam VarName
_ Type
_ Expr
_), [Expr]
args) -> do
        Expr
f <- m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
f
        Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp Expr
f [Expr]
args
      (Tuple' [Type]
ts, [Expr]
es) | [Type] -> Bool
isVectorTy' [Type]
ts -> do
        [Maybe Expr]
es' <- m [Maybe Expr] -> MaybeT m [Maybe Expr]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m [Maybe Expr] -> MaybeT m [Maybe Expr])
-> m [Maybe Expr] -> MaybeT m [Maybe Expr]
forall a b. (a -> b) -> a -> b
$ (Expr -> m (Maybe Expr)) -> [Expr] -> m [Maybe Expr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m)) [Expr]
es
        if (Maybe Expr -> Bool) -> [Maybe Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe Expr -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe Expr]
es'
          then m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
          else Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((Expr -> Maybe Expr -> Expr) -> [Expr] -> [Maybe Expr] -> [Expr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe [Expr]
es [Maybe Expr]
es')
      (Tuple' [Type]
ts, [Expr]
es) | Type -> Bool
isMatrixTy ([Type] -> Type
TupleTy [Type]
ts) -> do
        [Maybe Expr]
es' <- m [Maybe Expr] -> MaybeT m [Maybe Expr]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m [Maybe Expr] -> MaybeT m [Maybe Expr])
-> m [Maybe Expr] -> MaybeT m [Maybe Expr]
forall a b. (a -> b) -> a -> b
$ (Expr -> m (Maybe Expr)) -> [Expr] -> m [Maybe Expr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m)) [Expr]
es
        if (Maybe Expr -> Bool) -> [Maybe Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe Expr -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe Expr]
es'
          then m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
          else Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((Expr -> Maybe Expr -> Expr) -> [Expr] -> [Maybe Expr] -> [Expr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe [Expr]
es [Maybe Expr]
es')
      (Expr, [Expr])
_ -> m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
    Expr
_ -> m (Maybe Expr) -> MaybeT m Expr
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe Expr) -> MaybeT m Expr)
-> m (Maybe Expr) -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

putFloorModGeneric :: MonadAlpha m => (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric :: (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric Expr -> Modulo -> m Expr
fallback Modulo
m Expr
e =
  if Expr
e Expr -> Modulo -> Bool
`isModulo` Modulo
m
    then Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e
    else do
      Maybe Expr
e' <- Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod Modulo
m Expr
e
      case Maybe Expr
e' of
        Just Expr
e' -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e'
        Maybe Expr
Nothing -> Expr -> Modulo -> m Expr
fallback Expr
e Modulo
m

putMapFloorMod :: MonadAlpha m => Modulo -> Expr -> m Expr
putMapFloorMod :: Modulo -> Expr -> m Expr
putMapFloorMod = (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric Expr -> Modulo -> m Expr
forall (m :: * -> *). MonadAlpha m => Expr -> Modulo -> m Expr
fallback
  where
    fallback :: Expr -> Modulo -> m Expr
fallback Expr
e (Modulo Expr
m) = do
      VarName
x <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Expr -> Expr -> Expr
Map' Type
IntTy Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
x Type
IntTy (Expr -> Expr -> Expr
FloorMod' (VarName -> Expr
Var VarName
x) Expr
m)) Expr
e

putVecFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Modulo -> Expr -> m Expr
putVecFloorMod :: [(VarName, Type)] -> Modulo -> Expr -> m Expr
putVecFloorMod [(VarName, Type)]
env = (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric Expr -> Modulo -> m Expr
forall (m :: * -> *).
MonadError Error m =>
Expr -> Modulo -> m Expr
fallback
  where
    fallback :: Expr -> Modulo -> m Expr
fallback Expr
e (Modulo Expr
m) = do
      Type
t <- [(VarName, Type)] -> Expr -> m Type
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Type
typecheckExpr [(VarName, Type)]
env Expr
e
      case Type
t of
        TupleTy [Type]
ts -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Expr -> Expr -> Expr
VecFloorMod' ([Type] -> Integer
forall i a. Num i => [a] -> i
genericLength [Type]
ts) Expr
e Expr
m
        Type
_ -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m Expr) -> String -> m Expr
forall a b. (a -> b) -> a -> b
$ String
"not a vector: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t

putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod :: [(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env = (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric Expr -> Modulo -> m Expr
forall (m :: * -> *).
MonadError Error m =>
Expr -> Modulo -> m Expr
fallback
  where
    fallback :: Expr -> Modulo -> m Expr
fallback Expr
e (Modulo Expr
m) = do
      Type
t <- [(VarName, Type)] -> Expr -> m Type
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Type
typecheckExpr [(VarName, Type)]
env Expr
e
      case Type
t of
        TupleTy ts :: [Type]
ts@(TupleTy [Type]
ts' : [Type]
_) -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Expr -> Expr -> Expr
MatFloorMod' ([Type] -> Integer
forall i a. Num i => [a] -> i
genericLength [Type]
ts) ([Type] -> Integer
forall i a. Num i => [a] -> i
genericLength [Type]
ts') Expr
e Expr
m
        Type
_ -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m Expr) -> String -> m Expr
forall a b. (a -> b) -> a -> b
$ String
"not a matrix: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
  let go0 :: Expr -> Maybe Expr
      go0 :: Expr -> Maybe Expr
go0 Expr
e = do
        Expr
e' <- ModuloExpr -> Expr
formatBottomModuloExpr (ModuloExpr -> Expr) -> Maybe ModuloExpr -> Maybe Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe ModuloExpr
parseModuloExpr Expr
e
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Expr
e' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
e
        Expr -> Maybe Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e'
      go1 :: Expr -> (a -> Expr -> a) -> (Modulo -> t -> f a, t) -> f (Maybe a)
go1 Expr
m a -> Expr -> a
f (Modulo -> t -> f a
t1, t
e1) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> f a -> f (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> Expr -> a
f (a -> Expr -> a) -> f a -> f (Expr -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Modulo -> t -> f a
t1 (Expr -> Modulo
Modulo Expr
m) t
e1 f (Expr -> a) -> f Expr -> f a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> f Expr
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr
m)
      go2 :: Expr
-> (a -> a -> Expr -> a)
-> (Modulo -> t -> f a, t)
-> (Modulo -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m a -> a -> Expr -> a
f (Modulo -> t -> f a
t1, t
e1) (Modulo -> t -> f a
t2, t
e2) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> f a -> f (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> a -> Expr -> a
f (a -> a -> Expr -> a) -> f a -> f (a -> Expr -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Modulo -> t -> f a
t1 (Expr -> Modulo
Modulo Expr
m) t
e1 f (a -> Expr -> a) -> f a -> f (Expr -> a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Modulo -> t -> f a
t2 (Expr -> Modulo
Modulo Expr
m) t
e2 f (Expr -> a) -> f Expr -> f a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> f Expr
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr
m)
   in String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
makeRewriteRule String
"Jikka.Core.Convert.PropagateMod" ((RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \RewriteEnvironment
env -> \case
        e :: Expr
e@(ModNegate' Expr
_ Expr
_) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
go0 Expr
e
        e :: Expr
e@(ModPlus' Expr
_ Expr
_ Expr
_) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
go0 Expr
e
        e :: Expr
e@(ModMinus' Expr
_ Expr
_ Expr
_) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
go0 Expr
e
        e :: Expr
e@(ModMult' Expr
_ Expr
_ Expr
_) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
go0 Expr
e
        e :: Expr
e@(ModInv' Expr
_ Expr
_) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
go0 Expr
e
        e :: Expr
e@(ModPow' Expr
_ Expr
_ Expr
_) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
go0 Expr
e
        ModMatAp' Integer
h Integer
w Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo'` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo'` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Modulo -> t -> f a, t)
-> (Modulo -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Integer -> Integer -> Expr -> Expr -> Expr -> Expr
ModMatAp' Integer
h Integer
w) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e1) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putVecFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e2)
        ModMatAdd' Integer
h Integer
w Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo'` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo'` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Modulo -> t -> f a, t)
-> (Modulo -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Integer -> Integer -> Expr -> Expr -> Expr -> Expr
ModMatAdd' Integer
h Integer
w) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e1) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e2)
        ModMatMul' Integer
h Integer
n Integer
w Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo'` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo'` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Modulo -> t -> f a, t)
-> (Modulo -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Integer -> Integer -> Integer -> Expr -> Expr -> Expr -> Expr
ModMatMul' Integer
h Integer
n Integer
w) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e1) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e2)
        ModMatPow' Integer
n Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo'` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Modulo -> t -> f a, t)
-> (Modulo -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Integer -> Expr -> Expr -> Expr -> Expr
ModMatPow' Integer
n) ([(VarName, Type)] -> Modulo -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env), Expr
e1) (\Modulo
_ Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e, Expr
e2)
        ModSum' Expr
e Expr
m | Bool -> Bool
not (Expr
e Expr -> Expr -> Bool
`isModulo'` Expr
m) -> Expr
-> (Expr -> Expr -> Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a t.
Applicative f =>
Expr -> (a -> Expr -> a) -> (Modulo -> t -> f a, t) -> f (Maybe a)
go1 Expr
m Expr -> Expr -> Expr
ModSum' (Modulo -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Modulo -> Expr -> m Expr
putMapFloorMod, Expr
e)
        ModProduct' Expr
e Expr
m | Bool -> Bool
not (Expr
e Expr -> Expr -> Bool
`isModulo'` Expr
m) -> Expr
-> (Expr -> Expr -> Expr)
-> (Modulo -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a t.
Applicative f =>
Expr -> (a -> Expr -> a) -> (Modulo -> t -> f a, t) -> f (Maybe a)
go1 Expr
m Expr -> Expr -> Expr
ModProduct' (Modulo -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Modulo -> Expr -> m Expr
putMapFloorMod, Expr
e)
        FloorMod' Expr
e Expr
m ->
          if Expr
e Expr -> Expr -> Bool
`isModulo'` Expr
m
            then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e
            else Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
e
        VecFloorMod' Integer
_ Expr
e Expr
m ->
          if Expr
e Expr -> Expr -> Bool
`isModulo'` Expr
m
            then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e
            else Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
e
        MatFloorMod' Integer
_ Integer
_ Expr
e Expr
m ->
          if Expr
e Expr -> Expr -> Bool
`isModulo'` Expr
m
            then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e
            else Modulo -> Expr -> m (Maybe Expr)
forall (m :: * -> *).
MonadAlpha m =>
Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Modulo
Modulo Expr
m) Expr
e
        Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` propagates `FloorMod` to leaves of exprs.
-- For example, this converts the following:
--
-- > mod ((fun x -> x * x + x) y) 1000000007
--
-- to:
--
-- > (fun x -> mod (mod (x * x) 1000000007 + x) 1000000007) y
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.PropagateMod" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog