-- | An index function represents a mapping from an array index space -- to a flat byte offset. module Futhark.Representation.ExplicitMemory.IndexFunction ( -- IxFun(..) IxFun , index , iota , offsetIndex , strideIndex , permute , rotate , reshape , slice , base , rebase , repeat , shape , rank , linearWithOffset , rearrangeWithOffset , isLinear , isDirect , substituteInIxFun , getInfoMaxUnification , subsInIndexIxFun , ixFunsCompatibleRaw , ixFunHasIndex , offsetIndexDWIM ) where import Control.Arrow (first) import Data.Maybe import Data.Monoid ((<>)) import Data.List hiding (repeat) import Control.Monad.Identity import Control.Monad.Writer import Prelude hiding (mod, repeat) import qualified Data.List as L import qualified Data.Map.Strict as M import Futhark.Transform.Substitute import Futhark.Transform.Rename import Futhark.Representation.AST.Syntax (ShapeChange, DimChange(..), DimIndex(..), Slice, sliceDims, unitSlice, VName(..)) import Futhark.Representation.AST.Attributes.Names import Futhark.Representation.AST.Attributes.Reshape import Futhark.Representation.AST.Attributes.Rearrange import Futhark.Representation.AST.Pretty () import Futhark.Util.IntegralExp import Futhark.Util.Pretty import Futhark.Util import Futhark.Analysis.PrimExp.Convert type Shape num = [num] type Indices num = [num] type Permutation = [Int] data IxFun num = Direct (Shape num) | Permute (IxFun num) Permutation | Rotate (IxFun num) (Indices num) | Index (IxFun num) (Slice num) | Reshape (IxFun num) (ShapeChange num) | Repeat (IxFun num) [Shape num] (Shape num) deriving (Eq,Show) instance Pretty num => Pretty (IxFun num) where ppr (Direct dims) = text "Direct" <> parens (commasep $ map ppr dims) ppr (Permute fun perm) = ppr fun <> ppr perm ppr (Rotate fun offsets) = ppr fun <> brackets (commasep $ map ((text "+" <>) . ppr) offsets) ppr (Index fun is) = ppr fun <> brackets (commasep $ map ppr is) ppr (Reshape fun oldshape) = ppr fun <> text "->reshape" <> parens (commasep (map ppr oldshape)) ppr (Repeat fun outer_shapes inner_shape) = ppr fun <> text "->repeat" <> parens (commasep (map ppr $ outer_shapes++ [inner_shape])) instance Substitute num => Substitute (IxFun num) where substituteNames substs = fmap $ substituteNames substs instance FreeIn num => FreeIn (IxFun num) where freeIn = foldMap freeIn instance Functor IxFun where fmap f = runIdentity . traverse (return . f) instance Foldable IxFun where foldMap f = execWriter . traverse (tell . f) instance Traversable IxFun where traverse f (Direct dims) = Direct <$> traverse f dims traverse f (Permute ixfun perm) = Permute <$> traverse f ixfun <*> pure perm traverse f (Rotate ixfun offsets) = Rotate <$> traverse f ixfun <*> traverse f offsets traverse f (Index ixfun is) = Index <$> traverse f ixfun <*> traverse (traverse f) is traverse f (Reshape ixfun dims) = Reshape <$> traverse f ixfun <*> traverse (traverse f) dims traverse f (Repeat ixfun outer_shapes inner_shape) = Repeat <$> traverse f ixfun <*> traverse (traverse f) outer_shapes <*> traverse f inner_shape instance Substitute num => Rename (IxFun num) where rename = substituteRename index :: (Pretty num, IntegralExp num) => IxFun num -> Indices num -> num -> num index (Direct dims) is element_size = sum (zipWith (*) is slicesizes) * element_size where slicesizes = drop 1 $ sliceSizes dims index (Permute fun perm) is_new element_size = index fun is_old element_size where is_old = rearrangeShape (rearrangeInverse perm) is_new index (Rotate fun offsets) is element_size = index fun (zipWith mod (zipWith (+) is offsets) dims) element_size where dims = shape fun index (Index fun js) is element_size = index fun (adjust js is) element_size where adjust (DimFix j:js') is' = j : adjust js' is' adjust (DimSlice j _ s:js') (i:is') = j + i * s : adjust js' is' adjust _ _ = [] index (Reshape fun newshape) is element_size = let new_indices = reshapeIndex (shape fun) (newDims newshape) is in index fun new_indices element_size index (Repeat fun outer_shapes _) is element_size = -- Discard those indices that are just repeats. It is intentional -- that we cut off those indices that correspond to the innermost -- repeated dimensions. index fun is' element_size where flags dims = replicate (length dims) True ++ [False] is' = map snd $ filter (not . fst) $ zip (concatMap flags outer_shapes) is iota :: Shape num -> IxFun num iota = Direct offsetIndex :: (Eq num, IntegralExp num) => IxFun num -> num -> IxFun num offsetIndex ixfun i | i == 0 = ixfun offsetIndex ixfun i = case shape ixfun of d:ds -> slice ixfun (DimSlice i (d-i) 1 : map (unitSlice 0) ds) [] -> error "offsetIndex: underlying index function has rank zero" strideIndex :: (Eq num, IntegralExp num) => IxFun num -> num -> IxFun num strideIndex ixfun s = case shape ixfun of d:ds -> slice ixfun (DimSlice (fromInt32 0) d s : map (unitSlice (fromInt32 0)) ds) [] -> error "offsetIndex: underlying index function has rank zero" permute :: IntegralExp num => IxFun num -> Permutation -> IxFun num permute (Permute ixfun oldperm) perm | rearrangeInverse oldperm == perm = ixfun | otherwise = permute ixfun (rearrangeCompose perm oldperm) permute ixfun perm | perm == sort perm = ixfun | otherwise = Permute ixfun perm rotate :: IntegralExp num => IxFun num -> Indices num -> IxFun num rotate (Rotate ixfun old_offsets) offsets = Rotate ixfun $ zipWith (+) old_offsets offsets rotate ixfun offsets = Rotate ixfun offsets repeat :: IxFun num -> [Shape num] -> Shape num -> IxFun num repeat = Repeat reshape :: (Eq num, IntegralExp num) => IxFun num -> ShapeChange num -> IxFun num reshape Direct{} newshape = Direct $ map newDim newshape reshape (Reshape ixfun _) newshape = reshape ixfun newshape reshape (Permute ixfun perm) newshape | Just (head_coercions, reshapes, tail_coercions) <- splitCoercions newshape, num_coercions <- length (head_coercions ++ tail_coercions), (head_perms, mid_perms, end_perms) <- splitAt3 (length head_coercions) (length perm - num_coercions) perm, sequential mid_perms, first_reshaped <- foldl min (rank ixfun) mid_perms, extra_dims <- length newshape - length (shape ixfun), perm' <- map (shiftDim first_reshaped extra_dims) head_perms ++ take (length reshapes) [first_reshaped..] ++ map (shiftDim first_reshaped extra_dims) end_perms, newshape' <- rearrangeShape (rearrangeInverse perm') newshape = Permute (reshape ixfun newshape') perm' where splitCoercions newshape' = do let (head_coercions, newshape'') = span isCoercion newshape' let (reshapes, tail_coercions) = break isCoercion newshape'' guard (all isCoercion tail_coercions) return (head_coercions, reshapes, tail_coercions) isCoercion DimCoercion{} = True isCoercion _ = False shiftDim last_reshaped extra_dims x | x > last_reshaped = x + extra_dims | otherwise = x sequential [] = True sequential (x:xs) = and $ zipWith (==) xs [x+1, x+2..] reshape (Index ixfun slicing) newshape | [newdim] <- newDims newshape, Just slicing' <- findSlice slicing (Just newdim) = Index ixfun slicing' | (is, rem_slicing) <- splitSlice slicing, (fixed_ds, sliced_ds) <- splitAt (length is) $ shape ixfun, and $ zipWith isSliceOf rem_slicing sliced_ds = -- Move the reshape beneath the slicing. let newshape' = map DimCoercion fixed_ds ++ newshape in Index (reshape ixfun newshape') $ map DimFix is ++ map (unitSlice (fromInt32 0)) (newDims newshape) where isSliceOf (DimSlice _ d1 1) d2 = d1 == d2 isSliceOf _ _ = False findSlice (DimFix i:is) d = (DimFix i:) <$> findSlice is d findSlice (DimSlice j _ stride:is) d = do d' <- d (DimSlice j d' stride:) <$> findSlice is Nothing findSlice [] Just{} = Nothing findSlice [] Nothing = Just [] reshape ixfun newshape | shape ixfun == map newDim newshape = ixfun | rank ixfun == length newshape, Just _ <- shapeCoercion newshape = ixfun | otherwise = Reshape ixfun newshape splitSlice :: Slice num -> ([num], Slice num) splitSlice [] = ([], []) splitSlice (DimFix i:is) = first (i:) $ splitSlice is splitSlice is = ([], is) slice :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> IxFun num slice ixfun is -- Avoid identity slicing. | is == map (unitSlice 0) (shape ixfun) = ixfun slice (Index ixfun mis) is = Index ixfun $ reslice mis is where reslice mis' [] = mis' reslice (DimFix j:mis') is' = DimFix j : reslice mis' is' reslice (DimSlice orig_k _ orig_s:mis') (DimSlice new_k n new_s:is') = DimSlice (orig_k + new_k * orig_s) n (orig_s*new_s) : reslice mis' is' reslice (DimSlice orig_k _ orig_s:mis') (DimFix i:is') = DimFix (orig_k+i*orig_s) : reslice mis' is' reslice _ _ = error "IndexFunction slice: invalid arguments" slice ixfun [] = ixfun slice ixfun is = Index ixfun is rank :: IntegralExp num => IxFun num -> Int rank = length . shape shape :: IntegralExp num => IxFun num -> Shape num shape (Direct dims) = dims shape (Permute ixfun perm) = rearrangeShape perm $ shape ixfun shape (Rotate ixfun _) = shape ixfun shape (Index _ how) = sliceDims how shape (Reshape _ dims) = map newDim dims shape (Repeat ixfun outer_shapes inner_shape) = concat (zipWith repeated outer_shapes (shape ixfun)) ++ inner_shape where repeated outer_ds d = outer_ds ++ [d] base :: IxFun num -> Shape num base (Direct dims) = dims base (Permute ixfun _) = base ixfun base (Rotate ixfun _) = base ixfun base (Index ixfun _) = base ixfun base (Reshape ixfun _) = base ixfun base (Repeat ixfun _ _) = base ixfun rebase :: (Eq num, IntegralExp num) => IxFun num -> IxFun num -> IxFun num rebase new_base (Direct old_shape) | old_shape == shape new_base = new_base | otherwise = reshape new_base $ map DimCoercion old_shape rebase new_base (Permute ixfun perm) = permute (rebase new_base ixfun) perm rebase new_base (Rotate ixfun offsets) = rotate (rebase new_base ixfun) offsets rebase new_base (Index ixfun is) = slice (rebase new_base ixfun) is rebase new_base (Reshape ixfun new_shape) = reshape (rebase new_base ixfun) new_shape rebase new_base (Repeat ixfun outer_shapes inner_shape) = Repeat (rebase new_base ixfun) outer_shapes inner_shape -- This function does not cover all possible cases. It's a "best -- effort" kind of thing. linearWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe num linearWithOffset (Direct _) _ = Just 0 linearWithOffset (Reshape ixfun _) element_size = linearWithOffset ixfun element_size linearWithOffset (Index ixfun is) element_size = do is' <- fixingOuter is inner_shape inner_offset <- linearWithOffset ixfun element_size let slices = take m $ drop 1 $ sliceSizes $ shape ixfun return $ inner_offset + sum (zipWith (*) slices is') * element_size where m = length is inner_shape = shape ixfun fixingOuter (DimFix i:is') (_:ds) = (i:) <$> fixingOuter is' ds fixingOuter (DimSlice off _ 1:is') (_:ds) | is' == map (unitSlice 0) ds = Just [off] fixingOuter is' ds | is' == map (unitSlice 0) ds = Just [] fixingOuter _ _ = Nothing linearWithOffset _ _ = Nothing rearrangeWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe (num, [(Int,num)]) rearrangeWithOffset (Reshape ixfun _) element_size = rearrangeWithOffset ixfun element_size rearrangeWithOffset (Permute ixfun perm) element_size = do offset <- linearWithOffset ixfun element_size return (offset, zip perm $ rearrangeShape perm $ shape ixfun) rearrangeWithOffset _ _ = Nothing isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool isLinear = (==Just 0) . flip linearWithOffset 1 isDirect :: IxFun num -> Bool isDirect Direct{} = True isDirect _ = False -- | Substituting a name with a PrimExp in an index function. substituteInIxFun :: (Ord a) => M.Map a (PrimExp a) -> IxFun (PrimExp a) -> IxFun (PrimExp a) substituteInIxFun tab (Direct pes) = Direct $ map (substituteInPrimExp tab) pes substituteInIxFun tab (Permute ixfun p) = Permute (substituteInIxFun tab ixfun) p substituteInIxFun tab (Rotate ixfun pes) = Rotate (substituteInIxFun tab ixfun) $ map (substituteInPrimExp tab) pes substituteInIxFun tab (Index ixfun sl) = Index (substituteInIxFun tab ixfun) $ map (fmap $ substituteInPrimExp tab) sl substituteInIxFun tab (Reshape ixfun newshape) = Reshape (substituteInIxFun tab ixfun) $ map (fmap $ substituteInPrimExp tab) newshape substituteInIxFun tab (Repeat ixfun outer_shapes inner_shape) = Repeat (substituteInIxFun tab ixfun) outer_shapes inner_shape ----------------------------------------------------------- --- Niels' functions for memory management: --- --- these are prime candidates to be removed/re-written --- ----------------------------------------------------------- type IxFn = IxFun (PrimExp VName) getInfoMaxUnification :: IxFn -> Maybe (IxFn, Slice (PrimExp VName), VName) getInfoMaxUnification (Index ixfun_start slc) = case L.span isDimFix slc of (indices_start, [DimSlice _start_offset (LeafExp final_dim@VName{} (IntType Int32)) _stride]) -> Just (ixfun_start, indices_start, final_dim) _ -> Nothing where isDimFix DimFix{} = True isDimFix _ = False getInfoMaxUnification _ = Nothing -- Are two index functions *identical*? (Silly approach, but the Eq -- instance is used for something else.) ixFunsCompatibleRaw :: Eq num => IxFun num -> IxFun num -> Bool ixFunsCompatibleRaw ixfun0 ixfun1 = ixfun0 `primEq` ixfun1 where primEq a b = case (a, b) of (Direct sa, Direct sb) -> sa == sb (Permute a1 pa, Permute b1 pb) -> a1 `primEq` b1 && pa == pb (Rotate a1 ia, Rotate b1 ib) -> a1 `primEq` b1 && ia == ib (Index a1 sa, Index b1 sb) -> a1 `primEq` b1 && sa == sb (Reshape a1 sa, Reshape b1 sb) -> a1 `primEq` b1 && sa == sb (Repeat a1 ssa sa, Repeat b1 ssb sb) -> a1 `primEq` b1 && ssa == ssb && sa == sb _ -> False ixFunHasIndex :: IxFun num -> Bool ixFunHasIndex ixfun = case ixfun of Direct _ -> False Permute ixfun' _ -> ixFunHasIndex ixfun' Rotate ixfun' _ -> ixFunHasIndex ixfun' Index{} -> True Reshape ixfun' _ -> ixFunHasIndex ixfun' Repeat ixfun' _ _ -> ixFunHasIndex ixfun' subsInIndexIxFun :: IxFn -> VName -> VName -> IxFn subsInIndexIxFun (Index ixfun_start slc) final_dim final_dim_max_v = let tab = M.singleton final_dim (LeafExp final_dim_max_v (IntType Int32)) ixfun_start' = substituteInIxFun tab ixfun_start in Index ixfun_start' slc subsInIndexIxFun _ _ _ = error "In IxFun.subsInIndexIxFun: should-not-happen case reached!" offsetIndexDWIM :: Int -> IxFn -> PrimExp VName -> IxFn offsetIndexDWIM n_ignore_initial ixfun offset = fromMaybe (offsetIndex ixfun offset) $ case ixfun of Index ixfun1 dimindices -> let (dim_first, dim_rest) = L.splitAt n_ignore_initial dimindices in case dim_rest of (DimFix i : dim_rest') -> Just $ Index ixfun1 (dim_first ++ DimFix (i + offset) : dim_rest') _ -> Nothing _ -> Nothing