{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.MatrixExponentiation
-- Description : replaces repeated applications of linear (or, affine) functions with powers of matrices. / 線形な (あるいは affine な) 関数の繰り返しの適用を行列累乗で置き換えます。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.MatrixExponentiation
  ( run,
  )
where

import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.List
import qualified Data.Vector as V
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Common.Matrix
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

toLinearExpression :: VarName -> Expr -> Maybe (Maybe Expr, Maybe Expr)
toLinearExpression :: VarName -> Expr -> Maybe (Maybe Expr, Maybe Expr)
toLinearExpression VarName
x Expr
e = do
  (Vector ArithmeticExpr
a, ArithmeticExpr
b) <- Vector VarName
-> ArithmeticExpr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr)
makeVectorFromArithmeticExpr (VarName -> Vector VarName
forall a. a -> Vector a
V.singleton VarName
x) (Expr -> ArithmeticExpr
parseArithmeticExpr Expr
e)
  case Vector ArithmeticExpr -> [ArithmeticExpr]
forall a. Vector a -> [a]
V.toList Vector ArithmeticExpr
a of
    [ArithmeticExpr
a] ->
      let a' :: Maybe Expr
a' = if ArithmeticExpr -> Bool
isOneArithmeticExpr ArithmeticExpr
a then Maybe Expr
forall a. Maybe a
Nothing else Expr -> Maybe Expr
forall a. a -> Maybe a
Just (ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
a)
          b' :: Maybe Expr
b' = if ArithmeticExpr -> Bool
isZeroArithmeticExpr ArithmeticExpr
b then Maybe Expr
forall a. Maybe a
Nothing else Expr -> Maybe Expr
forall a. a -> Maybe a
Just (ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
b)
       in (Maybe Expr, Maybe Expr) -> Maybe (Maybe Expr, Maybe Expr)
forall a. a -> Maybe a
Just (Maybe Expr
a', Maybe Expr
b')
    [ArithmeticExpr]
_ -> [Char] -> Maybe (Maybe Expr, Maybe Expr)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe (Maybe Expr, Maybe Expr))
-> [Char] -> Maybe (Maybe Expr, Maybe Expr)
forall a b. (a -> b) -> a -> b
$ [Char]
"Jikka.Core.Convert.MatrixExponentiation.toLinearExpression: size mismtach: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector ArithmeticExpr -> Int
forall a. Vector a -> Int
V.length Vector ArithmeticExpr
a)

fromMatrix :: Matrix ArithmeticExpr -> Expr
fromMatrix :: Matrix ArithmeticExpr -> Expr
fromMatrix Matrix ArithmeticExpr
f =
  let (Int
h, Int
w) = Matrix ArithmeticExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticExpr
f
      go :: Vector ArithmeticExpr -> Expr
go Vector ArithmeticExpr
row = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
w Type
IntTy)) ((ArithmeticExpr -> Expr) -> [ArithmeticExpr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map ArithmeticExpr -> Expr
formatArithmeticExpr (Vector ArithmeticExpr -> [ArithmeticExpr]
forall a. Vector a -> [a]
V.toList Vector ArithmeticExpr
row))
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
h ([Type] -> Type
TupleTy (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
w Type
IntTy)))) ((Vector ArithmeticExpr -> Expr)
-> [Vector ArithmeticExpr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map Vector ArithmeticExpr -> Expr
go (Vector (Vector ArithmeticExpr) -> [Vector ArithmeticExpr]
forall a. Vector a -> [a]
V.toList (Matrix ArithmeticExpr -> Vector (Vector ArithmeticExpr)
forall a. Matrix a -> Vector (Vector a)
unMatrix Matrix ArithmeticExpr
f)))

fromAffineMatrix :: Matrix ArithmeticExpr -> V.Vector ArithmeticExpr -> Expr
fromAffineMatrix :: Matrix ArithmeticExpr -> Vector ArithmeticExpr -> Expr
fromAffineMatrix Matrix ArithmeticExpr
a Vector ArithmeticExpr
b | (Int, Int) -> Int
forall a b. (a, b) -> a
fst (Matrix ArithmeticExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticExpr
a) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector ArithmeticExpr -> Int
forall a. Vector a -> Int
V.length Vector ArithmeticExpr
b = [Char] -> Expr
forall a. HasCallStack => [Char] -> a
error ([Char] -> Expr) -> [Char] -> Expr
forall a b. (a -> b) -> a -> b
$ [Char]
"Jikka.Core.Convert.MatrixExponentiation.fromAffineMatrix: size mismtach: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Matrix ArithmeticExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticExpr
a) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" and " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector ArithmeticExpr -> Int
forall a. Vector a -> Int
V.length Vector ArithmeticExpr
b)
fromAffineMatrix Matrix ArithmeticExpr
a Vector ArithmeticExpr
b =
  let (Int
h, Int
w) = Matrix ArithmeticExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticExpr
a
      go :: Vector ArithmeticExpr -> ArithmeticExpr -> Expr
go Vector ArithmeticExpr
row ArithmeticExpr
c = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Type
IntTy)) ((ArithmeticExpr -> Expr) -> [ArithmeticExpr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map ArithmeticExpr -> Expr
formatArithmeticExpr (Vector ArithmeticExpr -> [ArithmeticExpr]
forall a. Vector a -> [a]
V.toList Vector ArithmeticExpr
row [ArithmeticExpr] -> [ArithmeticExpr] -> [ArithmeticExpr]
forall a. [a] -> [a] -> [a]
++ [ArithmeticExpr
c]))
      bottom :: Expr
bottom = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Type
IntTy)) (Int -> Expr -> [Expr]
forall a. Int -> a -> [a]
replicate Int
w (Integer -> Expr
LitInt' Integer
0) [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Integer -> Expr
LitInt' Integer
1])
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Type] -> Type
TupleTy (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Type
IntTy)))) (Vector Expr -> [Expr]
forall a. Vector a -> [a]
V.toList ((Vector ArithmeticExpr -> ArithmeticExpr -> Expr)
-> Vector (Vector ArithmeticExpr)
-> Vector ArithmeticExpr
-> Vector Expr
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Vector ArithmeticExpr -> ArithmeticExpr -> Expr
go (Matrix ArithmeticExpr -> Vector (Vector ArithmeticExpr)
forall a. Matrix a -> Vector (Vector a)
unMatrix Matrix ArithmeticExpr
a) Vector ArithmeticExpr
b) [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr
bottom])

toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Integer -> Expr -> m (Maybe (Matrix ArithmeticExpr, Maybe (V.Vector ArithmeticExpr)))
toMatrix :: [(VarName, Type)]
-> VarName
-> Integer
-> Expr
-> m (Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr)))
toMatrix [(VarName, Type)]
env VarName
x Integer
n Expr
step =
  case Expr -> (Expr, [Expr])
curryApp Expr
step of
    (Tuple' [Type]
_, [Expr]
es) -> MaybeT m (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
-> m (Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr)))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
 -> m (Maybe
         (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))))
-> MaybeT m (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
-> m (Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr)))
forall a b. (a -> b) -> a -> b
$ do
      Vector VarName
xs <- [VarName] -> Vector VarName
forall a. [a] -> Vector a
V.fromList ([VarName] -> Vector VarName)
-> MaybeT m [VarName] -> MaybeT m (Vector VarName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> MaybeT m VarName -> MaybeT m [VarName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
n) (m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x))
      let unpackTuple :: p -> Expr -> Expr
unpackTuple p
_ Expr
e = case Expr
e of
            Proj' [Type]
_ Integer
i (Var VarName
x') | VarName
x' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x -> VarName -> Expr
Var (Vector VarName
xs Vector VarName -> Int -> VarName
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
i)
            Expr
_ -> Expr
e
      [(Vector ArithmeticExpr, ArithmeticExpr)]
rows <- m (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)])
-> MaybeT m [(Vector ArithmeticExpr, ArithmeticExpr)]
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)])
 -> MaybeT m [(Vector ArithmeticExpr, ArithmeticExpr)])
-> ((Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
    -> m (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)]))
-> (Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
-> MaybeT m [(Vector ArithmeticExpr, ArithmeticExpr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe [(Vector ArithmeticExpr, ArithmeticExpr)]
-> m (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)]
 -> m (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)]))
-> ((Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
    -> Maybe [(Vector ArithmeticExpr, ArithmeticExpr)])
-> (Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
-> m (Maybe [(Vector ArithmeticExpr, ArithmeticExpr)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Expr]
-> (Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
-> Maybe [(Vector ArithmeticExpr, ArithmeticExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Expr]
es ((Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
 -> MaybeT m [(Vector ArithmeticExpr, ArithmeticExpr)])
-> (Expr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr))
-> MaybeT m [(Vector ArithmeticExpr, ArithmeticExpr)]
forall a b. (a -> b) -> a -> b
$ \Expr
e -> do
        let e' :: Expr
e' = ([(VarName, Type)] -> Expr -> Expr)
-> [(VarName, Type)] -> Expr -> Expr
mapSubExpr [(VarName, Type)] -> Expr -> Expr
forall p. p -> Expr -> Expr
unpackTuple [(VarName, Type)]
env Expr
e
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
e'
        Vector VarName
-> ArithmeticExpr -> Maybe (Vector ArithmeticExpr, ArithmeticExpr)
makeVectorFromArithmeticExpr Vector VarName
xs (Expr -> ArithmeticExpr
parseArithmeticExpr Expr
e')
      Matrix ArithmeticExpr
a <- m (Maybe (Matrix ArithmeticExpr))
-> MaybeT m (Matrix ArithmeticExpr)
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe (Matrix ArithmeticExpr))
 -> MaybeT m (Matrix ArithmeticExpr))
-> (Maybe (Matrix ArithmeticExpr)
    -> m (Maybe (Matrix ArithmeticExpr)))
-> Maybe (Matrix ArithmeticExpr)
-> MaybeT m (Matrix ArithmeticExpr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Matrix ArithmeticExpr) -> m (Maybe (Matrix ArithmeticExpr))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Matrix ArithmeticExpr) -> MaybeT m (Matrix ArithmeticExpr))
-> Maybe (Matrix ArithmeticExpr)
-> MaybeT m (Matrix ArithmeticExpr)
forall a b. (a -> b) -> a -> b
$ Vector (Vector ArithmeticExpr) -> Maybe (Matrix ArithmeticExpr)
forall a. Vector (Vector a) -> Maybe (Matrix a)
makeMatrix ([Vector ArithmeticExpr] -> Vector (Vector ArithmeticExpr)
forall a. [a] -> Vector a
V.fromList (((Vector ArithmeticExpr, ArithmeticExpr) -> Vector ArithmeticExpr)
-> [(Vector ArithmeticExpr, ArithmeticExpr)]
-> [Vector ArithmeticExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Vector ArithmeticExpr, ArithmeticExpr) -> Vector ArithmeticExpr
forall a b. (a, b) -> a
fst [(Vector ArithmeticExpr, ArithmeticExpr)]
rows))
      let b :: Maybe (Vector ArithmeticExpr)
b = if ((Vector ArithmeticExpr, ArithmeticExpr) -> Bool)
-> [(Vector ArithmeticExpr, ArithmeticExpr)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ArithmeticExpr -> Bool
isZeroArithmeticExpr (ArithmeticExpr -> Bool)
-> ((Vector ArithmeticExpr, ArithmeticExpr) -> ArithmeticExpr)
-> (Vector ArithmeticExpr, ArithmeticExpr)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector ArithmeticExpr, ArithmeticExpr) -> ArithmeticExpr
forall a b. (a, b) -> b
snd) [(Vector ArithmeticExpr, ArithmeticExpr)]
rows then Maybe (Vector ArithmeticExpr)
forall a. Maybe a
Nothing else Vector ArithmeticExpr -> Maybe (Vector ArithmeticExpr)
forall a. a -> Maybe a
Just ([ArithmeticExpr] -> Vector ArithmeticExpr
forall a. [a] -> Vector a
V.fromList (((Vector ArithmeticExpr, ArithmeticExpr) -> ArithmeticExpr)
-> [(Vector ArithmeticExpr, ArithmeticExpr)] -> [ArithmeticExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Vector ArithmeticExpr, ArithmeticExpr) -> ArithmeticExpr
forall a b. (a, b) -> b
snd [(Vector ArithmeticExpr, ArithmeticExpr)]
rows))
      (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
-> MaybeT m (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix ArithmeticExpr
a, Maybe (Vector ArithmeticExpr)
b)
    (Expr, [Expr])
_ -> Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
-> m (Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr)))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
forall a. Maybe a
Nothing

addOneToVector :: Integer -> VarName -> Expr
addOneToVector :: Integer -> VarName -> Expr
addOneToVector Integer
n VarName
x =
  let ts :: [Type]
ts = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
n) Type
IntTy
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Type
IntTy Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts)) ((Integer -> Expr) -> [Integer] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
i -> [Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
i (VarName -> Expr
Var VarName
x)) [Integer
0 .. Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1] [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Integer -> Expr
LitInt' Integer
1])

removeOneFromVector :: Integer -> VarName -> Expr
removeOneFromVector :: Integer -> VarName -> Expr
removeOneFromVector Integer
n VarName
x =
  let ts :: [Type]
ts = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
n) Type
IntTy
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((Integer -> Expr) -> [Integer] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
i -> [Type] -> Integer -> Expr -> Expr
Proj' (Type
IntTy Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts) Integer
i (VarName -> Expr
Var VarName
x)) [Integer
0 .. Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1])

rule :: MonadAlpha m => RewriteRule m
rule :: RewriteRule m
rule = [Char]
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
[Char]
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
makeRewriteRule [Char]
"Jikka.Core.Convert.MatrixExponentiation" ((RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \RewriteEnvironment
env -> \case
  Iterate' Type
IntTy Expr
k (Lam VarName
x Type
_ Expr
step) Expr
base -> do
    let step' :: Maybe (Maybe Expr, Maybe Expr)
step' = VarName -> Expr -> Maybe (Maybe Expr, Maybe Expr)
toLinearExpression VarName
x Expr
step
    Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ case Maybe (Maybe Expr, Maybe Expr)
step' of
      Maybe (Maybe Expr, Maybe Expr)
Nothing -> Maybe Expr
forall a. Maybe a
Nothing
      Just (Maybe Expr
Nothing, Maybe Expr
Nothing) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
base
      Just (Maybe Expr
Nothing, Just Expr
b) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Plus' Expr
base (Expr -> Expr -> Expr
Mult' Expr
k Expr
b)
      Just (Just Expr
a, Maybe Expr
Nothing) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Mult' (Expr -> Expr -> Expr
Pow' Expr
a Expr
k) Expr
base
      Just (Just Expr
a, Just Expr
b) ->
        let a' :: Expr
a' = Expr -> Expr -> Expr
Pow' Expr
a Expr
k
            b' :: Expr
b' = Expr -> Expr -> Expr
Mult' (Expr -> Expr -> Expr
FloorDiv' (Expr -> Expr -> Expr
Minus' (Expr -> Expr -> Expr
Pow' Expr
a Expr
k) (Integer -> Expr
LitInt' Integer
1)) (Expr -> Expr -> Expr
Minus' Expr
a (Integer -> Expr
LitInt' Integer
1))) Expr
b -- This division has no remainder.
         in Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Mult' Expr
a' Expr
base) Expr
b'
  Iterate' (TupleTy [Type]
ts) Expr
k (Lam VarName
x Type
_ Expr
step) Expr
base | [Type] -> Bool
isVectorTy' [Type]
ts -> do
    let n :: Integer
n = [Type] -> Integer
forall i a. Num i => [a] -> i
genericLength [Type]
ts
    let go :: Integer -> Expr -> Expr -> Expr
go Integer
n Expr
step Expr
base = Integer -> Integer -> Expr -> Expr -> Expr
MatAp' Integer
n Integer
n (Integer -> Expr -> Expr -> Expr
MatPow' Integer
n Expr
step Expr
k) Expr
base
    Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
step <- [(VarName, Type)]
-> VarName
-> Integer
-> Expr
-> m (Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr)))
forall (m :: * -> *).
MonadAlpha m =>
[(VarName, Type)]
-> VarName
-> Integer
-> Expr
-> m (Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr)))
toMatrix (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env) VarName
x Integer
n Expr
step
    case Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
step of
      Maybe (Matrix ArithmeticExpr, Maybe (Vector ArithmeticExpr))
Nothing -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
      Just (Matrix ArithmeticExpr
a, Maybe (Vector ArithmeticExpr)
Nothing) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr))
-> (Expr -> Maybe Expr) -> Expr -> m (Maybe Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Integer -> Expr -> Expr -> Expr
go Integer
n (Matrix ArithmeticExpr -> Expr
fromMatrix Matrix ArithmeticExpr
a) Expr
base
      Just (Matrix ArithmeticExpr
a, Just Vector ArithmeticExpr
b) -> do
        VarName
y <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
        VarName
z <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
        Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr))
-> (Expr -> Maybe Expr) -> Expr -> m (Maybe Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$
          VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y ([Type] -> Type
TupleTy [Type]
ts) Expr
base (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
            VarName -> Type -> Expr -> Expr -> Expr
Let VarName
z ([Type] -> Type
TupleTy (Type
IntTy Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts)) (Integer -> Expr -> Expr -> Expr
go (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) (Matrix ArithmeticExpr -> Vector ArithmeticExpr -> Expr
fromAffineMatrix Matrix ArithmeticExpr
a Vector ArithmeticExpr
b) (Integer -> VarName -> Expr
addOneToVector Integer
n VarName
y)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
              Integer -> VarName -> Expr
removeOneFromVector Integer
n VarName
z
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *). MonadAlpha m => RewriteRule m
rule

-- | `run` simplifies an affine functions from vectors to vectors in @iterate@ (`Iterate`) functions.
--
-- == Examples
--
-- This makes matrix multiplication. Before:
--
-- > iterate n (fun xs -> (xs[0] + 2 * xs[1], xs[1])) xs
--
-- After:
--
-- > matap (matpow ((1, 2), (0, 1)) n) xs
--
-- Also this works on integers. Before:
--
-- > iterate n (fun x -> (2 x + 1)) x
--
-- After:
--
-- > (2 ** n) * x + (2 ** n - 1) / (n - 1)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = [Char] -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => [Char] -> m a -> m a
wrapError' [Char]
"Jikka.Core.Convert.MatrixExponentiation" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog