{-# LANGUAGE TypeFamilies #-}

module Futhark.Analysis.HORep.MapNest
  ( Nesting (..),
    MapNest (..),
    typeOf,
    params,
    inputs,
    setInputs,
    fromSOAC,
    toSOAC,
  )
where

import Data.List (find)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.SOAC (SOAC)
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Construct
import Futhark.IR hiding (typeOf)
import Futhark.IR.SOACS.SOAC qualified as Futhark
import Futhark.Transform.Substitute

data Nesting rep = Nesting
  { forall {k} (rep :: k). Nesting rep -> [VName]
nestingParamNames :: [VName],
    forall {k} (rep :: k). Nesting rep -> [VName]
nestingResult :: [VName],
    forall {k} (rep :: k). Nesting rep -> [Type]
nestingReturnType :: [Type],
    forall {k} (rep :: k). Nesting rep -> SubExp
nestingWidth :: SubExp
  }
  deriving (Nesting rep -> Nesting rep -> Bool
(Nesting rep -> Nesting rep -> Bool)
-> (Nesting rep -> Nesting rep -> Bool) -> Eq (Nesting rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
$c== :: forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
== :: Nesting rep -> Nesting rep -> Bool
$c/= :: forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
/= :: Nesting rep -> Nesting rep -> Bool
Eq, Eq (Nesting rep)
Eq (Nesting rep)
-> (Nesting rep -> Nesting rep -> Ordering)
-> (Nesting rep -> Nesting rep -> Bool)
-> (Nesting rep -> Nesting rep -> Bool)
-> (Nesting rep -> Nesting rep -> Bool)
-> (Nesting rep -> Nesting rep -> Bool)
-> (Nesting rep -> Nesting rep -> Nesting rep)
-> (Nesting rep -> Nesting rep -> Nesting rep)
-> Ord (Nesting rep)
Nesting rep -> Nesting rep -> Bool
Nesting rep -> Nesting rep -> Ordering
Nesting rep -> Nesting rep -> Nesting rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). Eq (Nesting rep)
forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
forall k (rep :: k). Nesting rep -> Nesting rep -> Ordering
forall k (rep :: k). Nesting rep -> Nesting rep -> Nesting rep
$ccompare :: forall k (rep :: k). Nesting rep -> Nesting rep -> Ordering
compare :: Nesting rep -> Nesting rep -> Ordering
$c< :: forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
< :: Nesting rep -> Nesting rep -> Bool
$c<= :: forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
<= :: Nesting rep -> Nesting rep -> Bool
$c> :: forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
> :: Nesting rep -> Nesting rep -> Bool
$c>= :: forall k (rep :: k). Nesting rep -> Nesting rep -> Bool
>= :: Nesting rep -> Nesting rep -> Bool
$cmax :: forall k (rep :: k). Nesting rep -> Nesting rep -> Nesting rep
max :: Nesting rep -> Nesting rep -> Nesting rep
$cmin :: forall k (rep :: k). Nesting rep -> Nesting rep -> Nesting rep
min :: Nesting rep -> Nesting rep -> Nesting rep
Ord, Int -> Nesting rep -> ShowS
[Nesting rep] -> ShowS
Nesting rep -> String
(Int -> Nesting rep -> ShowS)
-> (Nesting rep -> String)
-> ([Nesting rep] -> ShowS)
-> Show (Nesting rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). Int -> Nesting rep -> ShowS
forall k (rep :: k). [Nesting rep] -> ShowS
forall k (rep :: k). Nesting rep -> String
$cshowsPrec :: forall k (rep :: k). Int -> Nesting rep -> ShowS
showsPrec :: Int -> Nesting rep -> ShowS
$cshow :: forall k (rep :: k). Nesting rep -> String
show :: Nesting rep -> String
$cshowList :: forall k (rep :: k). [Nesting rep] -> ShowS
showList :: [Nesting rep] -> ShowS
Show)

data MapNest rep = MapNest SubExp (Lambda rep) [Nesting rep] [SOAC.Input]
  deriving (Int -> MapNest rep -> ShowS
[MapNest rep] -> ShowS
MapNest rep -> String
(Int -> MapNest rep -> ShowS)
-> (MapNest rep -> String)
-> ([MapNest rep] -> ShowS)
-> Show (MapNest rep)
forall rep. RepTypes rep => Int -> MapNest rep -> ShowS
forall rep. RepTypes rep => [MapNest rep] -> ShowS
forall rep. RepTypes rep => MapNest rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> MapNest rep -> ShowS
showsPrec :: Int -> MapNest rep -> ShowS
$cshow :: forall rep. RepTypes rep => MapNest rep -> String
show :: MapNest rep -> String
$cshowList :: forall rep. RepTypes rep => [MapNest rep] -> ShowS
showList :: [MapNest rep] -> ShowS
Show)

typeOf :: MapNest rep -> [Type]
typeOf :: forall rep. MapNest rep -> [Type]
typeOf (MapNest SubExp
w Lambda rep
lam [] [Input]
_) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
typeOf (MapNest SubExp
w Lambda rep
_ (Nesting rep
nest : [Nesting rep]
_) [Input]
_) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Nesting rep -> [Type]
forall {k} (rep :: k). Nesting rep -> [Type]
nestingReturnType Nesting rep
nest

params :: MapNest rep -> [VName]
params :: forall rep. MapNest rep -> [VName]
params (MapNest SubExp
_ Lambda rep
lam [] [Input]
_) =
  (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
params (MapNest SubExp
_ Lambda rep
_ (Nesting rep
nest : [Nesting rep]
_) [Input]
_) =
  Nesting rep -> [VName]
forall {k} (rep :: k). Nesting rep -> [VName]
nestingParamNames Nesting rep
nest

inputs :: MapNest rep -> [SOAC.Input]
inputs :: forall rep. MapNest rep -> [Input]
inputs (MapNest SubExp
_ Lambda rep
_ [Nesting rep]
_ [Input]
inps) = [Input]
inps

setInputs :: [SOAC.Input] -> MapNest rep -> MapNest rep
setInputs :: forall rep. [Input] -> MapNest rep -> MapNest rep
setInputs [] (MapNest SubExp
w Lambda rep
body [Nesting rep]
ns [Input]
_) = SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest SubExp
w Lambda rep
body [Nesting rep]
ns []
setInputs (Input
inp : [Input]
inps) (MapNest SubExp
_ Lambda rep
body [Nesting rep]
ns [Input]
_) = SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest SubExp
w Lambda rep
body [Nesting rep]
ns' (Input
inp Input -> [Input] -> [Input]
forall a. a -> [a] -> [a]
: [Input]
inps)
  where
    w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (Type -> SubExp) -> Type -> SubExp
forall a b. (a -> b) -> a -> b
$ Input -> Type
SOAC.inputType Input
inp
    ws :: [SubExp]
ws = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Input -> Type
SOAC.inputType Input
inp
    ns' :: [Nesting rep]
ns' = (Nesting rep -> SubExp -> Nesting rep)
-> [Nesting rep] -> [SubExp] -> [Nesting rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Nesting rep -> SubExp -> Nesting rep
forall {k} {k} {rep :: k} {rep :: k}.
Nesting rep -> SubExp -> Nesting rep
setDepth [Nesting rep]
ns [SubExp]
ws
    setDepth :: Nesting rep -> SubExp -> Nesting rep
setDepth Nesting rep
n SubExp
nw = Nesting rep
n {nestingWidth :: SubExp
nestingWidth = SubExp
nw}

fromSOAC ::
  ( Buildable rep,
    MonadFreshNames m,
    LocalScope rep m,
    Op rep ~ Futhark.SOAC rep
  ) =>
  SOAC rep ->
  m (Maybe (MapNest rep))
fromSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
fromSOAC = [Ident] -> SOAC rep -> m (Maybe (MapNest rep))
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
[Ident] -> SOAC rep -> m (Maybe (MapNest rep))
fromSOAC' [Ident]
forall a. Monoid a => a
mempty

fromSOAC' ::
  ( Buildable rep,
    MonadFreshNames m,
    LocalScope rep m,
    Op rep ~ Futhark.SOAC rep
  ) =>
  [Ident] ->
  SOAC rep ->
  m (Maybe (MapNest rep))
fromSOAC' :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
[Ident] -> SOAC rep -> m (Maybe (MapNest rep))
fromSOAC' [Ident]
bound (SOAC.Screma SubExp
w (SOAC.ScremaForm [] [] Lambda rep
lam) [Input]
inps) = do
  Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
maybenest <- case ( Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam,
                      Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
                    ) of
    ([Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e], Result
res)
      | (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) ->
          Scope rep
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a. Scope rep -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo rep)] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (LParamInfo rep)] -> Scope rep)
-> [Param (LParamInfo rep)] -> Scope rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) (m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
 -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a b. (a -> b) -> a -> b
$
            Exp rep -> m (Either NotSOAC (SOAC rep))
forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
SOAC.fromExp Exp rep
e
              m (Either NotSOAC (SOAC rep))
-> (Either NotSOAC (SOAC rep)
    -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (NotSOAC
 -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> (SOAC rep
    -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> Either NotSOAC (SOAC rep)
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
 -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> (NotSOAC
    -> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
-> NotSOAC
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NotSOAC -> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
forall a b. a -> Either a b
Left) ((Maybe (MapNest rep)
 -> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
-> m (Maybe (MapNest rep))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe (Pat (LetDec rep), MapNest rep)
-> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
forall a b. b -> Either a b
Right (Maybe (Pat (LetDec rep), MapNest rep)
 -> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
-> (Maybe (MapNest rep) -> Maybe (Pat (LetDec rep), MapNest rep))
-> Maybe (MapNest rep)
-> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MapNest rep -> (Pat (LetDec rep), MapNest rep))
-> Maybe (MapNest rep) -> Maybe (Pat (LetDec rep), MapNest rep)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pat (LetDec rep)
pat,)) (m (Maybe (MapNest rep))
 -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> (SOAC rep -> m (Maybe (MapNest rep)))
-> SOAC rep
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ident] -> SOAC rep -> m (Maybe (MapNest rep))
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
[Ident] -> SOAC rep -> m (Maybe (MapNest rep))
fromSOAC' [Ident]
bound')
    ([Stm rep], Result)
_ ->
      Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
 -> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))))
-> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
-> m (Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep)))
forall a b. (a -> b) -> a -> b
$ Maybe (Pat (LetDec rep), MapNest rep)
-> Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
forall a b. b -> Either a b
Right Maybe (Pat (LetDec rep), MapNest rep)
forall a. Maybe a
Nothing

  case Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
maybenest of
    -- Do we have a nested MapNest?
    Right (Just (Pat (LetDec rep)
pat, mn :: MapNest rep
mn@(MapNest SubExp
inner_w Lambda rep
body' [Nesting rep]
ns' [Input]
inps'))) -> do
      ([VName]
ps, [Input]
inps'') <-
        [(VName, Input)] -> ([VName], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip
          ([(VName, Input)] -> ([VName], [Input]))
-> m [(VName, Input)] -> m ([VName], [Input])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp
-> [(VName, Input)] -> [(VName, Input)] -> m [(VName, Input)]
forall (m :: * -> *).
MonadFreshNames m =>
SubExp
-> [(VName, Input)] -> [(VName, Input)] -> m [(VName, Input)]
fixInputs
            SubExp
w
            ([VName] -> [Input] -> [(VName, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [Input]
inps)
            ([VName] -> [Input] -> [(VName, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (MapNest rep -> [VName]
forall rep. MapNest rep -> [VName]
params MapNest rep
mn) [Input]
inps')
      let n' :: Nesting rep
n' =
            Nesting
              { nestingParamNames :: [VName]
nestingParamNames = [VName]
ps,
                nestingResult :: [VName]
nestingResult = Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat,
                nestingReturnType :: [Type]
nestingReturnType = MapNest rep -> [Type]
forall rep. MapNest rep -> [Type]
typeOf MapNest rep
mn,
                nestingWidth :: SubExp
nestingWidth = SubExp
inner_w
              }
      Maybe (MapNest rep) -> m (Maybe (MapNest rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (MapNest rep) -> m (Maybe (MapNest rep)))
-> Maybe (MapNest rep) -> m (Maybe (MapNest rep))
forall a b. (a -> b) -> a -> b
$ MapNest rep -> Maybe (MapNest rep)
forall a. a -> Maybe a
Just (MapNest rep -> Maybe (MapNest rep))
-> MapNest rep -> Maybe (MapNest rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest SubExp
w Lambda rep
body' (Nesting rep
n' Nesting rep -> [Nesting rep] -> [Nesting rep]
forall a. a -> [a] -> [a]
: [Nesting rep]
ns') [Input]
inps''
    -- No nested MapNest it seems.
    Either NotSOAC (Maybe (Pat (LetDec rep), MapNest rep))
_ -> do
      let isBound :: VName -> Maybe Ident
isBound VName
name
            | Just Ident
param <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName
name ==) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
bound =
                Ident -> Maybe Ident
forall a. a -> Maybe a
Just Ident
param
            | Bool
otherwise =
                Maybe Ident
forall a. Maybe a
Nothing
          boundUsedInBody :: [Ident]
boundUsedInBody =
            (VName -> Maybe Ident) -> [VName] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe VName -> Maybe Ident
isBound ([VName] -> [Ident]) -> [VName] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam
      [Ident]
newParams <- (Ident -> m Ident) -> [Ident] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_wasfree")) [Ident]
boundUsedInBody
      let subst :: Map VName VName
subst =
            [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
              [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
boundUsedInBody) ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newParams)
          inps' :: [Input]
inps' =
            [Input]
inps
              [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map
                (ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ShapeBase SubExp -> ArrayTransform
SOAC.Replicate Certs
forall a. Monoid a => a
mempty (ShapeBase SubExp -> ArrayTransform)
-> ShapeBase SubExp -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (Input -> Input) -> (Ident -> Input) -> Ident -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> Input
SOAC.identInput)
                [Ident]
boundUsedInBody
          lam' :: Lambda rep
lam' =
            Lambda rep
lam
              { lambdaBody :: Body rep
lambdaBody =
                  Map VName VName -> Body rep -> Body rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (Body rep -> Body rep) -> Body rep -> Body rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam,
                lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
                  Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
                    [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
name Type
t | Ident VName
name Type
t <- [Ident]
newParams]
              }
      Maybe (MapNest rep) -> m (Maybe (MapNest rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (MapNest rep) -> m (Maybe (MapNest rep)))
-> Maybe (MapNest rep) -> m (Maybe (MapNest rep))
forall a b. (a -> b) -> a -> b
$ MapNest rep -> Maybe (MapNest rep)
forall a. a -> Maybe a
Just (MapNest rep -> Maybe (MapNest rep))
-> MapNest rep -> Maybe (MapNest rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest SubExp
w Lambda rep
lam' [] [Input]
inps'
  where
    bound' :: [Ident]
bound' = [Ident]
bound [Ident] -> [Ident] -> [Ident]
forall a. Semigroup a => a -> a -> a
<> (Param Type -> Ident) -> [Param Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
fromSOAC' [Ident]
_ SOAC rep
_ = Maybe (MapNest rep) -> m (Maybe (MapNest rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (MapNest rep)
forall a. Maybe a
Nothing

toSOAC ::
  ( MonadFreshNames m,
    HasScope rep m,
    Buildable rep,
    BuilderOps rep,
    Op rep ~ Futhark.SOAC rep
  ) =>
  MapNest rep ->
  m (SOAC rep)
toSOAC :: forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
toSOAC (MapNest SubExp
w Lambda rep
lam [] [Input]
inps) =
  SOAC rep -> m (SOAC rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC rep -> m (SOAC rep)) -> SOAC rep -> m (SOAC rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm rep -> [Input] -> SOAC rep
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
lam) [Input]
inps
toSOAC (MapNest SubExp
w Lambda rep
lam (Nesting [VName]
npnames [VName]
nres [Type]
nrettype SubExp
nw : [Nesting rep]
ns) [Input]
inps) = do
  let nparams :: [Param Type]
nparams = (VName -> Type -> Param Type) -> [VName] -> [Type] -> [Param Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty) [VName]
npnames ([Type] -> [Param Type]) -> [Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ (Input -> Type) -> [Input] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
SOAC.inputRowType [Input]
inps
  Body rep
body <- Builder rep (Body rep) -> m (Body rep)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder rep (Body rep) -> m (Body rep))
-> Builder rep (Body rep) -> m (Body rep)
forall a b. (a -> b) -> a -> b
$
    Scope rep -> Builder rep (Body rep) -> Builder rep (Body rep)
forall a.
Scope rep
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
nparams) (Builder rep (Body rep) -> Builder rep (Body rep))
-> Builder rep (Body rep) -> Builder rep (Body rep)
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
nres
        (Exp rep -> BuilderT rep (State VNameSource) ())
-> BuilderT rep (State VNameSource) (Exp rep)
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC rep -> BuilderT rep (State VNameSource) (Exp rep)
SOAC (Rep (BuilderT rep (State VNameSource)))
-> BuilderT
     rep
     (State VNameSource)
     (Exp (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp
        (SOAC rep -> BuilderT rep (State VNameSource) (Exp rep))
-> BuilderT rep (State VNameSource) (SOAC rep)
-> BuilderT rep (State VNameSource) (Exp rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapNest rep -> BuilderT rep (State VNameSource) (SOAC rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
toSOAC (SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest SubExp
nw Lambda rep
lam [Nesting rep]
ns ([Input] -> MapNest rep) -> [Input] -> MapNest rep
forall a b. (a -> b) -> a -> b
$ (Param Type -> Input) -> [Param Type] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input) -> (Param Type -> Ident) -> Param Type -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param Type]
nparams)
      Body rep -> Builder rep (Body rep)
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> Builder rep (Body rep))
-> Body rep -> Builder rep (Body rep)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body rep
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody ([SubExp] -> Body rep) -> [SubExp] -> Body rep
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nres
  let outerlam :: Lambda rep
outerlam =
        Lambda
          { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
nparams,
            lambdaBody :: Body rep
lambdaBody = Body rep
body,
            lambdaReturnType :: [Type]
lambdaReturnType = [Type]
nrettype
          }
  SOAC rep -> m (SOAC rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC rep -> m (SOAC rep)) -> SOAC rep -> m (SOAC rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm rep -> [Input] -> SOAC rep
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
outerlam) [Input]
inps

fixInputs ::
  (MonadFreshNames m) =>
  SubExp ->
  [(VName, SOAC.Input)] ->
  [(VName, SOAC.Input)] ->
  m [(VName, SOAC.Input)]
fixInputs :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp
-> [(VName, Input)] -> [(VName, Input)] -> m [(VName, Input)]
fixInputs SubExp
w [(VName, Input)]
ourInps = ((VName, Input) -> m (VName, Input))
-> [(VName, Input)] -> m [(VName, Input)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, Input) -> m (VName, Input)
inspect
  where
    isParam :: a -> (a, b) -> Bool
isParam a
x (a
y, b
_) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y

    inspect :: (VName, Input) -> m (VName, Input)
inspect (VName
_, SOAC.Input ArrayTransforms
ts VName
v Type
_)
      | Just (VName
p, Input
pInp) <- ((VName, Input) -> Bool)
-> [(VName, Input)] -> Maybe (VName, Input)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> (VName, Input) -> Bool
forall {a} {b}. Eq a => a -> (a, b) -> Bool
isParam VName
v) [(VName, Input)]
ourInps = do
          let pInp' :: Input
pInp' = ArrayTransforms -> Input -> Input
SOAC.transformRows ArrayTransforms
ts Input
pInp
          VName
p' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
p
          (VName, Input) -> m (VName, Input)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
p', Input
pInp')
    inspect (VName
param, SOAC.Input ArrayTransforms
ts VName
a Type
t) = do
      VName
param' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString (VName -> String
baseString VName
param String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_rep")
      (VName, Input) -> m (VName, Input)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
param', ArrayTransforms -> VName -> Type -> Input
SOAC.Input (ArrayTransforms
ts ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certs -> ShapeBase SubExp -> ArrayTransform
SOAC.Replicate Certs
forall a. Monoid a => a
mempty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w])) VName
a Type
t)