{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Definitions for multicore operations.
--
-- Most of the interesting stuff is in "Futhark.IR.SegOp", which is
-- also re-exported from here.
module Futhark.IR.MC.Op
  ( MCOp (..),
    traverseMCOpStms,
    typeCheckMCOp,
    simplifyMCOp,
    module Futhark.IR.SegOp,
  )
where

import Data.Bifunctor (first)
import Futhark.Analysis.Metrics
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.Pretty
  ( Pretty,
    nestedBlock,
    pretty,
    (<+>),
    (</>),
  )
import Prelude hiding (id, (.))

-- | An operation for the multicore representation.  Feel free to
-- extend this on an ad hoc basis as needed.  Parameterised with some
-- other operation.
data MCOp rep op
  = -- | The first 'SegOp' (if it exists) contains nested parallelism,
    -- while the second one has a fully sequential body.  They are
    -- semantically fully equivalent.
    ParOp
      (Maybe (SegOp () rep))
      (SegOp () rep)
  | -- | Something else (in practice often a SOAC).
    OtherOp op
  deriving (MCOp rep op -> MCOp rep op -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k) op.
(RepTypes rep, Eq op) =>
MCOp rep op -> MCOp rep op -> Bool
/= :: MCOp rep op -> MCOp rep op -> Bool
$c/= :: forall k (rep :: k) op.
(RepTypes rep, Eq op) =>
MCOp rep op -> MCOp rep op -> Bool
== :: MCOp rep op -> MCOp rep op -> Bool
$c== :: forall k (rep :: k) op.
(RepTypes rep, Eq op) =>
MCOp rep op -> MCOp rep op -> Bool
Eq, MCOp rep op -> MCOp rep op -> Bool
MCOp rep op -> MCOp rep op -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {k} {rep :: k} {op}.
(RepTypes rep, Ord op) =>
Eq (MCOp rep op)
forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Bool
forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Ordering
forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> MCOp rep op
min :: MCOp rep op -> MCOp rep op -> MCOp rep op
$cmin :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> MCOp rep op
max :: MCOp rep op -> MCOp rep op -> MCOp rep op
$cmax :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> MCOp rep op
>= :: MCOp rep op -> MCOp rep op -> Bool
$c>= :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Bool
> :: MCOp rep op -> MCOp rep op -> Bool
$c> :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Bool
<= :: MCOp rep op -> MCOp rep op -> Bool
$c<= :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Bool
< :: MCOp rep op -> MCOp rep op -> Bool
$c< :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Bool
compare :: MCOp rep op -> MCOp rep op -> Ordering
$ccompare :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
MCOp rep op -> MCOp rep op -> Ordering
Ord, Int -> MCOp rep op -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k) op.
(RepTypes rep, Show op) =>
Int -> MCOp rep op -> ShowS
forall k (rep :: k) op.
(RepTypes rep, Show op) =>
[MCOp rep op] -> ShowS
forall k (rep :: k) op.
(RepTypes rep, Show op) =>
MCOp rep op -> String
showList :: [MCOp rep op] -> ShowS
$cshowList :: forall k (rep :: k) op.
(RepTypes rep, Show op) =>
[MCOp rep op] -> ShowS
show :: MCOp rep op -> String
$cshow :: forall k (rep :: k) op.
(RepTypes rep, Show op) =>
MCOp rep op -> String
showsPrec :: Int -> MCOp rep op -> ShowS
$cshowsPrec :: forall k (rep :: k) op.
(RepTypes rep, Show op) =>
Int -> MCOp rep op -> ShowS
Show)

traverseMCOpStms :: Monad m => OpStmsTraverser m op rep -> OpStmsTraverser m (MCOp rep op) rep
traverseMCOpStms :: forall {k} (m :: * -> *) op (rep :: k).
Monad m =>
OpStmsTraverser m op rep -> OpStmsTraverser m (MCOp rep op) rep
traverseMCOpStms OpStmsTraverser m op rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) =
  forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f) Maybe (SegOp () rep)
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp () rep
op
traverseMCOpStms OpStmsTraverser m op rep
onInner Scope rep -> Stms rep -> m (Stms rep)
f (OtherOp op
op) = forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m op rep
onInner Scope rep -> Stms rep -> m (Stms rep)
f op
op

instance (ASTRep rep, Substitute op) => Substitute (MCOp rep op) where
  substituteNames :: Map VName VName -> MCOp rep op -> MCOp rep op
substituteNames Map VName VName
substs (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) =
    forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (SegOp () rep)
par_op) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp () rep
op)
  substituteNames Map VName VName
substs (OtherOp op
op) =
    forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op

instance (ASTRep rep, Rename op) => Rename (MCOp rep op) where
  rename :: MCOp rep op -> RenameM (MCOp rep op)
rename (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) = forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename Maybe (SegOp () rep)
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Rename a => a -> RenameM a
rename SegOp () rep
op
  rename (OtherOp op
op) = forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename op
op

instance (ASTRep rep, FreeIn op) => FreeIn (MCOp rep op) where
  freeIn' :: MCOp rep op -> FV
freeIn' (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) = forall a. FreeIn a => a -> FV
freeIn' Maybe (SegOp () rep)
par_op forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SegOp () rep
op
  freeIn' (OtherOp op
op) = forall a. FreeIn a => a -> FV
freeIn' op
op

instance (ASTRep rep, IsOp op) => IsOp (MCOp rep op) where
  safeOp :: MCOp rep op -> Bool
safeOp (ParOp Maybe (SegOp () rep)
_ SegOp () rep
op) = forall op. IsOp op => op -> Bool
safeOp SegOp () rep
op
  safeOp (OtherOp op
op) = forall op. IsOp op => op -> Bool
safeOp op
op

  cheapOp :: MCOp rep op -> Bool
cheapOp (ParOp Maybe (SegOp () rep)
_ SegOp () rep
op) = forall op. IsOp op => op -> Bool
cheapOp SegOp () rep
op
  cheapOp (OtherOp op
op) = forall op. IsOp op => op -> Bool
cheapOp op
op

instance TypedOp op => TypedOp (MCOp rep op) where
  opType :: forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
MCOp rep op -> m [ExtType]
opType (ParOp Maybe (SegOp () rep)
_ SegOp () rep
op) = forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp () rep
op
  opType (OtherOp op
op) = forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op

instance
  (Aliased rep, AliasedOp op, ASTRep rep) =>
  AliasedOp (MCOp rep op)
  where
  opAliases :: MCOp rep op -> [Names]
opAliases (ParOp Maybe (SegOp () rep)
_ SegOp () rep
op) = forall op. AliasedOp op => op -> [Names]
opAliases SegOp () rep
op
  opAliases (OtherOp op
op) = forall op. AliasedOp op => op -> [Names]
opAliases op
op

  consumedInOp :: MCOp rep op -> Names
consumedInOp (ParOp Maybe (SegOp () rep)
_ SegOp () rep
op) = forall op. AliasedOp op => op -> Names
consumedInOp SegOp () rep
op
  consumedInOp (OtherOp op
op) = forall op. AliasedOp op => op -> Names
consumedInOp op
op

instance
  (CanBeAliased (Op rep), CanBeAliased op, ASTRep rep) =>
  CanBeAliased (MCOp rep op)
  where
  type OpWithAliases (MCOp rep op) = MCOp (Aliases rep) (OpWithAliases op)

  addOpAliases :: AliasTable -> MCOp rep op -> OpWithAliases (MCOp rep op)
addOpAliases AliasTable
aliases (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) =
    forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (SegOp () rep)
par_op) (forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases SegOp () rep
op)
  addOpAliases AliasTable
aliases (OtherOp op
op) =
    forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases op
op

  removeOpAliases :: OpWithAliases (MCOp rep op) -> MCOp rep op
removeOpAliases (ParOp Maybe (SegOp () (Aliases rep))
par_op SegOp () (Aliases rep)
op) =
    forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (SegOp () (Aliases rep))
par_op) (forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases SegOp () (Aliases rep)
op)
  removeOpAliases (OtherOp OpWithAliases op
op) =
    forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op

instance
  (CanBeWise (Op rep), CanBeWise op, ASTRep rep) =>
  CanBeWise (MCOp rep op)
  where
  type OpWithWisdom (MCOp rep op) = MCOp (Wise rep) (OpWithWisdom op)

  removeOpWisdom :: OpWithWisdom (MCOp rep op) -> MCOp rep op
removeOpWisdom (ParOp Maybe (SegOp () (Wise rep))
par_op SegOp () (Wise rep)
op) =
    forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (SegOp () (Wise rep))
par_op) (forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom SegOp () (Wise rep)
op)
  removeOpWisdom (OtherOp OpWithWisdom op
op) =
    forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op

  addOpWisdom :: MCOp rep op -> OpWithWisdom (MCOp rep op)
addOpWisdom (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) =
    forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall op. CanBeWise op => op -> OpWithWisdom op
addOpWisdom forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (SegOp () rep)
par_op) (forall op. CanBeWise op => op -> OpWithWisdom op
addOpWisdom SegOp () rep
op)
  addOpWisdom (OtherOp op
op) =
    forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => op -> OpWithWisdom op
addOpWisdom op
op

instance (ASTRep rep, ST.IndexOp op) => ST.IndexOp (MCOp rep op) where
  indexOp :: forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> MCOp rep op -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (ParOp Maybe (SegOp () rep)
_ SegOp () rep
op) [TPrimExp Int64 VName]
is = forall op {k} (rep :: k).
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k SegOp () rep
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
vtable Int
k (OtherOp op
op) [TPrimExp Int64 VName]
is = forall op {k} (rep :: k).
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k op
op [TPrimExp Int64 VName]
is

instance (PrettyRep rep, Pretty op) => Pretty (MCOp rep op) where
  pretty :: forall ann. MCOp rep op -> Doc ann
pretty (ParOp Maybe (SegOp () rep)
Nothing SegOp () rep
op) = forall a ann. Pretty a => a -> Doc ann
pretty SegOp () rep
op
  pretty (ParOp (Just SegOp () rep
par_op) SegOp () rep
op) =
    Doc ann
"par"
      forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty SegOp () rep
par_op)
      forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc ann
"seq"
      forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty SegOp () rep
op)
  pretty (OtherOp op
op) = forall a ann. Pretty a => a -> Doc ann
pretty op
op

instance (OpMetrics (Op rep), OpMetrics op) => OpMetrics (MCOp rep op) where
  opMetrics :: MCOp rep op -> MetricsM ()
opMetrics (ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics Maybe (SegOp () rep)
par_op forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp () rep
op
  opMetrics (OtherOp op
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op

typeCheckMCOp ::
  TC.Checkable rep =>
  (op -> TC.TypeM rep ()) ->
  MCOp (Aliases rep) op ->
  TC.TypeM rep ()
typeCheckMCOp :: forall {k} (rep :: k) op.
Checkable rep =>
(op -> TypeM rep ()) -> MCOp (Aliases rep) op -> TypeM rep ()
typeCheckMCOp op -> TypeM rep ()
_ (ParOp (Just SegOp () (Aliases rep)
par_op) SegOp () (Aliases rep)
op) = do
  -- It is valid for the same array to be consumed in both par_op and op.
  ((), ())
_ <- forall {k} (rep :: k) lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp forall (f :: * -> *) a. Applicative f => a -> f a
pure SegOp () (Aliases rep)
par_op forall {k} (rep :: k) a b.
TypeM rep a -> TypeM rep b -> TypeM rep (a, b)
`TC.alternative` forall {k} (rep :: k) lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp forall (f :: * -> *) a. Applicative f => a -> f a
pure SegOp () (Aliases rep)
op
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckMCOp op -> TypeM rep ()
_ (ParOp Maybe (SegOp () (Aliases rep))
Nothing SegOp () (Aliases rep)
op) =
  forall {k} (rep :: k) lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp forall (f :: * -> *) a. Applicative f => a -> f a
pure SegOp () (Aliases rep)
op
typeCheckMCOp op -> TypeM rep ()
f (OtherOp op
op) = op -> TypeM rep ()
f op
op

simplifyMCOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ ()
  ) =>
  Simplify.SimplifyOp rep op ->
  MCOp (Wise rep) op ->
  Engine.SimpleM rep (MCOp (Wise rep) op, Stms (Wise rep))
simplifyMCOp :: forall {k} (rep :: k) op.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep op
-> MCOp (Wise rep) op
-> SimpleM rep (MCOp (Wise rep) op, Stms (Wise rep))
simplifyMCOp SimplifyOp rep op
f (OtherOp op
op) = do
  (op
op', Stms (Wise rep)
stms) <- SimplifyOp rep op
f op
op
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp op
op', Stms (Wise rep)
stms)
simplifyMCOp SimplifyOp rep op
_ (ParOp Maybe (SegOp () (Wise rep))
par_op SegOp () (Wise rep)
op) = do
  (Maybe (SegOp () (Wise rep))
par_op', Stms (Wise rep)
par_op_hoisted) <-
    case Maybe (SegOp () (Wise rep))
par_op of
      Maybe (SegOp () (Wise rep))
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Maybe a
Nothing, forall a. Monoid a => a
mempty)
      Just SegOp () (Wise rep)
x -> forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp SegOp () (Wise rep)
x

  (SegOp () (Wise rep)
op', Stms (Wise rep)
op_hoisted) <- forall {k} (rep :: k) lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp SegOp () (Wise rep)
op

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () (Wise rep))
par_op' SegOp () (Wise rep)
op', Stms (Wise rep)
par_op_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
op_hoisted)