-- | A cache-oblivious sequential transposition for CPU execution.
-- Generates a recursive function.
module Futhark.CodeGen.ImpGen.Transpose
  ( mapTransposeFunction,
    transposeArgs,
  )
where

import Futhark.CodeGen.ImpCode
import Futhark.IR.Prop.Types
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | Take well-typed arguments to the transpose function and produce
-- the actual argument list.
transposeArgs ::
  PrimType ->
  VName ->
  Count Bytes (TExp Int64) ->
  VName ->
  Count Bytes (TExp Int64) ->
  TExp Int64 ->
  TExp Int64 ->
  TExp Int64 ->
  [Arg]
transposeArgs :: PrimType
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [Arg]
transposeArgs PrimType
pt VName
destmem Count Bytes (TPrimExp Int64 VName)
destoffset VName
srcmem Count Bytes (TPrimExp Int64 VName)
srcoffset TPrimExp Int64 VName
num_arrays TPrimExp Int64 VName
m TPrimExp Int64 VName
n =
  [ VName -> Arg
MemArg VName
destmem,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count Bytes (TPrimExp Int64 VName)
destoffset forall e. IntegralExp e => e -> e -> e
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
pt,
    VName -> Arg
MemArg VName
srcmem,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count Bytes (TPrimExp Int64 VName)
srcoffset forall e. IntegralExp e => e -> e -> e
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
pt,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_arrays,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
m,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
n,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
0 :: TExp Int64),
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
m,
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
0 :: TExp Int64),
    Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
n
  ]

-- | We need to know the name of the function we are generating, as
-- this function is recursive.
mapTransposeFunction :: Name -> PrimType -> Function op
mapTransposeFunction :: forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
pt =
  forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Function
    forall a. Maybe a
Nothing
    []
    [Param]
params
    ( forall a. Monoid a => [a] -> a
mconcat
        [ forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
r forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
re forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
rb,
          forall {k} {t :: k} {a}. VName -> TPrimExp t VName -> Code a
dec VName
c forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
ce forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
cb,
          forall a. TExp Bool -> Code a -> Code a -> Code a
If (forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) forall {a}. Code a
doTranspose forall {a}. Code a
doMapTranspose
        ]
    )
  where
    params :: [Param]
params =
      [ VName -> Param
memparam VName
destmem,
        VName -> Param
intparam VName
destoffset,
        VName -> Param
memparam VName
srcmem,
        VName -> Param
intparam VName
srcoffset,
        VName -> Param
intparam VName
num_arrays,
        VName -> Param
intparam VName
m,
        VName -> Param
intparam VName
n,
        VName -> Param
intparam VName
cb,
        VName -> Param
intparam VName
ce,
        VName -> Param
intparam VName
rb,
        VName -> Param
intparam VName
re
      ]

    memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
MemParam VName
v Space
DefaultSpace
    intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
ScalarParam VName
v PrimType
int64

    [ VName
destmem,
      VName
destoffset,
      VName
srcmem,
      VName
srcoffset,
      VName
num_arrays,
      VName
n,
      VName
m,
      VName
rb,
      VName
re,
      VName
cb,
      VName
ce,
      VName
r,
      VName
c,
      VName
i,
      VName
j,
      VName
val
      ] =
        forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (Name -> Int -> VName
VName forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
nameFromString)
          [ String
"destmem",
            String
"destoffset",
            String
"srcmem",
            String
"srcoffset",
            String
"num_arrays",
            String
"n",
            String
"m",
            String
"rb",
            String
"re",
            String
"cb",
            String
"ce",
            String
"r",
            String
"c",
            String
"i",
            String
"j", -- local
            String
"val"
          ]
          [Int
0 ..]

    dec :: VName -> TPrimExp t VName -> Code a
dec VName
v TPrimExp t VName
e = forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
v Volatility
Nonvolatile PrimType
int32 forall a. Semigroup a => a -> a -> a
<> forall a. VName -> Exp -> Code a
SetScalar VName
v (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
e)

    naiveTranspose :: Code a
naiveTranspose =
      forall a. VName -> Exp -> Code a -> Code a
For VName
j (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
c) forall a b. (a -> b) -> a -> b
$
        forall a. VName -> Exp -> Code a -> Code a
For VName
i (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
r) forall a b. (a -> b) -> a -> b
$
          let i' :: TPrimExp Int64 VName
i' = forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
rb
              j' :: TPrimExp Int64 VName
j' = forall a. a -> TPrimExp Int64 a
le64 VName
j forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
cb
           in forall a. Monoid a => [a] -> a
mconcat
                [ forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
val Volatility
Nonvolatile PrimType
pt,
                  forall a.
VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Code a
Read
                    VName
val
                    VName
srcmem
                    (forall a. a -> Count Elements a
elements forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i' forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
m forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
j')
                    PrimType
pt
                    Space
DefaultSpace
                    Volatility
Nonvolatile,
                  forall a.
VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
                    VName
destmem
                    (forall a. a -> Count Elements a
elements forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
destoffset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
j' forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
n forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i')
                    PrimType
pt
                    Space
DefaultSpace
                    Volatility
Nonvolatile
                    (VName -> PrimType -> Exp
var VName
val PrimType
pt)
                ]

    recArgs :: (TPrimExp t VName, TPrimExp t VName, TPrimExp t VName,
 TPrimExp t VName)
-> [Arg]
recArgs (TPrimExp t VName
cb', TPrimExp t VName
ce', TPrimExp t VName
rb', TPrimExp t VName
re') =
      [ VName -> Arg
MemArg VName
destmem,
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
destoffset,
        VName -> Arg
MemArg VName
srcmem,
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset,
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays,
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
m,
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
n,
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
cb',
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
ce',
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
rb',
        Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
re'
      ]

    cutoff :: TPrimExp Int64 VName
cutoff = TPrimExp Int64 VName
64 -- arbitrary
    doTranspose :: Code a
doTranspose =
      forall a. Monoid a => [a] -> a
mconcat
        [ forall a. TExp Bool -> Code a -> Code a -> Code a
If
            (forall a. a -> TPrimExp Int64 a
le64 VName
r forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
cutoff forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
c forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
cutoff)
            forall {a}. Code a
naiveTranspose
            forall a b. (a -> b) -> a -> b
$ forall a. TExp Bool -> Code a -> Code a -> Code a
If
              (forall a. a -> TPrimExp Int64 a
le64 VName
r forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. forall a. a -> TPrimExp Int64 a
le64 VName
c)
              ( forall a. [VName] -> Name -> [Arg] -> Code a
Call
                  []
                  Name
fname
                  ( forall {k} {k} {k} {k} {t :: k} {t :: k} {t :: k} {t :: k}.
(TPrimExp t VName, TPrimExp t VName, TPrimExp t VName,
 TPrimExp t VName)
-> [Arg]
recArgs
                      ( forall a. a -> TPrimExp Int64 a
le64 VName
cb,
                        forall a. a -> TPrimExp Int64 a
le64 VName
ce,
                        forall a. a -> TPrimExp Int64 a
le64 VName
rb,
                        forall a. a -> TPrimExp Int64 a
le64 VName
rb forall a. Num a => a -> a -> a
+ (forall a. a -> TPrimExp Int64 a
le64 VName
r forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
                      )
                  )
                  forall a. Semigroup a => a -> a -> a
<> forall a. [VName] -> Name -> [Arg] -> Code a
Call
                    []
                    Name
fname
                    ( forall {k} {k} {k} {k} {t :: k} {t :: k} {t :: k} {t :: k}.
(TPrimExp t VName, TPrimExp t VName, TPrimExp t VName,
 TPrimExp t VName)
-> [Arg]
recArgs
                        ( forall a. a -> TPrimExp Int64 a
le64 VName
cb,
                          forall a. a -> TPrimExp Int64 a
le64 VName
ce,
                          forall a. a -> TPrimExp Int64 a
le64 VName
rb forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
r forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2,
                          forall a. a -> TPrimExp Int64 a
le64 VName
re
                        )
                    )
              )
              ( forall a. [VName] -> Name -> [Arg] -> Code a
Call
                  []
                  Name
fname
                  ( forall {k} {k} {k} {k} {t :: k} {t :: k} {t :: k} {t :: k}.
(TPrimExp t VName, TPrimExp t VName, TPrimExp t VName,
 TPrimExp t VName)
-> [Arg]
recArgs
                      ( forall a. a -> TPrimExp Int64 a
le64 VName
cb,
                        forall a. a -> TPrimExp Int64 a
le64 VName
cb forall a. Num a => a -> a -> a
+ (forall a. a -> TPrimExp Int64 a
le64 VName
c forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2),
                        forall a. a -> TPrimExp Int64 a
le64 VName
rb,
                        forall a. a -> TPrimExp Int64 a
le64 VName
re
                      )
                  )
                  forall a. Semigroup a => a -> a -> a
<> forall a. [VName] -> Name -> [Arg] -> Code a
Call
                    []
                    Name
fname
                    ( forall {k} {k} {k} {k} {t :: k} {t :: k} {t :: k} {t :: k}.
(TPrimExp t VName, TPrimExp t VName, TPrimExp t VName,
 TPrimExp t VName)
-> [Arg]
recArgs
                        ( forall a. a -> TPrimExp Int64 a
le64 VName
cb forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
c forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2,
                          forall a. a -> TPrimExp Int64 a
le64 VName
ce,
                          forall a. a -> TPrimExp Int64 a
le64 VName
rb,
                          forall a. a -> TPrimExp Int64 a
le64 VName
re
                        )
                    )
              )
        ]

    doMapTranspose :: Code a
doMapTranspose =
      -- In the map-transpose case, we assume that cb==rb==0, ce==m,
      -- re==n.
      forall a. VName -> Exp -> Code a -> Code a
For VName
i (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays) forall a b. (a -> b) -> a -> b
$
        forall a. [VName] -> Name -> [Arg] -> Code a
Call
          []
          Name
fname
          [ VName -> Arg
MemArg VName
destmem,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
destoffset forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
m forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
n,
            VName -> Arg
MemArg VName
srcmem,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
m forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
n,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
1 :: TExp Int64),
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
m,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
n,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
cb,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
ce,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
rb,
            Exp -> Arg
ExpArg forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
re
          ]