{-# LANGUAGE FlexibleContexts #-}

module Futhark.Internalise.AccurateSizes
  ( argShapes,
    ensureResultShape,
    ensureResultExtShape,
    ensureExtShape,
    ensureShape,
    ensureArgShapes,
  )
where

import Control.Monad
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.SOACS
import Futhark.Internalise.Monad
import Futhark.Util (takeLast)

shapeMapping ::
  (HasScope SOACS m, Monad m) =>
  [FParam] ->
  [Type] ->
  m (M.Map VName SubExp)
shapeMapping :: [FParam] -> [Type] -> m (Map VName SubExp)
shapeMapping [FParam]
all_params [Type]
value_arg_types =
  [Map VName SubExp] -> Map VName SubExp
forall a. Monoid a => [a] -> a
mconcat ([Map VName SubExp] -> Map VName SubExp)
-> m [Map VName SubExp] -> m (Map VName SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Type -> m (Map VName SubExp))
-> [Type] -> [Type] -> m [Map VName SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Type -> Type -> m (Map VName SubExp)
forall (f :: * -> *).
Monad f =>
Type -> Type -> f (Map VName SubExp)
f ((Param DeclType -> Type) -> [Param DeclType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param DeclType]
value_params) [Type]
value_arg_types
  where
    value_params :: [Param DeclType]
value_params = Int -> [Param DeclType] -> [Param DeclType]
forall a. Int -> [a] -> [a]
takeLast ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
value_arg_types) [Param DeclType]
[FParam]
all_params

    f :: Type -> Type -> f (Map VName SubExp)
f t1 :: Type
t1@Array {} t2 :: Type
t2@Array {} =
      Map VName SubExp -> f (Map VName SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName SubExp -> f (Map VName SubExp))
-> Map VName SubExp -> f (Map VName SubExp)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Maybe (VName, SubExp))
-> [(SubExp, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (SubExp, SubExp) -> Maybe (VName, SubExp)
forall b. (SubExp, b) -> Maybe (VName, b)
match ([(SubExp, SubExp)] -> [(VName, SubExp)])
-> [(SubExp, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t2)
    f (Acc VName
acc1 Shape
ispace1 [Type]
ts1 NoUniqueness
_) (Acc VName
acc2 Shape
ispace2 [Type]
ts2 NoUniqueness
_) = do
      let ispace_m :: Map VName SubExp
ispace_m =
            [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> ([(SubExp, SubExp)] -> [(VName, SubExp)])
-> [(SubExp, SubExp)]
-> Map VName SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp, SubExp) -> Maybe (VName, SubExp))
-> [(SubExp, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (SubExp, SubExp) -> Maybe (VName, SubExp)
forall b. (SubExp, b) -> Maybe (VName, b)
match ([(SubExp, SubExp)] -> Map VName SubExp)
-> [(SubExp, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
              [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace1) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace2)
      Map VName SubExp
arr_sizes_m <- [Map VName SubExp] -> Map VName SubExp
forall a. Monoid a => [a] -> a
mconcat ([Map VName SubExp] -> Map VName SubExp)
-> f [Map VName SubExp] -> f (Map VName SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> Type -> f (Map VName SubExp))
-> [Type] -> [Type] -> f [Map VName SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Type -> Type -> f (Map VName SubExp)
f [Type]
ts1 [Type]
ts2
      Map VName SubExp -> f (Map VName SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName SubExp -> f (Map VName SubExp))
-> Map VName SubExp -> f (Map VName SubExp)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton VName
acc1 (VName -> SubExp
Var VName
acc2) Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
ispace_m Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arr_sizes_m
    f Type
_ Type
_ =
      Map VName SubExp -> f (Map VName SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName SubExp
forall a. Monoid a => a
mempty

    match :: (SubExp, b) -> Maybe (VName, b)
match (Var VName
v, b
se) = (VName, b) -> Maybe (VName, b)
forall a. a -> Maybe a
Just (VName
v, b
se)
    match (SubExp, b)
_ = Maybe (VName, b)
forall a. Maybe a
Nothing

argShapes ::
  (HasScope SOACS m, Monad m) =>
  [VName] ->
  [FParam] ->
  [Type] ->
  m [SubExp]
argShapes :: [VName] -> [FParam] -> [Type] -> m [SubExp]
argShapes [VName]
shapes [FParam]
all_params [Type]
valargts = do
  Map VName SubExp
mapping <- [FParam] -> [Type] -> m (Map VName SubExp)
forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
[FParam] -> [Type] -> m (Map VName SubExp)
shapeMapping [FParam]
all_params [Type]
valargts
  let addShape :: VName -> SubExp
addShape VName
name =
        case VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName SubExp
mapping of
          Just SubExp
se -> SubExp
se
          Maybe SubExp
_ -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 -- FIXME: we only need this because
          -- the defunctionaliser throws away
          -- sizes.
  [SubExp] -> m [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> m [SubExp]) -> [SubExp] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
addShape [VName]
shapes

ensureResultShape ::
  ErrorMsg SubExp ->
  SrcLoc ->
  [Type] ->
  Result ->
  InternaliseM Result
ensureResultShape :: ErrorMsg SubExp
-> SrcLoc -> [Type] -> Result -> InternaliseM Result
ensureResultShape ErrorMsg SubExp
msg SrcLoc
loc =
  ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShape ErrorMsg SubExp
msg SrcLoc
loc ([ExtType] -> Result -> InternaliseM Result)
-> ([Type] -> [ExtType]) -> [Type] -> Result -> InternaliseM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes

ensureResultExtShape ::
  ErrorMsg SubExp ->
  SrcLoc ->
  [ExtType] ->
  Result ->
  InternaliseM Result
ensureResultExtShape :: ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShape ErrorMsg SubExp
msg SrcLoc
loc [ExtType]
rettype Result
res = do
  Result
res' <- ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShapeNoCtx ErrorMsg SubExp
msg SrcLoc
loc [ExtType]
rettype Result
res
  [Type]
ts <- (SubExpRes -> InternaliseM Type) -> Result -> InternaliseM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res'
  let ctx :: [SubExp]
ctx = [ExtType] -> [[SubExp]] -> [SubExp]
forall u a. [TypeBase ExtShape u] -> [[a]] -> [a]
extractShapeContext [ExtType]
rettype ([[SubExp]] -> [SubExp]) -> [[SubExp]] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (Type -> [SubExp]) -> [Type] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims [Type]
ts
  Result -> InternaliseM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> InternaliseM Result) -> Result -> InternaliseM Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
ctx Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res'

ensureResultExtShapeNoCtx ::
  ErrorMsg SubExp ->
  SrcLoc ->
  [ExtType] ->
  Result ->
  InternaliseM Result
ensureResultExtShapeNoCtx :: ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShapeNoCtx ErrorMsg SubExp
msg SrcLoc
loc [ExtType]
rettype Result
es = do
  [Type]
es_ts <- (SubExpRes -> InternaliseM Type) -> Result -> InternaliseM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
es
  let ext_mapping :: Map Int SubExp
ext_mapping = [ExtType] -> [Type] -> Map Int SubExp
forall u u1.
[TypeBase ExtShape u] -> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping [ExtType]
rettype [Type]
es_ts
      rettype' :: [ExtType]
rettype' = ((Int, SubExp) -> [ExtType] -> [ExtType])
-> [ExtType] -> [(Int, SubExp)] -> [ExtType]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [ExtType] -> [ExtType])
-> (Int, SubExp) -> [ExtType] -> [ExtType]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [ExtType] -> [ExtType]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [ExtType]
rettype ([(Int, SubExp)] -> [ExtType]) -> [(Int, SubExp)] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ Map Int SubExp -> [(Int, SubExp)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int SubExp
ext_mapping
      assertProperShape :: ExtType -> SubExpRes -> InternaliseM SubExpRes
assertProperShape ExtType
t (SubExpRes Certs
cs SubExp
se) =
        let name :: [Char]
name = [Char]
"result_proper_shape"
         in Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> InternaliseM SubExp -> InternaliseM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name SubExp
se
  (ExtType -> SubExpRes -> InternaliseM SubExpRes)
-> [ExtType] -> Result -> InternaliseM Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> SubExpRes -> InternaliseM SubExpRes
assertProperShape [ExtType]
rettype' Result
es

ensureExtShape ::
  ErrorMsg SubExp ->
  SrcLoc ->
  ExtType ->
  String ->
  SubExp ->
  InternaliseM SubExp
ensureExtShape :: ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name SubExp
orig
  | Array {} <- ExtType
t,
    Var VName
v <- SubExp
orig =
    VName -> SubExp
Var (VName -> SubExp) -> InternaliseM VName -> InternaliseM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> VName -> InternaliseM VName
ensureShapeVar ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name VName
v
  | Bool
otherwise = SubExp -> InternaliseM SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
orig

ensureShape ::
  ErrorMsg SubExp ->
  SrcLoc ->
  Type ->
  String ->
  SubExp ->
  InternaliseM SubExp
ensureShape :: ErrorMsg SubExp
-> SrcLoc -> Type -> [Char] -> SubExp -> InternaliseM SubExp
ensureShape ErrorMsg SubExp
msg SrcLoc
loc = ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ErrorMsg SubExp
msg SrcLoc
loc (ExtType -> [Char] -> SubExp -> InternaliseM SubExp)
-> (Type -> ExtType)
-> Type
-> [Char]
-> SubExp
-> InternaliseM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> ExtType
forall u. TypeBase Shape u -> TypeBase ExtShape u
staticShapes1

-- | Reshape the arguments to a function so that they fit the expected
-- shape declarations.  Not used to change rank of arguments.  Assumes
-- everything is otherwise type-correct.
ensureArgShapes ::
  (Typed (TypeBase Shape u)) =>
  ErrorMsg SubExp ->
  SrcLoc ->
  [VName] ->
  [TypeBase Shape u] ->
  [SubExp] ->
  InternaliseM [SubExp]
ensureArgShapes :: ErrorMsg SubExp
-> SrcLoc
-> [VName]
-> [TypeBase Shape u]
-> [SubExp]
-> InternaliseM [SubExp]
ensureArgShapes ErrorMsg SubExp
msg SrcLoc
loc [VName]
shapes [TypeBase Shape u]
paramts [SubExp]
args =
  (Type -> SubExp -> InternaliseM SubExp)
-> [Type] -> [SubExp] -> InternaliseM [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Type -> SubExp -> InternaliseM SubExp
ensureArgShape ([VName] -> [TypeBase Shape u] -> [SubExp] -> [Type]
forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes [VName]
shapes [TypeBase Shape u]
paramts [SubExp]
args) [SubExp]
args
  where
    ensureArgShape :: Type -> SubExp -> InternaliseM SubExp
ensureArgShape Type
_ (Constant PrimValue
v) = SubExp -> InternaliseM SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> InternaliseM SubExp) -> SubExp -> InternaliseM SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
    ensureArgShape Type
t (Var VName
v)
      | Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = SubExp -> InternaliseM SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> InternaliseM SubExp) -> SubExp -> InternaliseM SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      | Bool
otherwise =
        ErrorMsg SubExp
-> SrcLoc -> Type -> [Char] -> SubExp -> InternaliseM SubExp
ensureShape ErrorMsg SubExp
msg SrcLoc
loc Type
t (VName -> [Char]
baseString VName
v) (SubExp -> InternaliseM SubExp) -> SubExp -> InternaliseM SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

ensureShapeVar ::
  ErrorMsg SubExp ->
  SrcLoc ->
  ExtType ->
  String ->
  VName ->
  InternaliseM VName
ensureShapeVar :: ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> VName -> InternaliseM VName
ensureShapeVar ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name VName
v
  | Array {} <- ExtType
t = do
    [SubExp]
newdims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> (Type -> Type) -> Type -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtType -> Type -> Type
removeExistentials ExtType
t (Type -> [SubExp]) -> InternaliseM Type -> InternaliseM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> InternaliseM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
    [SubExp]
olddims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> InternaliseM Type -> InternaliseM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> InternaliseM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
    if [SubExp]
newdims [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
olddims
      then VName -> InternaliseM VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
      else do
        [SubExp]
matches <- (SubExp -> SubExp -> InternaliseM SubExp)
-> [SubExp] -> [SubExp] -> InternaliseM [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> InternaliseM SubExp
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
checkDim [SubExp]
newdims [SubExp]
olddims
        SubExp
all_match <- [Char] -> Exp (Rep InternaliseM) -> InternaliseM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"match" (ExpT SOACS -> InternaliseM SubExp)
-> InternaliseM (ExpT SOACS) -> InternaliseM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InternaliseM (Exp (Rep InternaliseM))
forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll [SubExp]
matches
        Certs
cs <- [Char] -> SubExp -> ErrorMsg SubExp -> SrcLoc -> InternaliseM Certs
assert [Char]
"empty_or_match_cert" SubExp
all_match ErrorMsg SubExp
msg SrcLoc
loc
        Certs -> InternaliseM VName -> InternaliseM VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (InternaliseM VName -> InternaliseM VName)
-> InternaliseM VName -> InternaliseM VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Exp (Rep InternaliseM) -> InternaliseM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
name (Exp (Rep InternaliseM) -> InternaliseM VName)
-> Exp (Rep InternaliseM) -> InternaliseM VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> ExpT SOACS
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newdims VName
v
  | Bool
otherwise = VName -> InternaliseM VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
  where
    checkDim :: SubExp -> SubExp -> m SubExp
checkDim SubExp
desired SubExp
has =
      [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"dim_match" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
desired SubExp
has