{-# LANGUAGE TypeFamilies #-}

-- | Interchanging scans with inner maps.
module Futhark.Pass.ExtractKernels.ISRWIM
  ( iswim,
    irwim,
    rwimPossible,
  )
where

import Control.Arrow (first)
import Control.Monad.State
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools

-- | Interchange Scan With Inner Map. Tries to turn a @scan(map)@ into a
-- @map(scan)
iswim ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  Pat Type ->
  SubExp ->
  Lambda SOACS ->
  [(SubExp, VName)] ->
  Maybe (m ())
iswim :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ())
iswim Pat Type
res_pat SubExp
w Lambda SOACS
scan_fun [(SubExp, VName)]
scan_input
  | Just (Pat Type
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pat Type, Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
      let ([SubExp]
accs, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
scan_input
      [VName]
arrs' <- forall (m :: * -> *). MonadBuilder m => [VName] -> m [VName]
transposedArrays [VName]
arrs
      [VName]
accs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
accs

      let map_arrs' :: [VName]
map_arrs' = [VName]
accs' forall a. [a] -> [a] -> [a]
++ [VName]
arrs'
          ([Param Type]
scan_acc_params, [Param Type]
scan_elem_params) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
scan_fun
          map_params :: [Param Type]
map_params =
            forall a b. (a -> b) -> [a] -> [b]
map LParam SOACS -> LParam SOACS
removeParamOuterDim [Param Type]
scan_acc_params
              forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
scan_elem_params
          map_rettype :: [Type]
map_rettype = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Type -> Type
setOuterDimTo SubExp
w) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_fun

          scan_params :: [LParam SOACS]
scan_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
          scan_body :: Body SOACS
scan_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
          scan_rettype :: [Type]
scan_rettype = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_fun
          scan_fun' :: Lambda SOACS
scan_fun' = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
scan_params Body SOACS
scan_body [Type]
scan_rettype
          scan_input' :: [(SubExp, VName)]
scan_input' =
            forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$
              forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. [a] -> [b] -> [(a, b)]
zip forall a b. (a -> b) -> a -> b
$
                forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs') forall a b. (a -> b) -> a -> b
$
                  forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
map_params
          ([SubExp]
nes', [VName]
scan_arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
scan_input'

      ScremaForm SOACS
scan_soac <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_fun' [SubExp]
nes']
      let map_body :: Body SOACS
map_body =
            forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
              ( forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp -> Pat Type -> Pat Type
setPatOuterDimTo SubExp
w Pat Type
map_pat) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
                    forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                      forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
scan_arrs ScremaForm SOACS
scan_soac
              )
              forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
              forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
map_pat
          map_fun' :: Lambda SOACS
map_fun' = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
map_params Body SOACS
map_body [Type]
map_rettype

      Pat Type
res_pat' <-
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Ident] -> Pat Type
basicPat forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
(String -> String) -> Ident -> m Ident
newIdent' (forall a. Semigroup a => a -> a -> a
<> String
"_transposed") forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> Ident
transposeIdentType) forall a b. (a -> b) -> a -> b
$
            forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
res_pat

      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
res_pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
map_cs forall a. Monoid a => a
mempty ()) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
map_w [VName]
map_arrs' (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_fun')

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
res_pat) (forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
res_pat')) forall a b. (a -> b) -> a -> b
$ \(Ident
to, Ident
from) -> do
        let perm :: [Int]
perm = [Int
1, Int
0] forall a. [a] -> [a] -> [a]
++ [Int
2 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> Type
identType Ident
from) forall a. Num a => a -> a -> a
- Int
1]
        forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([Ident] -> Pat Type
basicPat [Ident
to]) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              [Int] -> VName -> BasicOp
Rearrange [Int]
perm forall a b. (a -> b) -> a -> b
$
                Ident -> VName
identName Ident
from
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Interchange Reduce With Inner Map. Tries to turn a @reduce(map)@ into a
-- @map(reduce)
irwim ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  Pat Type ->
  SubExp ->
  Commutativity ->
  Lambda SOACS ->
  [(SubExp, VName)] ->
  Maybe (m ())
irwim :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat Type
res_pat SubExp
w Commutativity
comm Lambda SOACS
red_fun [(SubExp, VName)]
red_input
  | Just (Pat Type
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pat Type, Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
red_fun = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
      let ([SubExp]
accs, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
red_input
      [VName]
arrs' <- forall (m :: * -> *). MonadBuilder m => [VName] -> m [VName]
transposedArrays [VName]
arrs
      -- FIXME?  Can we reasonably assume that the accumulator is a
      -- replicate?  We also assume that it is non-empty.
      let indexAcc :: SubExp -> m SubExp
indexAcc (Var VName
v) = do
            Type
v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"acc" forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> Slice SubExp -> BasicOp
Index VName
v forall a b. (a -> b) -> a -> b
$
                  Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
v_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
          indexAcc Constant {} =
            forall a. HasCallStack => String -> a
error String
"irwim: array accumulator is a constant."
      [SubExp]
accs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *}. MonadBuilder m => SubExp -> m SubExp
indexAcc [SubExp]
accs

      let ([Param Type]
_red_acc_params, [Param Type]
red_elem_params) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
red_fun
          map_rettype :: [Type]
map_rettype = forall a b. (a -> b) -> [a] -> [b]
map forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
red_fun
          map_params :: [Param Type]
map_params = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
red_elem_params

          red_params :: [LParam SOACS]
red_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
          red_body :: Body SOACS
red_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
          red_rettype :: [Type]
red_rettype = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_fun
          red_fun' :: Lambda SOACS
red_fun' = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
red_params Body SOACS
red_body [Type]
red_rettype
          red_input' :: [(SubExp, VName)]
red_input' = forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
accs' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
map_params
          red_pat :: Pat Type
red_pat = Pat Type -> Pat Type
stripPatOuterDim Pat Type
map_pat

      Body SOACS
map_body <-
        case forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat Type
red_pat SubExp
w Commutativity
comm Lambda SOACS
red_fun' [(SubExp, VName)]
red_input' of
          Maybe (m ())
Nothing -> do
            ScremaForm SOACS
reduce_soac <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_fun' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, VName)]
red_input']
            forall (f :: * -> *) a. Applicative f => a -> f a
pure
              forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
                ( forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$
                    forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
red_pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
                      forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                        forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, VName)]
red_input') ScremaForm SOACS
reduce_soac
                )
              forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
              forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
map_pat
          Just m ()
m -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [Param Type]
map_params) forall a b. (a -> b) -> a -> b
$ do
            Stms SOACS
map_body_stms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ m ()
m
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms SOACS
map_body_stms forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
map_pat

      let map_fun' :: Lambda SOACS
map_fun' = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
map_params Body SOACS
map_body [Type]
map_rettype

      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
res_pat (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
map_cs forall a. Monoid a => a
mempty ()) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
map_w [VName]
arrs' forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_fun'
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Does this reduce operator contain an inner map, and if so, what
-- does that map look like?
rwimPossible ::
  Lambda SOACS ->
  Maybe (Pat Type, Certs, SubExp, Lambda SOACS)
rwimPossible :: Lambda SOACS -> Maybe (Pat Type, Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
fun
  | Body BodyDec SOACS
_ Stms SOACS
stms Result
res <- forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
fun,
    [Stm SOACS
stm] <- forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms, -- Body has a single binding
    Pat (LetDec SOACS)
map_pat <- forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm,
    forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
map_pat) forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res, -- Returned verbatim
    Op (Screma SubExp
map_w [VName]
map_arrs ScremaForm SOACS
form) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm SOACS
stm,
    Just Lambda SOACS
map_fun <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
    forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
fun) forall a. Eq a => a -> a -> Bool
== [VName]
map_arrs =
      forall a. a -> Maybe a
Just (Pat (LetDec SOACS)
map_pat, forall {k} (rep :: k). Stm rep -> Certs
stmCerts Stm SOACS
stm, SubExp
map_w, Lambda SOACS
map_fun)
  | Bool
otherwise =
      forall a. Maybe a
Nothing

transposedArrays :: MonadBuilder m => [VName] -> m [VName]
transposedArrays :: forall (m :: * -> *). MonadBuilder m => [VName] -> m [VName]
transposedArrays [VName]
arrs = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
  Type
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
  let perm :: [Int]
perm = [Int
1, Int
0] forall a. [a] -> [a] -> [a]
++ [Int
2 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t forall a. Num a => a -> a -> a
- Int
1]
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr

removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
  let t :: Type
t = forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType LParam SOACS
param
   in LParam SOACS
param {paramDec :: Type
paramDec = Type
t}

setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
  let t :: Type
t = SubExp -> Type -> Type
setOuterDimTo SubExp
w forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType LParam SOACS
param
   in LParam SOACS
param {paramDec :: Type
paramDec = Type
t}

setIdentOuterDimTo :: SubExp -> Ident -> Ident
setIdentOuterDimTo :: SubExp -> Ident -> Ident
setIdentOuterDimTo SubExp
w Ident
ident =
  let t :: Type
t = SubExp -> Type -> Type
setOuterDimTo SubExp
w forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
   in Ident
ident {identType :: Type
identType = Type
t}

setOuterDimTo :: SubExp -> Type -> Type
setOuterDimTo :: SubExp -> Type -> Type
setOuterDimTo SubExp
w Type
t =
  forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow (forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType Type
t) SubExp
w

setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo SubExp
w Pat Type
pat =
  [Ident] -> Pat Type
basicPat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Ident -> Ident
setIdentOuterDimTo SubExp
w) forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
pat

transposeIdentType :: Ident -> Ident
transposeIdentType :: Ident -> Ident
transposeIdentType Ident
ident =
  Ident
ident {identType :: Type
identType = Type -> Type
transposeType forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident}

stripIdentOuterDim :: Ident -> Ident
stripIdentOuterDim :: Ident -> Ident
stripIdentOuterDim Ident
ident =
  Ident
ident {identType :: Type
identType = forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident}

stripPatOuterDim :: Pat Type -> Pat Type
stripPatOuterDim :: Pat Type -> Pat Type
stripPatOuterDim Pat Type
pat =
  [Ident] -> Pat Type
basicPat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Ident -> Ident
stripIdentOuterDim forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
pat