{-# LANGUAGE FlexibleContexts #-} module Futhark.Internalise.Lambdas ( InternaliseLambda , internaliseMapLambda , internaliseStreamMapLambda , internaliseFoldLambda , internaliseStreamLambda , internalisePartitionLambda ) where import Control.Monad import Data.Loc import Language.Futhark as E import Futhark.Representation.SOACS as I import Futhark.MonadFreshNames import Futhark.Internalise.Monad import Futhark.Internalise.AccurateSizes import Futhark.Representation.SOACS.Simplify (simplifyLambda) -- | A function for internalising lambdas. type InternaliseLambda = E.Exp -> [I.Type] -> InternaliseM ([I.LParam], I.Body, [I.ExtType]) internaliseMapLambda :: InternaliseLambda -> E.Exp -> [I.SubExp] -> InternaliseM I.Lambda internaliseMapLambda internaliseLambda lam args = do argtypes <- mapM I.subExpType args let rowtypes = map I.rowType argtypes (params, body, rettype) <- internaliseLambda lam rowtypes (rettype', inner_shapes) <- instantiateShapes' rettype let outer_shape = arraysSize 0 argtypes shapefun <- makeShapeFun params body rettype' inner_shapes bindMapShapes index0 [] inner_shapes shapefun args outer_shape body' <- localScope (scopeOfLParams params) $ ensureResultShape asserting (ErrorMsg [ErrorString "not all iterations produce same shape"]) (srclocOf lam) rettype' body return $ I.Lambda params body' rettype' where index0 arg = do arg' <- letExp "arg" $ I.BasicOp $ I.SubExp arg arg_t <- lookupType arg' return $ I.BasicOp $ I.Index arg' $ fullSlice arg_t [I.DimFix zero] zero = constant (0::I.Int32) internaliseStreamMapLambda :: InternaliseLambda -> E.Exp -> [I.SubExp] -> InternaliseM I.Lambda internaliseStreamMapLambda internaliseLambda lam args = do chunk_size <- newVName "chunk_size" let chunk_param = I.Param chunk_size (I.Prim int32) outer = (`setOuterSize` I.Var chunk_size) localScope (scopeOfLParams [chunk_param]) $ do argtypes <- mapM I.subExpType args (lam_params, orig_body, rettype) <- internaliseLambda lam $ I.Prim int32 : map outer argtypes let orig_chunk_param : params = lam_params body <- runBodyBinder $ do letBindNames_ [paramName orig_chunk_param] $ I.BasicOp $ I.SubExp $ I.Var chunk_size return orig_body (rettype', inner_shapes) <- instantiateShapes' rettype let outer_shape = arraysSize 0 argtypes shapefun <- makeShapeFun (chunk_param:params) body rettype' inner_shapes bindMapShapes (slice0 chunk_size) [zero] inner_shapes shapefun args outer_shape body' <- localScope (scopeOfLParams params) $ insertStmsM $ do letBindNames_ [paramName orig_chunk_param] $ I.BasicOp $ I.SubExp $ I.Var chunk_size ensureResultShape asserting (ErrorMsg [ErrorString "not all iterations produce same shape"]) (srclocOf lam) (map outer rettype') body return $ I.Lambda (chunk_param:params) body' (map outer rettype') where slice0 chunk_size arg = do arg' <- letExp "arg" $ I.BasicOp $ I.SubExp arg arg_t <- lookupType arg' return $ I.BasicOp $ I.Index arg' $ fullSlice arg_t [I.DimSlice zero (I.Var chunk_size) one] zero = constant (0::I.Int32) one = constant (1::I.Int32) makeShapeFun :: [I.LParam] -> I.Body -> [Type] -> [I.Ident] -> InternaliseM I.Lambda makeShapeFun params body val_rettype inner_shapes = do -- Some of 'params' may be unique, which means that the shape slice -- would consume its input. This is not acceptable - that input is -- needed for the value function! Hence, for all unique parameters, -- we create a substitute non-unique parameter, and insert a -- copy-binding in the body of the function. (params', copystms) <- nonuniqueParams params shape_body <- runBodyBinder $ localScope (scopeOfLParams params') $ do mapM_ addStm copystms shapeBody (map I.identName inner_shapes) val_rettype body return $ I.Lambda params' shape_body rettype where rettype = replicate (length inner_shapes) $ I.Prim int32 bindMapShapes :: (SubExp -> InternaliseM I.Exp) -> [SubExp] -> [I.Ident] -> I.Lambda -> [I.SubExp] -> SubExp -> InternaliseM () bindMapShapes indexArg extra_args inner_shapes sizefun args outer_shape | null $ I.lambdaReturnType sizefun = return () | otherwise = do let size_args = replicate (length $ lambdaParams sizefun) Nothing sizefun' <- simplifyLambda sizefun size_args let sizefun_safe = all (I.safeExp . I.stmExp) $ I.bodyStms $ I.lambdaBody sizefun' sizefun_arg_invariant = not $ any ((`nameIn` freeIn (I.lambdaBody sizefun')) . I.paramName) $ lambdaParams sizefun' if sizefun_safe && sizefun_arg_invariant then do ses <- bodyBind $ lambdaBody sizefun' forM_ (zip inner_shapes ses) $ \(v, se) -> letBind_ (basicPattern [] [v]) $ I.BasicOp $ I.SubExp se else letBind_ (basicPattern [] inner_shapes) =<< eIf' isnonempty nonemptybranch emptybranch IfFallback where emptybranch = pure $ resultBody (map (const zero) $ I.lambdaReturnType sizefun) nonemptybranch = insertStmsM $ resultBody <$> (eLambda sizefun . (map eSubExp extra_args++) $ map indexArg args) isnonempty = eNot $ eCmpOp (I.CmpEq I.int32) (pure $ I.BasicOp $ I.SubExp outer_shape) (pure $ I.BasicOp $ SubExp zero) zero = constant (0::I.Int32) internaliseFoldLambda :: InternaliseLambda -> E.Exp -> [I.Type] -> [I.Type] -> InternaliseM I.Lambda internaliseFoldLambda internaliseLambda lam acctypes arrtypes = do let rowtypes = map I.rowType arrtypes (params, body, rettype) <- internaliseLambda lam $ acctypes ++ rowtypes let rettype' = [ t `I.setArrayShape` I.arrayShape shape | (t,shape) <- zip rettype acctypes ] -- The result of the body must have the exact same shape as the -- initial accumulator. We accomplish this with an assertion and -- reshape(). body' <- localScope (scopeOfLParams params) $ ensureResultShape asserting (ErrorMsg [ErrorString "shape of result does not match shape of initial value"]) (srclocOf lam) rettype' body return $ I.Lambda params body' rettype' internaliseStreamLambda :: InternaliseLambda -> E.Exp -> [I.Type] -> InternaliseM ([LParam], Body) internaliseStreamLambda internaliseLambda lam rowts = do chunk_size <- newVName "chunk_size" let chunk_param = I.Param chunk_size $ I.Prim int32 chunktypes = map (`arrayOfRow` I.Var chunk_size) rowts localScope (scopeOfLParams [chunk_param]) $ do (lam_params, orig_body, _) <- internaliseLambda lam $ I.Prim int32 : chunktypes let orig_chunk_param : params = lam_params body <- runBodyBinder $ do letBindNames_ [paramName orig_chunk_param] $ I.BasicOp $ I.SubExp $ I.Var chunk_size return orig_body return (chunk_param:params, body) -- Given @k@ lambdas, this will return a lambda that returns an -- (k+2)-element tuple of integers. The first element is the -- equivalence class ID in the range [0,k]. The remaining are all zero -- except for possibly one element. internalisePartitionLambda :: InternaliseLambda -> Int -> E.Exp -> [I.SubExp] -> InternaliseM I.Lambda internalisePartitionLambda internaliseLambda k lam args = do argtypes <- mapM I.subExpType args let rowtypes = map I.rowType argtypes (params, body, _) <- internaliseLambda lam rowtypes body' <- localScope (scopeOfLParams params) $ lambdaWithIncrement body return $ I.Lambda params body' rettype where rettype = replicate (k+2) $ I.Prim int32 result i = map constant $ (fromIntegral i :: Int32) : (replicate i 0 ++ [1::Int32] ++ replicate (k-i) 0) mkResult _ i | i >= k = return $ result i mkResult eq_class i = do is_i <- letSubExp "is_i" $ BasicOp $ CmpOp (CmpEq int32) eq_class (constant i) fmap (map I.Var) . letTupExp "part_res" =<< eIf (eSubExp is_i) (pure $ resultBody $ result i) (resultBody <$> mkResult eq_class (i+1)) lambdaWithIncrement :: I.Body -> InternaliseM I.Body lambdaWithIncrement lam_body = runBodyBinder $ do eq_class <- head <$> bodyBind lam_body resultBody <$> mkResult eq_class 0