{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | This module contains a representation for the index function based on -- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work. module Futhark.IR.Mem.IxFun ( IxFun (..), Shape, LMAD (..), LMADDim (..), Monotonicity (..), index, mkExistential, iota, iotaOffset, permute, reshape, coerce, slice, flatSlice, rebase, shape, rank, linearWithOffset, rearrangeWithOffset, isDirect, isLinear, substituteInIxFun, existentialize, closeEnough, equivalent, dynamicEqualsLMAD, ) where import Control.Category import Control.Monad.Identity import Control.Monad.State import Control.Monad.Writer import Data.Function (on, (&)) import Data.List (sort, sortBy, zip4, zipWith4) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import Data.Maybe (isJust) import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp) import Futhark.IR.Prop import Futhark.IR.Syntax ( DimIndex (..), FlatDimIndex (..), FlatSlice (..), Slice (..), dimFix, flatSliceDims, flatSliceStrides, unitSlice, ) import Futhark.IR.Syntax.Core (Ext (..)) import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.IntegralExp import Futhark.Util.Pretty import Prelude hiding (id, mod, (.)) -- | The shape of an index function. type Shape num = [num] type Indices num = [num] type Permutation = [Int] -- | The physical element ordering alongside a dimension, i.e. the -- sign of the stride. data Monotonicity = -- | Increasing. Inc | -- | Decreasing. Dec | -- | Unknown. Unknown deriving (Show, Eq) -- | A single dimension in an 'LMAD'. data LMADDim num = LMADDim { ldStride :: num, ldShape :: num, ldPerm :: Int, ldMon :: Monotonicity } deriving (Show, Eq) -- | LMAD's representation consists of a general offset and for each dimension a -- stride, number of elements (or shape), permutation, and -- monotonicity. Note that the permutation is not strictly necessary in that the -- permutation can be performed directly on LMAD dimensions, but then it is -- difficult to extract the permutation back from an LMAD. -- -- LMAD algebra is closed under composition w.r.t. operators such as -- permute, index and slice. However, other operations, such as -- reshape, cannot always be represented inside the LMAD algebra. -- -- It follows that the general representation of an index function is a list of -- LMADS, in which each following LMAD in the list implicitly corresponds to an -- irregular reshaping operation. -- -- However, we expect that the common case is when the index function is one -- LMAD -- we call this the "nice" representation. -- -- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of -- the original array, and a bit to indicate whether the index function is -- contiguous, i.e., if we instantiate all the points of the current index -- function, do we get a contiguous memory interval? -- -- By definition, the LMAD denotes the set of points (simplified): -- -- \{ o + \Sigma_{j=0}^{k} ((i_j+r_j) `mod` n_j)*s_j, -- \forall i_j such that 0<=i_j Pretty (LMAD num) where ppr (LMAD offset dims) = braces $ semisep [ "offset: " <> oneLine (ppr offset), "strides: " <> p ldStride, "shape: " <> p ldShape, "permutation: " <> p ldPerm, "monotonicity: " <> p ldMon ] where p f = oneLine $ brackets $ commasep $ map (ppr . f) dims instance Pretty num => Pretty (IxFun num) where ppr (IxFun lmads oshp cg) = braces $ semisep [ "base: " <> brackets (commasep $ map ppr oshp), "contiguous: " <> if cg then "true" else "false", "LMADs: " <> brackets (commastack $ NE.toList $ NE.map ppr lmads) ] instance Substitute num => Substitute (LMAD num) where substituteNames substs = fmap $ substituteNames substs instance Substitute num => Substitute (IxFun num) where substituteNames substs = fmap $ substituteNames substs instance Substitute num => Rename (LMAD num) where rename = substituteRename instance Substitute num => Rename (IxFun num) where rename = substituteRename instance FreeIn num => FreeIn (LMAD num) where freeIn' = foldMap freeIn' instance FreeIn num => FreeIn (IxFun num) where freeIn' = foldMap freeIn' instance Functor LMAD where fmap f = runIdentity . traverse (pure . f) instance Functor IxFun where fmap f = runIdentity . traverse (pure . f) instance Foldable LMAD where foldMap f = execWriter . traverse (tell . f) instance Foldable IxFun where foldMap f = execWriter . traverse (tell . f) instance Traversable LMAD where traverse f (LMAD offset dims) = LMAD <$> f offset <*> traverse f' dims where f' (LMADDim s n p m) = LMADDim <$> f s <*> f n <*> pure p <*> pure m -- It is important that the traversal order here is the same as in -- mkExistential. instance Traversable IxFun where traverse f (IxFun lmads oshp cg) = IxFun <$> traverse (traverse f) lmads <*> traverse f oshp <*> pure cg (++@) :: [a] -> NonEmpty a -> NonEmpty a es ++@ (ne :| nes) = case es of e : es' -> e :| es' ++ [ne] ++ nes [] -> ne :| nes (@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a (x :| xs) @++@ (y :| ys) = x :| xs ++ [y] ++ ys invertMonotonicity :: Monotonicity -> Monotonicity invertMonotonicity Inc = Dec invertMonotonicity Dec = Inc invertMonotonicity Unknown = Unknown lmadPermutation :: LMAD num -> Permutation lmadPermutation = map ldPerm . lmadDims setLMADPermutation :: Permutation -> LMAD num -> LMAD num setLMADPermutation perm lmad = lmad {lmadDims = zipWith (\dim p -> dim {ldPerm = p}) (lmadDims lmad) perm} setLMADShape :: Shape num -> LMAD num -> LMAD num setLMADShape shp lmad = lmad {lmadDims = zipWith (\dim s -> dim {ldShape = s}) (lmadDims lmad) shp} -- | Substitute a name with a PrimExp in an LMAD. substituteInLMAD :: Ord a => M.Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a) substituteInLMAD tab (LMAD offset dims) = let offset' = substituteInPrimExp tab offset dims' = map ( \(LMADDim s n p m) -> LMADDim (substituteInPrimExp tab s) (substituteInPrimExp tab n) p m ) dims in LMAD offset' dims' -- | Substitute a name with a PrimExp in an index function. substituteInIxFun :: Ord a => M.Map a (TPrimExp t a) -> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a) substituteInIxFun tab (IxFun lmads oshp cg) = IxFun (NE.map (fmap TPrimExp . substituteInLMAD tab' . fmap untyped) lmads) (map (TPrimExp . substituteInPrimExp tab' . untyped) oshp) cg where tab' = fmap untyped tab -- | Is this is a row-major array? isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool isDirect ixfun@(IxFun (LMAD offset dims :| []) oshp True) = let strides_expected = reverse $ scanl (*) 1 (reverse (tail oshp)) in hasContiguousPerm ixfun && length oshp == length dims && offset == 0 && all (\(LMADDim s n p _, m, d, se) -> s == se && n == d && p == m) (zip4 dims [0 .. length dims - 1] oshp strides_expected) isDirect _ = False -- | Does the index function have an ascending permutation? hasContiguousPerm :: IxFun num -> Bool hasContiguousPerm (IxFun (lmad :| []) _ _) = let perm = lmadPermutation lmad in perm == sort perm hasContiguousPerm _ = False -- | The index space of the index function. This is the same as the -- shape of arrays that the index function supports. shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num shape (IxFun (lmad :| _) _ _) = permuteFwd (lmadPermutation lmad) $ lmadShapeBase lmad -- | Shape of an LMAD. lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num lmadShape lmad = permuteInv (lmadPermutation lmad) $ lmadShapeBase lmad -- | Shape of an LMAD, ignoring permutations. lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num lmadShapeBase = map ldShape . lmadDims -- | Compute the flat memory index for a complete set @inds@ of array indices -- and a certain element size @elem_size@. index :: (IntegralExp num, Eq num) => IxFun num -> Indices num -> num index = indexFromLMADs . ixfunLMADs where indexFromLMADs :: (IntegralExp num, Eq num) => NonEmpty (LMAD num) -> Indices num -> num indexFromLMADs (lmad :| []) inds = indexLMAD lmad inds indexFromLMADs (lmad1 :| lmad2 : lmads) inds = let i_flat = indexLMAD lmad1 inds new_inds = unflattenIndex (permuteFwd (lmadPermutation lmad2) $ lmadShapeBase lmad2) i_flat in indexFromLMADs (lmad2 :| lmads) new_inds indexLMAD :: (IntegralExp num, Eq num) => LMAD num -> Indices num -> num indexLMAD lmad@(LMAD off dims) inds = let prod = sum $ zipWith flatOneDim (map ldStride dims) (permuteInv (lmadPermutation lmad) inds) in off + prod -- | iota with offset. iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num iotaOffset o ns = IxFun (makeRotIota Inc o ns :| []) ns True -- | iota. iota :: IntegralExp num => Shape num -> IxFun num iota = iotaOffset 0 -- | Create a contiguous single-LMAD index function that is -- existential in everything, with the provided permutation, -- monotonicity, and contiguousness. mkExistential :: Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a) mkExistential basis_rank perm contig start = IxFun (NE.singleton lmad) basis contig where basis = take basis_rank $ map Ext [start + 1 + dims_rank * 2 ..] dims_rank = length perm lmad = LMAD (Ext start) $ zipWith onDim perm [0 ..] onDim (p, mon) i = LMADDim (Ext (start + 1 + i * 2)) (Ext (start + 2 + i * 2)) p mon -- | Permute dimensions. permute :: IntegralExp num => IxFun num -> Permutation -> IxFun num permute (IxFun (lmad :| lmads) oshp cg) perm_new = let perm_cur = lmadPermutation lmad perm = map (perm_cur !!) perm_new in IxFun (setLMADPermutation perm lmad :| lmads) oshp cg -- | Handle the case where a slice can stay within a single LMAD. sliceOneLMAD :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> Maybe (IxFun num) sliceOneLMAD (IxFun (lmad@(LMAD _ ldims) :| lmads) oshp cg) (Slice is) = do let perm = lmadPermutation lmad is' = permuteInv perm is cg' = cg && slicePreservesContiguous lmad (Slice is') let lmad' = foldl sliceOne (LMAD (lmadOffset lmad) []) $ zip is' ldims -- need to remove the fixed dims from the permutation perm' = updatePerm perm $ map fst $ filter (isJust . dimFix . snd) $ zip [0 .. length is' - 1] is' pure $ IxFun (setLMADPermutation perm' lmad' :| lmads) oshp cg' where updatePerm ps inds = concatMap decrease ps where decrease p = let f n i | i == p = -1 | i > p = n | n /= -1 = n + 1 | otherwise = n d = foldl f 0 inds in [p - d | d /= -1] -- XXX: TODO: what happens to r on a negative-stride slice; is there -- such a case? sliceOne :: (Eq num, IntegralExp num) => LMAD num -> (DimIndex num, LMADDim num) -> LMAD num sliceOne (LMAD off dims) (DimFix i, LMADDim s _x _ _) = LMAD (off + flatOneDim s i) dims sliceOne (LMAD off dims) (DimSlice _ ne _, LMADDim 0 _ p _) = LMAD off (dims ++ [LMADDim 0 ne p Unknown]) sliceOne (LMAD off dims) (dmind, dim@(LMADDim _ n _ _)) | dmind == unitSlice 0 n = LMAD off (dims ++ [dim]) sliceOne (LMAD off dims) (dmind, LMADDim s n p m) | dmind == DimSlice (n - 1) n (-1) = let off' = off + flatOneDim s (n - 1) in LMAD off' (dims ++ [LMADDim (s * (-1)) n p (invertMonotonicity m)]) sliceOne (LMAD off dims) (DimSlice b ne 0, LMADDim s _ p _) = LMAD (off + flatOneDim s b) (dims ++ [LMADDim 0 ne p Unknown]) sliceOne (LMAD off dims) (DimSlice bs ns ss, LMADDim s _ p m) = let m' = case sgn ss of Just 1 -> m Just (-1) -> invertMonotonicity m _ -> Unknown in LMAD (off + s * bs) (dims ++ [LMADDim (ss * s) ns p m']) slicePreservesContiguous :: (Eq num, IntegralExp num) => LMAD num -> Slice num -> Bool slicePreservesContiguous (LMAD _ dims) (Slice slc) = -- remove from the slice the LMAD dimensions that have stride 0. -- If the LMAD was contiguous in mem, then these dims will not -- influence the contiguousness of the result. -- Also normalize the input slice, i.e., 0-stride and size-1 -- slices are rewritten as DimFixed. let (dims', slc') = unzip $ filter ((/= 0) . ldStride . fst) $ zip dims $ map normIndex slc -- Check that: -- 1. a clean split point exists between Fixed and Sliced dims -- 2. the outermost sliced dim has +/- 1 stride. -- 3. the rest of inner sliced dims are full. (_, success) = foldl ( \(found, res) (slcdim, LMADDim _ n _ _) -> case (slcdim, found) of (DimFix {}, True) -> (found, False) (DimFix {}, False) -> (found, res) (DimSlice _ _ ds, False) -> -- outermost sliced dim: +/-1 stride let res' = (ds == 1 || ds == -1) in (True, res && res') (DimSlice _ ne ds, True) -> -- inner sliced dim: needs to be full let res' = (n == ne) && (ds == 1 || ds == -1) in (found, res && res') ) (False, True) $ zip slc' dims' in success normIndex :: (Eq num, IntegralExp num) => DimIndex num -> DimIndex num normIndex (DimSlice b 1 _) = DimFix b normIndex (DimSlice b _ 0) = DimFix b normIndex d = d -- | Slice an index function. slice :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> IxFun num slice ixfun@(IxFun (lmad@(LMAD _ _) :| lmads) oshp cg) dim_slices -- Avoid identity slicing. | unSlice dim_slices == map (unitSlice 0) (shape ixfun) = ixfun | Just ixfun' <- sliceOneLMAD ixfun dim_slices = ixfun' | otherwise = case sliceOneLMAD (iota (lmadShape lmad)) dim_slices of Just (IxFun (lmad' :| []) _ cg') -> IxFun (lmad' :| lmad : lmads) oshp (cg && cg') _ -> error "slice: reached impossible case" -- | Flat-slice an index function. flatSlice :: (Eq num, IntegralExp num) => IxFun num -> FlatSlice num -> IxFun num flatSlice ixfun@(IxFun (LMAD offset (dim : dims) :| lmads) oshp cg) (FlatSlice new_offset is) | hasContiguousPerm ixfun = let lmad = LMAD (offset + new_offset * ldStride dim) (map (helper $ ldStride dim) is <> dims) & setLMADPermutation [0 ..] in IxFun (lmad :| lmads) oshp cg where helper s0 (FlatDimIndex n s) = let new_mon = if s0 * s == 1 then Inc else Unknown in LMADDim (s0 * s) n 0 new_mon flatSlice (IxFun (lmad :| lmads) oshp cg) s@(FlatSlice new_offset _) = IxFun (LMAD (new_offset * base_stride) (new_dims <> tail_dims) :| lmad : lmads) oshp cg where tail_shapes = tail $ lmadShape lmad base_stride = product tail_shapes tail_strides = tail $ scanr (*) 1 tail_shapes tail_dims = zipWith4 LMADDim tail_strides tail_shapes [length new_shapes ..] (repeat Inc) new_shapes = flatSliceDims s new_strides = map (* base_stride) $ flatSliceStrides s new_dims = zipWith4 LMADDim new_strides new_shapes [0 ..] (repeat Inc) -- | Handle the case where a reshape operation can stay inside a single LMAD. -- -- There are four conditions that all must hold for the result of a reshape -- operation to remain in the one-LMAD domain: -- -- (1) the permutation of the underlying LMAD must leave unchanged -- the LMAD dimensions that were *not* reshape coercions. -- (2) the repetition of dimensions of the underlying LMAD must -- refer only to the coerced-dimensions of the reshape operation. -- (3) finally, the underlying memory is contiguous (and monotonous). -- -- If any of these conditions do not hold, then the reshape operation will -- conservatively add a new LMAD to the list, leading to a representation that -- provides less opportunities for further analysis. reshapeOneLMAD :: (Eq num, IntegralExp num) => IxFun num -> Shape num -> Maybe (IxFun num) reshapeOneLMAD ixfun@(IxFun (lmad@(LMAD off dims) :| lmads) oldbase cg) newshape = do let perm = lmadPermutation lmad dims_perm = permuteFwd perm dims mid_dims = take (length dims) dims_perm mon = ixfunMonotonicity ixfun guard $ -- checking conditions (2) all (\(LMADDim s _ _ _) -> s /= 0) mid_dims && -- checking condition (1) consecutive 0 (map ldPerm mid_dims) && -- checking condition (3) hasContiguousPerm ixfun && cg && (mon == Inc || mon == Dec) -- make new permutation let rsh_len = length newshape diff = length newshape - length dims iota_shape = [0 .. length newshape - 1] perm' = map ( \i -> let ind = i - diff in if (i >= 0) && (i < rsh_len) then i -- already checked mid_dims not affected else ldPerm (dims !! ind) + diff ) iota_shape -- split the dimensions (support_inds, repeat_inds) = foldl (\(sup, rpt) (shpdim, ip) -> ((ip, shpdim) : sup, rpt)) ([], []) $ reverse $ zip newshape perm' (sup_inds, support) = unzip $ sortBy (compare `on` fst) support_inds (rpt_inds, repeats) = unzip repeat_inds LMAD off' dims_sup = makeRotIota mon off support repeats' = map (\n -> LMADDim 0 n 0 Unknown) repeats dims' = map snd $ sortBy (compare `on` fst) $ zip sup_inds dims_sup ++ zip rpt_inds repeats' lmad' = LMAD off' dims' pure $ IxFun (setLMADPermutation perm' lmad' :| lmads) oldbase cg where consecutive _ [] = True consecutive i [p] = i == p consecutive i ps = and $ zipWith (==) ps [i, i + 1 ..] -- | Reshape an index function. reshape :: (Eq num, IntegralExp num) => IxFun num -> Shape num -> IxFun num reshape ixfun new_shape | Just ixfun' <- reshapeOneLMAD ixfun new_shape = ixfun' reshape (IxFun (lmad0 :| lmad0s) oshp cg) new_shape = case iota new_shape of IxFun (lmad :| []) _ _ -> IxFun (lmad :| lmad0 : lmad0s) oshp cg _ -> error "reshape: reached impossible case" -- | Coerce an index function to look like it has a new shape. -- Dynamically the shape must be the same. coerce :: (Eq num, IntegralExp num) => IxFun num -> Shape num -> IxFun num coerce (IxFun (lmad :| lmads) oshp cg) new_shape = IxFun (onLMAD lmad :| lmads) oshp cg where onLMAD (LMAD offset dims) = LMAD offset $ zipWith onDim dims new_shape onDim ld d = ld {ldShape = d} -- | The number of dimensions in the domain of the input function. rank :: IntegralExp num => IxFun num -> Int rank (IxFun (LMAD _ sss :| _) _ _) = length sss -- | Handle the case where a rebase operation can stay within m + n - 1 LMADs, -- where m is the number of LMADs in the index function, and n is the number of -- LMADs in the new base. If both index function have only on LMAD, this means -- that we stay within the single-LMAD domain. -- -- We can often stay in that domain if the original ixfun is essentially a -- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`. -- -- XXX: TODO: handle repetitions in both lmads. -- -- How to handle repeated dimensions in the original? -- -- (a) Shave them off of the last lmad of original -- (b) Compose the result from (a) with the first -- lmad of the new base -- (c) apply a repeat operation on the result of (b). -- -- However, I strongly suspect that for in-place update what we need is actually -- the INVERSE of the rebase function, i.e., given an index function new-base -- and another one orig, compute the index function ixfun0 such that: -- -- new-base == rebase ixfun0 ixfun, or equivalently: -- new-base == ixfun o ixfun0 -- -- because then I can go bottom up and compose with ixfun0 all the index -- functions corresponding to the memory block associated with ixfun. rebaseNice :: (Eq num, IntegralExp num) => IxFun num -> IxFun num -> Maybe (IxFun num) rebaseNice new_base@(IxFun (lmad_base :| lmads_base) _ cg_base) ixfun@(IxFun lmads shp cg) = do let (lmad :| lmads') = NE.reverse lmads dims = lmadDims lmad perm = lmadPermutation lmad perm_base = lmadPermutation lmad_base guard $ -- Core rebase condition. base ixfun == shape new_base -- Conservative safety conditions: ixfun is contiguous and has known -- monotonicity for all dimensions. && cg && all ((/= Unknown) . ldMon) dims -- XXX: We should be able to handle some basic cases where both index -- functions have non-trivial permutations. && (hasContiguousPerm ixfun || hasContiguousPerm new_base) -- We need the permutations to be of the same size if we want to compose -- them. They don't have to be of the same size if the ixfun has a trivial -- permutation. Supporting this latter case allows us to rebase when ixfun -- has been created by slicing with fixed dimensions. && (length perm == length perm_base || hasContiguousPerm ixfun) -- To not have to worry about ixfun having non-1 strides, we also check that -- it is a row-major array (modulo permutation, which is handled -- separately). Accept a non-full innermost dimension. XXX: Maybe this can -- be less conservative? && and ( zipWith3 (\sn ld inner -> sn == ldShape ld || (inner && ldStride ld == 1)) shp dims (replicate (length dims - 1) False ++ [True]) ) -- Compose permutations, reverse strides and adjust offset if necessary. let perm_base' = if hasContiguousPerm ixfun then perm_base else map (perm !!) perm_base lmad_base' = setLMADPermutation perm_base' lmad_base dims_base = lmadDims lmad_base' n_fewer_dims = length dims_base - length dims (dims_base', offs_contrib) = unzip $ zipWith ( \(LMADDim s1 n1 p1 _) (LMADDim _ _ _ m2) -> let (s', off') | m2 == Inc = (s1, 0) | otherwise = (s1 * (-1), s1 * (n1 - 1)) in (LMADDim s' n1 (p1 - n_fewer_dims) Inc, off') ) -- If @dims@ is morally a slice, it might have fewer dimensions than -- @dims_base@. Drop extraneous outer dimensions. (drop n_fewer_dims dims_base) dims off_base = lmadOffset lmad_base' + sum offs_contrib lmad_base'' | lmadOffset lmad == 0 = LMAD off_base dims_base' | otherwise = -- If the innermost dimension of the ixfun was not full (but still -- had a stride of 1), add its offset relative to the new base. setLMADShape (lmadShape lmad) ( LMAD (off_base + ldStride (last dims_base) * lmadOffset lmad) dims_base' ) new_base' = IxFun (lmad_base'' :| lmads_base) shp cg_base IxFun lmads_base' _ _ = new_base' lmads'' = lmads' ++@ lmads_base' pure $ IxFun lmads'' shp (cg && cg_base) -- | Rebase an index function on top of a new base. rebase :: (Eq num, IntegralExp num) => IxFun num -> IxFun num -> IxFun num rebase new_base@(IxFun lmads_base shp_base cg_base) ixfun@(IxFun lmads shp cg) | Just ixfun' <- rebaseNice new_base ixfun = ixfun' -- In the general case just concatenate LMADs since this refers to index -- function composition, which is always safe. | otherwise = let (lmads_base', shp_base') = if base ixfun == shape new_base then (lmads_base, shp_base) else let IxFun lmads' shp_base'' _ = reshape new_base shp in (lmads', shp_base'') in IxFun (lmads @++@ lmads_base') shp_base' (cg && cg_base) -- | If the memory support of the index function is contiguous and row-major -- (i.e., no transpositions, repetitions, rotates, etc.), then this should -- return the offset from which the memory-support of this index function -- starts. linearWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe num linearWithOffset ixfun@(IxFun (lmad :| []) _ cg) elem_size | hasContiguousPerm ixfun && cg && ixfunMonotonicity ixfun == Inc = Just $ lmadOffset lmad * elem_size linearWithOffset _ _ = Nothing -- | Similar restrictions to @linearWithOffset@ except for transpositions, which -- are returned together with the offset. rearrangeWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe (num, [(Int, num)]) rearrangeWithOffset (IxFun (lmad :| []) oshp cg) elem_size = do -- Note that @cg@ describes whether the index function is -- contiguous, *ignoring permutations*. This function requires that -- functionality. let perm = lmadPermutation lmad perm_contig = [0 .. length perm - 1] offset <- linearWithOffset (IxFun (setLMADPermutation perm_contig lmad :| []) oshp cg) elem_size pure (offset, zip perm (permuteFwd perm (lmadShapeBase lmad))) rearrangeWithOffset _ _ = Nothing -- | Is this a row-major array starting at offset zero? isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool isLinear = (== Just 0) . flip linearWithOffset 1 permuteFwd :: Permutation -> [a] -> [a] permuteFwd ps elems = map (elems !!) ps permuteInv :: Permutation -> [a] -> [a] permuteInv ps elems = map snd $ sortBy (compare `on` fst) $ zip ps elems flatOneDim :: (Eq num, IntegralExp num) => num -> num -> num flatOneDim s i | s == 0 = 0 | otherwise = i * s -- | Generalised iota with user-specified offset and rotates. makeRotIota :: IntegralExp num => Monotonicity -> -- | Offset num -> -- | Shape [num] -> LMAD num makeRotIota mon off ns | mon == Inc || mon == Dec = let rk = length ns ss0 = reverse $ take rk $ scanl (*) 1 $ reverse ns ss = if mon == Inc then ss0 else map (* (-1)) ss0 ps = map fromIntegral [0 .. rk - 1] fi = replicate rk mon in LMAD off $ zipWith4 LMADDim ss ns ps fi | otherwise = error "makeRotIota: requires Inc or Dec" -- | Check monotonicity of an index function. ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity ixfunMonotonicity (IxFun (lmad :| lmads) _ _) = let mon0 = lmadMonotonicityRots lmad in if all ((== mon0) . lmadMonotonicityRots) lmads then mon0 else Unknown where lmadMonotonicityRots :: (Eq num, IntegralExp num) => LMAD num -> Monotonicity lmadMonotonicityRots (LMAD _ dims) | all (isMonDim Inc) dims = Inc | all (isMonDim Dec) dims = Dec | otherwise = Unknown isMonDim :: (Eq num, IntegralExp num) => Monotonicity -> LMADDim num -> Bool isMonDim mon (LMADDim s _ _ ldmon) = s == 0 || mon == ldmon -- | Turn all the leaves of the index function into 'Ext's. existentialize :: IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b)) existentialize ixfun = evalState (traverse (const mkExt) ixfun) 0 where mkExt = do i <- get put $ i + 1 pure $ TPrimExp $ LeafExp (Ext i) int64 -- | When comparing index functions as part of the type check in KernelsMem, -- we may run into problems caused by the simplifier. As index functions can be -- generalized over if-then-else expressions, the simplifier might hoist some of -- the code from inside the if-then-else (computing the offset of an array, for -- instance), but now the type checker cannot verify that the generalized index -- function is valid, because some of the existentials are computed somewhere -- else. To Work around this, we've had to relax the KernelsMem type-checker -- a bit, specifically, we've introduced this function to verify whether two -- index functions are "close enough" that we can assume that they match. We use -- this instead of `ixfun1 == ixfun2` and hope that it's good enough. closeEnough :: IxFun num -> IxFun num -> Bool closeEnough ixf1 ixf2 = (length (base ixf1) == length (base ixf2)) && (NE.length (ixfunLMADs ixf1) == NE.length (ixfunLMADs ixf2)) && all closeEnoughLMADs (NE.zip (ixfunLMADs ixf1) (ixfunLMADs ixf2)) -- This treats ixf1 as the "declared type" that we are matching against. && (contiguous ixf1 <= contiguous ixf2) where closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool closeEnoughLMADs (lmad1, lmad2) = length (lmadDims lmad1) == length (lmadDims lmad2) && map ldPerm (lmadDims lmad1) == map ldPerm (lmadDims lmad2) -- | Returns true if two 'IxFun's are equivalent. -- -- Equivalence in this case is defined as having the same number of LMADs, with -- each pair of LMADs matching in permutation, offsets, strides and rotations. equivalent :: Eq num => IxFun num -> IxFun num -> Bool equivalent ixf1 ixf2 = NE.length (ixfunLMADs ixf1) == NE.length (ixfunLMADs ixf2) && all equivalentLMADs (NE.zip (ixfunLMADs ixf1) (ixfunLMADs ixf2)) where equivalentLMADs (lmad1, lmad2) = length (lmadDims lmad1) == length (lmadDims lmad2) && map ldPerm (lmadDims lmad1) == map ldPerm (lmadDims lmad2) && lmadOffset lmad1 == lmadOffset lmad2 && map ldStride (lmadDims lmad1) == map ldStride (lmadDims lmad2) -- | Dynamically determine if two 'LMADDim' are equal. -- -- True if the dynamic values of their constituents are equal. dynamicEqualsLMADDim :: Eq num => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num dynamicEqualsLMADDim dim1 dim2 = ldStride dim1 .==. ldStride dim2 .&&. ldShape dim1 .==. ldShape dim2 .&&. fromBool (ldPerm dim1 == ldPerm dim2) .&&. fromBool (ldMon dim1 == ldMon dim2) -- | Dynamically determine if two 'LMAD' are equal. -- -- True if offset and constituent 'LMADDim' are equal. dynamicEqualsLMAD :: Eq num => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num dynamicEqualsLMAD lmad1 lmad2 = lmadOffset lmad1 .==. lmadOffset lmad2 .&&. foldr ((.&&.) . uncurry dynamicEqualsLMADDim) true (zip (lmadDims lmad1) (lmadDims lmad2))