{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.PropagateMod
-- Description : propagates modulo operations, and replaces integer functions with corresponding functions with modulo. / 剰余演算を伝播させ、整数の関数を対応する modulo 付きの関数で置き換えます。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.PropagateMod
  ( run,
  )
where

import Data.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Format (formatType)
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.TypeCheck
import Jikka.Core.Language.Util

-- | `Mod` is a newtype to avoid mistakes that swapping left and right of mod-op.
newtype Mod = Mod Expr

isModulo' :: Expr -> Mod -> Bool
isModulo' :: Expr -> Mod -> Bool
isModulo' Expr
e (Mod Expr
m) = case Expr
e of
  FloorMod' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModNegate' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModPlus' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMinus' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMult' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModInv' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModPow' Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  VecFloorMod' Int
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  MatFloorMod' Int
_ Int
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatAp' Int
_ Int
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatAdd' Int
_ Int
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatMul' Int
_ Int
_ Int
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModMatPow' Int
_ Expr
_ Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModSum' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  ModProduct' Expr
_ Expr
m' -> Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m
  LitInt' Integer
n -> case Expr
m of
    LitInt' Integer
m -> Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n Bool -> Bool -> Bool
&& Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
m
    Expr
_ -> Bool
False
  Proj' [Type]
ts Int
_ Expr
e | [Type] -> Bool
isVectorTy' [Type]
ts -> Expr
e Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m
  Proj' [Type]
ts Int
_ Expr
e | [Type] -> Bool
isMatrixTy' [Type]
ts -> Expr
e Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m
  Map' Type
_ Type
_ Expr
f Expr
_ -> Expr
f Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m
  Lam VarName
_ Type
_ Expr
body -> Expr
body Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m
  e :: Expr
e@(App Expr
_ Expr
_) -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
    (e :: Expr
e@(Lam VarName
_ Type
_ Expr
_), [Expr]
_) -> Expr
e Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m
    (Tuple' [Type]
ts, [Expr]
es) | [Type] -> Bool
isVectorTy' [Type]
ts -> (Expr -> Bool) -> [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m) [Expr]
es
    (Tuple' [Type]
ts, [Expr]
es) | [Type] -> Bool
isMatrixTy' [Type]
ts -> (Expr -> Bool) -> [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m) [Expr]
es
    (Expr, [Expr])
_ -> Bool
False
  Expr
_ -> Bool
False

isModulo :: Expr -> Expr -> Bool
isModulo :: Expr -> Expr -> Bool
isModulo Expr
e Expr
m = Expr
e Expr -> Mod -> Bool
`isModulo'` Expr -> Mod
Mod Expr
m

putFloorMod :: MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod :: Mod -> Expr -> m (Maybe Expr)
putFloorMod (Mod Expr
m) =
  let return' :: a -> m (Maybe a)
return' = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> (a -> Maybe a) -> a -> m (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just
   in \case
        Negate' Expr
e -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModNegate' Expr
e Expr
m
        Plus' Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModPlus' Expr
e1 Expr
e2 Expr
m
        Minus' Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModMinus' Expr
e1 Expr
e2 Expr
m
        Mult' Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModMult' Expr
e1 Expr
e2 Expr
m
        Pow' Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr -> Expr
ModPow' Expr
e1 Expr
e2 Expr
m
        MatAp' Int
h Int
w Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Expr -> Expr -> Expr -> Expr
ModMatAp' Int
h Int
w Expr
e1 Expr
e2 Expr
m
        MatAdd' Int
h Int
w Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Expr -> Expr -> Expr -> Expr
ModMatAdd' Int
h Int
w Expr
e1 Expr
e2 Expr
m
        MatMul' Int
h Int
n Int
w Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> Expr -> Expr -> Expr -> Expr
ModMatMul' Int
h Int
n Int
w Expr
e1 Expr
e2 Expr
m
        MatPow' Int
n Expr
e1 Expr
e2 -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Int -> Expr -> Expr -> Expr -> Expr
ModMatPow' Int
n Expr
e1 Expr
e2 Expr
m
        Sum' Expr
e -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModSum' Expr
e Expr
m
        Product' Expr
e -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModProduct' Expr
e Expr
m
        LitInt' Integer
n -> case Expr
m of
          LitInt' Integer
m -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Integer -> Expr
LitInt' (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m)
          Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
        Proj' [Type]
ts Int
i Expr
e | [Type] -> Bool
isVectorTy' [Type]
ts -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ [Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
i (Int -> Expr -> Expr -> Expr
VecFloorMod' ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts) Expr
e Expr
m)
        Proj' [Type]
ts Int
i Expr
e
          | [Type] -> Bool
isMatrixTy' [Type]
ts ->
            let (Int
h, Int
w) = Maybe (Int, Int) -> (Int, Int)
forall a. HasCallStack => Maybe a -> a
fromJust (Type -> Maybe (Int, Int)
sizeOfMatrixTy ([Type] -> Type
TupleTy [Type]
ts))
             in Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ [Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
i (Int -> Int -> Expr -> Expr -> Expr
MatFloorMod' Int
h Int
w Expr
e Expr
m)
        Map' Type
t1 Type
t2 Expr
f Expr
xs -> do
          Maybe Expr
f <- Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m) Expr
f
          case Maybe Expr
f of
            Maybe Expr
Nothing -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
            Just Expr
f -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 Expr
f Expr
xs
        Lam VarName
x Type
t Expr
body -> do
          -- TODO: rename only if required
          VarName
y <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
          Expr
body <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (VarName -> Expr
Var VarName
y) Expr
body
          Maybe Expr
body <- Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m) Expr
body
          case Maybe Expr
body of
            Maybe Expr
Nothing -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
            Just Expr
body -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> Expr
Lam VarName
y Type
t Expr
body
        e :: Expr
e@(App Expr
_ Expr
_) -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
          (f :: Expr
f@(Lam VarName
_ Type
_ Expr
_), [Expr]
args) -> do
            Maybe Expr
f <- Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m) Expr
f
            case Maybe Expr
f of
              Maybe Expr
Nothing -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
              Just Expr
f -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp Expr
f [Expr]
args
          (Tuple' [Type]
ts, [Expr]
es) | [Type] -> Bool
isVectorTy' [Type]
ts -> do
            [Maybe Expr]
es' <- (Expr -> m (Maybe Expr)) -> [Expr] -> m [Maybe Expr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m)) [Expr]
es
            if (Maybe Expr -> Bool) -> [Maybe Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe Expr -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe Expr]
es'
              then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
              else Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((Expr -> Maybe Expr -> Expr) -> [Expr] -> [Maybe Expr] -> [Expr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe [Expr]
es [Maybe Expr]
es')
          (Tuple' [Type]
ts, [Expr]
es) | Type -> Bool
isMatrixTy ([Type] -> Type
TupleTy [Type]
ts) -> do
            [Maybe Expr]
es' <- (Expr -> m (Maybe Expr)) -> [Expr] -> m [Maybe Expr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m)) [Expr]
es
            if (Maybe Expr -> Bool) -> [Maybe Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe Expr -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe Expr]
es'
              then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
              else Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((Expr -> Maybe Expr -> Expr) -> [Expr] -> [Maybe Expr] -> [Expr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe [Expr]
es [Maybe Expr]
es')
          (Expr, [Expr])
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
        Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

putFloorModGeneric :: MonadAlpha m => (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric :: (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric Expr -> Mod -> m Expr
fallback Mod
m Expr
e =
  if Expr
e Expr -> Mod -> Bool
`isModulo'` Mod
m
    then Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e
    else do
      Maybe Expr
e' <- Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod Mod
m Expr
e
      case Maybe Expr
e' of
        Just Expr
e' -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e'
        Maybe Expr
Nothing -> Expr -> Mod -> m Expr
fallback Expr
e Mod
m

putFloorModInt :: MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt :: Mod -> Expr -> m Expr
putFloorModInt = (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric (\Expr
e (Mod Expr
m) -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
FloorMod' Expr
e Expr
m)

putMapFloorMod :: MonadAlpha m => Mod -> Expr -> m Expr
putMapFloorMod :: Mod -> Expr -> m Expr
putMapFloorMod = (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric Expr -> Mod -> m Expr
forall (m :: * -> *). MonadAlpha m => Expr -> Mod -> m Expr
fallback
  where
    fallback :: Expr -> Mod -> m Expr
fallback Expr
e (Mod Expr
m) = do
      VarName
x <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Expr -> Expr -> Expr
Map' Type
IntTy Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
x Type
IntTy (Expr -> Expr -> Expr
FloorMod' (VarName -> Expr
Var VarName
x) Expr
m)) Expr
e

putVecFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
putVecFloorMod :: [(VarName, Type)] -> Mod -> Expr -> m Expr
putVecFloorMod [(VarName, Type)]
env = (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric Expr -> Mod -> m Expr
forall (m :: * -> *). MonadError Error m => Expr -> Mod -> m Expr
fallback
  where
    fallback :: Expr -> Mod -> m Expr
fallback Expr
e (Mod Expr
m) = do
      Type
t <- [(VarName, Type)] -> Expr -> m Type
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Type
typecheckExpr [(VarName, Type)]
env Expr
e
      case Type
t of
        TupleTy [Type]
ts -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Int -> Expr -> Expr -> Expr
VecFloorMod' ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts) Expr
e Expr
m
        Type
_ -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m Expr) -> String -> m Expr
forall a b. (a -> b) -> a -> b
$ String
"not a vector: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t

putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod :: [(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env = (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric Expr -> Mod -> m Expr
forall (m :: * -> *). MonadError Error m => Expr -> Mod -> m Expr
fallback
  where
    fallback :: Expr -> Mod -> m Expr
fallback Expr
e (Mod Expr
m) = do
      Type
t <- [(VarName, Type)] -> Expr -> m Type
forall (m :: * -> *).
MonadError Error m =>
[(VarName, Type)] -> Expr -> m Type
typecheckExpr [(VarName, Type)]
env Expr
e
      case Type
t of
        TupleTy ts :: [Type]
ts@(TupleTy [Type]
ts' : [Type]
_) -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Expr -> Expr -> Expr
MatFloorMod' ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts) ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts') Expr
e Expr
m
        Type
_ -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m Expr) -> String -> m Expr
forall a b. (a -> b) -> a -> b
$ String
"not a matrix: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
  let go1 :: Expr -> (a -> Expr -> a) -> (Mod -> t -> f a, t) -> f (Maybe a)
go1 Expr
m a -> Expr -> a
f (Mod -> t -> f a
t1, t
e1) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> f a -> f (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> Expr -> a
f (a -> Expr -> a) -> f a -> f (Expr -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mod -> t -> f a
t1 (Expr -> Mod
Mod Expr
m) t
e1 f (Expr -> a) -> f Expr -> f a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> f Expr
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr
m)
      go2 :: Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m a -> a -> Expr -> a
f (Mod -> t -> f a
t1, t
e1) (Mod -> t -> f a
t2, t
e2) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> f a -> f (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> a -> Expr -> a
f (a -> a -> Expr -> a) -> f a -> f (a -> Expr -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mod -> t -> f a
t1 (Expr -> Mod
Mod Expr
m) t
e1 f (a -> Expr -> a) -> f a -> f (Expr -> a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mod -> t -> f a
t2 (Expr -> Mod
Mod Expr
m) t
e2 f (Expr -> a) -> f Expr -> f a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> f Expr
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr
m)
   in ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
env -> \case
        ModNegate' Expr
e Expr
m | Bool -> Bool
not (Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a t.
Applicative f =>
Expr -> (a -> Expr -> a) -> (Mod -> t -> f a, t) -> f (Maybe a)
go1 Expr
m Expr -> Expr -> Expr
ModNegate' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e)
        ModPlus' Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m Expr -> Expr -> Expr -> Expr
ModPlus' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e1) (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e2)
        ModMinus' Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m Expr -> Expr -> Expr -> Expr
ModMinus' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e1) (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e2)
        ModMult' Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m Expr -> Expr -> Expr -> Expr
ModMult' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e1) (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e2)
        ModInv' Expr
e Expr
m | Bool -> Bool
not (Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a t.
Applicative f =>
Expr -> (a -> Expr -> a) -> (Mod -> t -> f a, t) -> f (Maybe a)
go1 Expr
m Expr -> Expr -> Expr
ModInv' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e)
        ModPow' Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m Expr -> Expr -> Expr -> Expr
ModPow' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt, Expr
e1) (\Mod
_ Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e, Expr
e2)
        ModMatAp' Int
h Int
w Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Int -> Int -> Expr -> Expr -> Expr -> Expr
ModMatAp' Int
h Int
w) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env, Expr
e1) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putVecFloorMod [(VarName, Type)]
env, Expr
e2)
        ModMatAdd' Int
h Int
w Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Int -> Int -> Expr -> Expr -> Expr -> Expr
ModMatAdd' Int
h Int
w) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env, Expr
e1) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env, Expr
e2)
        ModMatMul' Int
h Int
n Int
w Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) Bool -> Bool -> Bool
|| Bool -> Bool
not (Expr
e2 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Int -> Int -> Int -> Expr -> Expr -> Expr -> Expr
ModMatMul' Int
h Int
n Int
w) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env, Expr
e1) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env, Expr
e2)
        ModMatPow' Int
n Expr
e1 Expr
e2 Expr
m | Bool -> Bool
not (Expr
e1 Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a a t t.
Applicative f =>
Expr
-> (a -> a -> Expr -> a)
-> (Mod -> t -> f a, t)
-> (Mod -> t -> f a, t)
-> f (Maybe a)
go2 Expr
m (Int -> Expr -> Expr -> Expr -> Expr
ModMatPow' Int
n) ([(VarName, Type)] -> Mod -> Expr -> m Expr
forall (m :: * -> *).
(MonadError Error m, MonadAlpha m) =>
[(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod [(VarName, Type)]
env, Expr
e1) (\Mod
_ Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e, Expr
e2)
        ModSum' Expr
e Expr
m | Bool -> Bool
not (Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a t.
Applicative f =>
Expr -> (a -> Expr -> a) -> (Mod -> t -> f a, t) -> f (Maybe a)
go1 Expr
m Expr -> Expr -> Expr
ModSum' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putMapFloorMod, Expr
e)
        ModProduct' Expr
e Expr
m | Bool -> Bool
not (Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m) -> Expr
-> (Expr -> Expr -> Expr)
-> (Mod -> Expr -> m Expr, Expr)
-> m (Maybe Expr)
forall (f :: * -> *) a a t.
Applicative f =>
Expr -> (a -> Expr -> a) -> (Mod -> t -> f a, t) -> f (Maybe a)
go1 Expr
m Expr -> Expr -> Expr
ModProduct' (Mod -> Expr -> m Expr
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m Expr
putMapFloorMod, Expr
e)
        FloorMod' Expr
e Expr
m ->
          if Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m
            then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e
            else Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m) Expr
e
        VecFloorMod' Int
_ Expr
e Expr
m ->
          if Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m
            then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e
            else Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m) Expr
e
        MatFloorMod' Int
_ Int
_ Expr
e Expr
m ->
          if Expr
e Expr -> Expr -> Bool
`isModulo` Expr
m
            then Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e
            else Mod -> Expr -> m (Maybe Expr)
forall (m :: * -> *). MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Expr -> Mod
Mod Expr
m) Expr
e
        Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` propagates `FloorMod` to leaves of exprs.
-- For example, this converts the following:
--
-- > mod ((fun x -> x * x + x) y) 1000000007
--
-- to:
--
-- > (fun x -> mod (mod (x * x) 1000000007 + x) 1000000007) y
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.PropagateMod" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog