{-# LANGUAGE FlexibleContexts #-} module Futhark.Internalise.AccurateSizes ( shapeBody , annotateArrayShape , argShapes , ensureResultShape , ensureResultExtShape , ensureResultExtShapeNoCtx , ensureExtShape , ensureShape , ensureArgShapes ) where import Control.Monad import Data.Loc import qualified Data.Map.Strict as M import qualified Data.Set as S import Futhark.Construct import Futhark.Representation.AST shapeBody :: (HasScope lore m, MonadFreshNames m, BinderOps lore, Bindable lore) => [VName] -> [Type] -> Body lore -> m (Body lore) shapeBody shapenames ts body = runBodyBinder $ do ses <- bodyBind body sets <- mapM subExpType ses resultBody <$> argShapes shapenames ts sets annotateArrayShape :: ArrayShape shape => TypeBase shape u -> [Int] -> TypeBase Shape u annotateArrayShape t newshape = t `setArrayShape` Shape (take (arrayRank t) $ map (intConst Int32 . toInteger) $ newshape ++ repeat 0) -- Some trickery is needed here to predict sensible values for -- dimensions that are used exclusively as the inner dimension of an -- array. The issue is that the dimension may be inside an empty -- array. In this case, the dimension inside the empty array should -- not count, as it will be zero. The solution we use is to take the -- maximum of such sizes; this will effectively disregard the zeroes. argShapes :: MonadBinder m => [VName] -> [TypeBase Shape u0] -> [TypeBase Shape u1] -> m [SubExp] argShapes shapes valts valargts = mapM addShape shapes where mapping = shapeMapping valts valargts outer_dims = map (arraySize 0) valts addShape name = case M.lookup name mapping of Just s | x:xs <- S.toList s -> if Var name `elem` outer_dims then return x else letSubExp (baseString name) =<< foldBinOp (SMax Int32) x xs _ -> return $ intConst Int32 0 ensureResultShape :: MonadBinder m => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> [Type] -> Body (Lore m) -> m (Body (Lore m)) ensureResultShape asserting msg loc = ensureResultExtShape asserting msg loc . staticShapes ensureResultExtShape :: MonadBinder m => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> [ExtType] -> Body (Lore m) -> m (Body (Lore m)) ensureResultExtShape asserting msg loc rettype body = insertStmsM $ do reses <- bodyBind =<< ensureResultExtShapeNoCtx asserting msg loc rettype body ts <- mapM subExpType reses let ctx = extractShapeContext rettype $ map arrayDims ts mkBodyM mempty $ ctx ++ reses ensureResultExtShapeNoCtx :: MonadBinder m => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> [ExtType] -> Body (Lore m) -> m (Body (Lore m)) ensureResultExtShapeNoCtx asserting msg loc rettype body = insertStmsM $ do es <- bodyBind body es_ts <- mapM subExpType es let ext_mapping = shapeExtMapping rettype es_ts rettype' = foldr (uncurry fixExt) rettype $ M.toList ext_mapping assertProperShape t se = let name = "result_proper_shape" in ensureExtShape asserting msg loc t name se resultBodyM =<< zipWithM assertProperShape rettype' es ensureExtShape :: MonadBinder m => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> ExtType -> String -> SubExp -> m SubExp ensureExtShape asserting msg loc t name orig | Array{} <- t, Var v <- orig = Var <$> ensureShapeVar asserting msg loc t name v | otherwise = return orig ensureShape :: MonadBinder m => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> Type -> String -> SubExp -> m SubExp ensureShape asserting msg loc = ensureExtShape asserting msg loc . staticShapes1 -- | Reshape the arguments to a function so that they fit the expected -- shape declarations. Not used to change rank of arguments. Assumes -- everything is otherwise type-correct. ensureArgShapes :: (MonadBinder m, Typed (TypeBase Shape u)) => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> [VName] -> [TypeBase Shape u] -> [SubExp] -> m [SubExp] ensureArgShapes asserting msg loc shapes paramts args = zipWithM ensureArgShape (expectedTypes shapes paramts args) args where ensureArgShape _ (Constant v) = return $ Constant v ensureArgShape t (Var v) | arrayRank t < 1 = return $ Var v | otherwise = ensureShape asserting msg loc t (baseString v) $ Var v ensureShapeVar :: MonadBinder m => (m Certificates -> m Certificates) -> ErrorMsg SubExp -> SrcLoc -> ExtType -> String -> VName -> m VName ensureShapeVar asserting msg loc t name v | Array{} <- t = do newdims <- arrayDims . removeExistentials t <$> lookupType v olddims <- arrayDims <$> lookupType v if newdims == olddims then return v else do certs <- asserting $ do old_zero <- letSubExp "old_empty" =<< anyZero olddims new_zero <- letSubExp "new_empty" =<< anyZero newdims both_empty <- letSubExp "both_empty" $ BasicOp $ BinOp LogAnd old_zero new_zero matches <- zipWithM checkDim newdims olddims all_match <- letSubExp "match" =<< foldBinOp LogAnd (constant True) matches empty_or_match <- letSubExp "empty_or_match" $ BasicOp $ BinOp LogOr both_empty all_match Certificates . pure <$> letExp "empty_or_match_cert" (BasicOp $ Assert empty_or_match msg (loc, [])) certifying certs $ letExp name $ shapeCoerce newdims v | otherwise = return v where checkDim desired has = letSubExp "dim_match" $ BasicOp $ CmpOp (CmpEq int32) desired has anyZero = foldBinOp LogOr (constant False) <=< mapM (letSubExp "dim_zero" . BasicOp . CmpOp (CmpEq int32) (intConst Int32 0))