{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | This module contains a representation for the index function based on -- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work. module Futhark.IR.Mem.IxFun ( IxFun (..), LMAD (..), LMADDim (..), Monotonicity (..), index, iota, iotaOffset, permute, rotate, reshape, slice, flatSlice, rebase, shape, rank, linearWithOffset, rearrangeWithOffset, isDirect, isLinear, substituteInIxFun, leastGeneralGeneralization, existentialize, closeEnough, equivalent, ) where import Control.Category import Control.Monad.Identity import Control.Monad.State import Control.Monad.Writer import Data.Function (on, (&)) import Data.List (sort, sortBy, zip4, zip5, zipWith5) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import Data.Maybe (isJust) import Futhark.Analysis.PrimExp ( IntExp, PrimExp (..), TPrimExp (..), primExpType, ) import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp) import qualified Futhark.Analysis.PrimExp.Generalize as PEG import Futhark.IR.Prop import Futhark.IR.Syntax ( DimChange (..), DimIndex (..), FlatDimIndex (..), FlatSlice (..), ShapeChange, Slice (..), dimFix, flatSliceDims, flatSliceStrides, unitSlice, ) import Futhark.IR.Syntax.Core (Ext (..)) import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.IntegralExp import Futhark.Util.Pretty import Prelude hiding (id, mod, (.)) type Shape num = [num] type Indices num = [num] type Permutation = [Int] data Monotonicity = Inc | Dec | -- | monotonously increasing, decreasing or unknown Unknown deriving (Show, Eq) data LMADDim num = LMADDim { ldStride :: num, ldRotate :: num, ldShape :: num, ldPerm :: Int, ldMon :: Monotonicity } deriving (Show, Eq) -- | LMAD's representation consists of a general offset and for each dimension a -- stride, rotate factor, number of elements (or shape), permutation, and -- monotonicity. Note that the permutation is not strictly necessary in that the -- permutation can be performed directly on LMAD dimensions, but then it is -- difficult to extract the permutation back from an LMAD. -- -- LMAD algebra is closed under composition w.r.t. operators such as -- permute, index and slice. However, other operations, such as -- reshape, cannot always be represented inside the LMAD algebra. -- -- It follows that the general representation of an index function is a list of -- LMADS, in which each following LMAD in the list implicitly corresponds to an -- irregular reshaping operation. -- -- However, we expect that the common case is when the index function is one -- LMAD -- we call this the "nice" representation. -- -- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of -- the original array, and a bit to indicate whether the index function is -- contiguous, i.e., if we instantiate all the points of the current index -- function, do we get a contiguous memory interval? -- -- By definition, the LMAD denotes the set of points (simplified): -- -- \{ o + \Sigma_{j=0}^{k} ((i_j+r_j) `mod` n_j)*s_j, -- \forall i_j such that 0<=i_j Pretty (LMAD num) where ppr (LMAD offset dims) = braces $ semisep [ "offset: " <> oneLine (ppr offset), "strides: " <> p ldStride, "rotates: " <> p ldRotate, "shape: " <> p ldShape, "permutation: " <> p ldPerm, "monotonicity: " <> p ldMon ] where p f = oneLine $ brackets $ commasep $ map (ppr . f) dims instance Pretty num => Pretty (IxFun num) where ppr (IxFun lmads oshp cg) = braces $ semisep [ "base: " <> brackets (commasep $ map ppr oshp), "contiguous: " <> if cg then "true" else "false", "LMADs: " <> brackets (commastack $ NE.toList $ NE.map ppr lmads) ] instance Substitute num => Substitute (LMAD num) where substituteNames substs = fmap $ substituteNames substs instance Substitute num => Substitute (IxFun num) where substituteNames substs = fmap $ substituteNames substs instance Substitute num => Rename (LMAD num) where rename = substituteRename instance Substitute num => Rename (IxFun num) where rename = substituteRename instance FreeIn num => FreeIn (LMAD num) where freeIn' = foldMap freeIn' instance FreeIn num => FreeIn (IxFun num) where freeIn' = foldMap freeIn' instance Functor LMAD where fmap f = runIdentity . traverse (return . f) instance Functor IxFun where fmap f = runIdentity . traverse (return . f) instance Foldable LMAD where foldMap f = execWriter . traverse (tell . f) instance Foldable IxFun where foldMap f = execWriter . traverse (tell . f) instance Traversable LMAD where traverse f (LMAD offset dims) = LMAD <$> f offset <*> traverse f' dims where f' (LMADDim s r n p m) = LMADDim <$> f s <*> f r <*> f n <*> pure p <*> pure m instance Traversable IxFun where traverse f (IxFun lmads oshp cg) = IxFun <$> traverse (traverse f) lmads <*> traverse f oshp <*> pure cg (++@) :: [a] -> NonEmpty a -> NonEmpty a es ++@ (ne :| nes) = case es of e : es' -> e :| es' ++ [ne] ++ nes [] -> ne :| nes (@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a (x :| xs) @++@ (y :| ys) = x :| xs ++ [y] ++ ys invertMonotonicity :: Monotonicity -> Monotonicity invertMonotonicity Inc = Dec invertMonotonicity Dec = Inc invertMonotonicity Unknown = Unknown lmadPermutation :: LMAD num -> Permutation lmadPermutation = map ldPerm . lmadDims setLMADPermutation :: Permutation -> LMAD num -> LMAD num setLMADPermutation perm lmad = lmad {lmadDims = zipWith (\dim p -> dim {ldPerm = p}) (lmadDims lmad) perm} setLMADShape :: Shape num -> LMAD num -> LMAD num setLMADShape shp lmad = lmad {lmadDims = zipWith (\dim s -> dim {ldShape = s}) (lmadDims lmad) shp} -- | Substitute a name with a PrimExp in an LMAD. substituteInLMAD :: Ord a => M.Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a) substituteInLMAD tab (LMAD offset dims) = let offset' = substituteInPrimExp tab offset dims' = map ( \(LMADDim s r n p m) -> LMADDim (substituteInPrimExp tab s) (substituteInPrimExp tab r) (substituteInPrimExp tab n) p m ) dims in LMAD offset' dims' -- | Substitute a name with a PrimExp in an index function. substituteInIxFun :: Ord a => M.Map a (TPrimExp t a) -> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a) substituteInIxFun tab (IxFun lmads oshp cg) = IxFun (NE.map (fmap TPrimExp . substituteInLMAD tab' . fmap untyped) lmads) (map (TPrimExp . substituteInPrimExp tab' . untyped) oshp) cg where tab' = fmap untyped tab -- | Is this is a row-major array? isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool isDirect ixfun@(IxFun (LMAD offset dims :| []) oshp True) = let strides_expected = reverse $ scanl (*) 1 (reverse (tail oshp)) in hasContiguousPerm ixfun && length oshp == length dims && offset == 0 && all ( \(LMADDim s r n p _, m, d, se) -> s == se && r == 0 && n == d && p == m ) (zip4 dims [0 .. length dims - 1] oshp strides_expected) isDirect _ = False -- | Does the index function have an ascending permutation? hasContiguousPerm :: IxFun num -> Bool hasContiguousPerm (IxFun (lmad :| []) _ _) = let perm = lmadPermutation lmad in perm == sort perm hasContiguousPerm _ = False -- | Shape of an index function. shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num shape (IxFun (lmad :| _) _ _) = lmadShape lmad -- | Shape of an LMAD. lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num lmadShape lmad = permuteInv (lmadPermutation lmad) $ lmadShapeBase lmad -- | Shape of an LMAD, ignoring permutations. lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num lmadShapeBase = map ldShape . lmadDims -- | Compute the flat memory index for a complete set @inds@ of array indices -- and a certain element size @elem_size@. index :: (IntegralExp num, Eq num) => IxFun num -> Indices num -> num index = indexFromLMADs . ixfunLMADs where indexFromLMADs :: (IntegralExp num, Eq num) => NonEmpty (LMAD num) -> Indices num -> num indexFromLMADs (lmad :| []) inds = indexLMAD lmad inds indexFromLMADs (lmad1 :| lmad2 : lmads) inds = let i_flat = indexLMAD lmad1 inds new_inds = unflattenIndex (permuteFwd (lmadPermutation lmad2) $ lmadShapeBase lmad2) i_flat in indexFromLMADs (lmad2 :| lmads) new_inds indexLMAD :: (IntegralExp num, Eq num) => LMAD num -> Indices num -> num indexLMAD lmad@(LMAD off dims) inds = let prod = sum $ zipWith flatOneDim (map (\(LMADDim s r n _ _) -> (s, r, n)) dims) (permuteInv (lmadPermutation lmad) inds) in off + prod -- | iota with offset. iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num iotaOffset o ns = let rs = replicate (length ns) 0 in IxFun (makeRotIota Inc o (zip rs ns) :| []) ns True -- | iota. iota :: IntegralExp num => Shape num -> IxFun num iota = iotaOffset 0 -- | Permute dimensions. permute :: IntegralExp num => IxFun num -> Permutation -> IxFun num permute (IxFun (lmad :| lmads) oshp cg) perm_new = let perm_cur = lmadPermutation lmad perm = map (perm_cur !!) perm_new in IxFun (setLMADPermutation perm lmad :| lmads) oshp cg -- | Rotate an index function. rotate :: (Eq num, IntegralExp num) => IxFun num -> Indices num -> IxFun num rotate (IxFun (lmad@(LMAD off dims) :| lmads) oshp cg) offs = let dims' = zipWith ( \(LMADDim s r n p f) o -> if s == 0 then LMADDim 0 0 n p Unknown else LMADDim s (r + o) n p f ) dims (permuteInv (lmadPermutation lmad) offs) in IxFun (LMAD off dims' :| lmads) oshp cg -- | Handle the case where a slice can stay within a single LMAD. sliceOneLMAD :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> Maybe (IxFun num) sliceOneLMAD (IxFun (lmad@(LMAD _ ldims) :| lmads) oshp cg) (Slice is) = do let perm = lmadPermutation lmad is' = permuteInv perm is cg' = cg && slicePreservesContiguous lmad (Slice is') guard $ harmlessRotation lmad (Slice is') let lmad' = foldl sliceOne (LMAD (lmadOffset lmad) []) $ zip is' ldims -- need to remove the fixed dims from the permutation perm' = updatePerm perm $ map fst $ filter (isJust . dimFix . snd) $ zip [0 .. length is' - 1] is' return $ IxFun (setLMADPermutation perm' lmad' :| lmads) oshp cg' where updatePerm ps inds = concatMap decrease ps where decrease p = let f n i | i == p = -1 | i > p = n | n /= -1 = n + 1 | otherwise = n d = foldl f 0 inds in [p - d | d /= -1] harmlessRotation' :: (Eq num, IntegralExp num) => LMADDim num -> DimIndex num -> Bool harmlessRotation' _ (DimFix _) = True harmlessRotation' (LMADDim 0 _ _ _ _) _ = True harmlessRotation' (LMADDim _ 0 _ _ _) _ = True harmlessRotation' (LMADDim _ _ n _ _) dslc | dslc == DimSlice (n - 1) n (-1) || dslc == unitSlice 0 n = True harmlessRotation' _ _ = False harmlessRotation :: (Eq num, IntegralExp num) => LMAD num -> Slice num -> Bool harmlessRotation (LMAD _ dims) (Slice iss) = and $ zipWith harmlessRotation' dims iss -- XXX: TODO: what happens to r on a negative-stride slice; is there -- such a case? sliceOne :: (Eq num, IntegralExp num) => LMAD num -> (DimIndex num, LMADDim num) -> LMAD num sliceOne (LMAD off dims) (DimFix i, LMADDim s r n _ _) = LMAD (off + flatOneDim (s, r, n) i) dims sliceOne (LMAD off dims) (DimSlice _ ne _, LMADDim 0 _ _ p _) = LMAD off (dims ++ [LMADDim 0 0 ne p Unknown]) sliceOne (LMAD off dims) (dmind, dim@(LMADDim _ _ n _ _)) | dmind == unitSlice 0 n = LMAD off (dims ++ [dim]) sliceOne (LMAD off dims) (dmind, LMADDim s r n p m) | dmind == DimSlice (n - 1) n (-1) = let r' = if r == 0 then 0 else n - r off' = off + flatOneDim (s, 0, n) (n - 1) in LMAD off' (dims ++ [LMADDim (s * (-1)) r' n p (invertMonotonicity m)]) sliceOne (LMAD off dims) (DimSlice b ne 0, LMADDim s r n p _) = LMAD (off + flatOneDim (s, r, n) b) (dims ++ [LMADDim 0 0 ne p Unknown]) sliceOne (LMAD off dims) (DimSlice bs ns ss, LMADDim s 0 _ p m) = let m' = case sgn ss of Just 1 -> m Just (-1) -> invertMonotonicity m _ -> Unknown in LMAD (off + s * bs) (dims ++ [LMADDim (ss * s) 0 ns p m']) sliceOne _ _ = error "slice: reached impossible case" slicePreservesContiguous :: (Eq num, IntegralExp num) => LMAD num -> Slice num -> Bool slicePreservesContiguous (LMAD _ dims) (Slice slc) = -- remove from the slice the LMAD dimensions that have stride 0. -- If the LMAD was contiguous in mem, then these dims will not -- influence the contiguousness of the result. -- Also normalize the input slice, i.e., 0-stride and size-1 -- slices are rewritten as DimFixed. let (dims', slc') = unzip $ filter ((/= 0) . ldStride . fst) $ zip dims $ map normIndex slc -- Check that: -- 1. a clean split point exists between Fixed and Sliced dims -- 2. the outermost sliced dim has +/- 1 stride AND is unrotated or full. -- 3. the rest of inner sliced dims are full. (_, success) = foldl ( \(found, res) (slcdim, LMADDim _ r n _ _) -> case (slcdim, found) of (DimFix {}, True) -> (found, False) (DimFix {}, False) -> (found, res) (DimSlice _ ne ds, False) -> -- outermost sliced dim: +/-1 stride let res' = (r == 0 || n == ne) && (ds == 1 || ds == -1) in (True, res && res') (DimSlice _ ne ds, True) -> -- inner sliced dim: needs to be full let res' = (n == ne) && (ds == 1 || ds == -1) in (found, res && res') ) (False, True) $ zip slc' dims' in success normIndex :: (Eq num, IntegralExp num) => DimIndex num -> DimIndex num normIndex (DimSlice b 1 _) = DimFix b normIndex (DimSlice b _ 0) = DimFix b normIndex d = d -- | Slice an index function. slice :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> IxFun num slice ixfun@(IxFun (lmad@(LMAD _ _) :| lmads) oshp cg) dim_slices -- Avoid identity slicing. | unSlice dim_slices == map (unitSlice 0) (shape ixfun) = ixfun | Just ixfun' <- sliceOneLMAD ixfun dim_slices = ixfun' | otherwise = case sliceOneLMAD (iota (lmadShape lmad)) dim_slices of Just (IxFun (lmad' :| []) _ cg') -> IxFun (lmad' :| lmad : lmads) oshp (cg && cg') _ -> error "slice: reached impossible case" -- | Flat-slice an index function. flatSlice :: (Eq num, IntegralExp num) => IxFun num -> FlatSlice num -> IxFun num flatSlice ixfun@(IxFun (LMAD offset (dim : dims) :| lmads) oshp cg) (FlatSlice new_offset is) | hasContiguousPerm ixfun, ldRotate dim == 0 = let lmad = LMAD (offset + new_offset * ldStride dim) ( map (helper $ ldStride dim) is <> dims ) & setLMADPermutation [0 ..] in IxFun (lmad :| lmads) oshp cg where helper s0 (FlatDimIndex n s) = let new_mon = if s0 * s == 1 then Inc else Unknown in LMADDim (s0 * s) 0 n 0 new_mon flatSlice (IxFun (lmad :| lmads) oshp cg) s@(FlatSlice new_offset _) = IxFun (LMAD (new_offset * base_stride) (new_dims <> tail_dims) :| lmad : lmads) oshp cg where tail_shapes = tail $ lmadShape lmad base_stride = product tail_shapes tail_strides = tail $ scanr (*) 1 tail_shapes tail_dims = zipWith5 LMADDim tail_strides (repeat 0) tail_shapes [length new_shapes ..] (repeat Inc) new_shapes = flatSliceDims s new_strides = map (* base_stride) $ flatSliceStrides s new_dims = zipWith5 LMADDim new_strides (repeat 0) new_shapes [0 ..] (repeat Inc) -- | Handle the simple case where all reshape dimensions are coercions. reshapeCoercion :: (Eq num, IntegralExp num) => IxFun num -> ShapeChange num -> Maybe (IxFun num) reshapeCoercion (IxFun (lmad@(LMAD off dims) :| lmads) oldbase cg) newshape = do let perm = lmadPermutation lmad (head_coercions, reshapes, tail_coercions) <- splitCoercions newshape let hd_len = length head_coercions num_coercions = hd_len + length tail_coercions dims' = permuteFwd perm dims mid_dims = take (length dims - num_coercions) $ drop hd_len dims' num_rshps = length reshapes guard (num_rshps == 0 || (num_rshps == 1 && length mid_dims == 1)) let dims'' = permuteInv perm $ zipWith (\ld n -> ld {ldShape = n}) dims' (newDims newshape) lmad' = LMAD off dims'' return $ IxFun (lmad' :| lmads) oldbase cg -- | Handle the case where a reshape operation can stay inside a single LMAD. -- -- There are four conditions that all must hold for the result of a reshape -- operation to remain in the one-LMAD domain: -- -- (1) the permutation of the underlying LMAD must leave unchanged -- the LMAD dimensions that were *not* reshape coercions. -- (2) the repetition of dimensions of the underlying LMAD must -- refer only to the coerced-dimensions of the reshape operation. -- (3) similarly, the rotated dimensions must refer only to -- dimensions that are coerced by the reshape operation. -- (4) finally, the underlying memory is contiguous (and monotonous). -- -- If any of these conditions do not hold, then the reshape operation will -- conservatively add a new LMAD to the list, leading to a representation that -- provides less opportunities for further analysis. reshapeOneLMAD :: (Eq num, IntegralExp num) => IxFun num -> ShapeChange num -> Maybe (IxFun num) reshapeOneLMAD ixfun@(IxFun (lmad@(LMAD off dims) :| lmads) oldbase cg) newshape = do let perm = lmadPermutation lmad (head_coercions, reshapes, tail_coercions) <- splitCoercions newshape let hd_len = length head_coercions num_coercions = hd_len + length tail_coercions dims_perm = permuteFwd perm dims mid_dims = take (length dims - num_coercions) $ drop hd_len dims_perm -- Ignore rotates, as we only care about not having rotates in the -- dimensions that aren't coercions (@mid_dims@), which we check -- separately. mon = ixfunMonotonicityRots True ixfun guard $ -- checking conditions (2) and (3) all (\(LMADDim s r _ _ _) -> s /= 0 && r == 0) mid_dims && -- checking condition (1) consecutive hd_len (map ldPerm mid_dims) && -- checking condition (4) hasContiguousPerm ixfun && cg && (mon == Inc || mon == Dec) -- make new permutation let rsh_len = length reshapes diff = length newshape - length dims iota_shape = [0 .. length newshape -1] perm' = map ( \i -> let ind = if i < hd_len then i else i - diff in if (i >= hd_len) && (i < hd_len + rsh_len) then i -- already checked mid_dims not affected else let p = ldPerm (dims !! ind) in if p < hd_len then p else p + diff ) iota_shape -- split the dimensions (support_inds, repeat_inds) = foldl ( \(sup, rpt) (i, shpdim, ip) -> case (i < hd_len, i >= hd_len + rsh_len, shpdim) of (True, _, DimCoercion n) -> case dims_perm !! i of (LMADDim 0 _ _ _ _) -> (sup, (ip, n) : rpt) (LMADDim _ r _ _ _) -> ((ip, (r, n)) : sup, rpt) (_, True, DimCoercion n) -> case dims_perm !! (i - diff) of (LMADDim 0 _ _ _ _) -> (sup, (ip, n) : rpt) (LMADDim _ r _ _ _) -> ((ip, (r, n)) : sup, rpt) (False, False, _) -> ((ip, (0, newDim shpdim)) : sup, rpt) -- already checked that the reshaped -- dims cannot be rotates _ -> error "reshape: reached impossible case" ) ([], []) $ reverse $ zip3 iota_shape newshape perm' (sup_inds, support) = unzip $ sortBy (compare `on` fst) support_inds (rpt_inds, repeats) = unzip repeat_inds LMAD off' dims_sup = makeRotIota mon off support repeats' = map (\n -> LMADDim 0 0 n 0 Unknown) repeats dims' = map snd $ sortBy (compare `on` fst) $ zip sup_inds dims_sup ++ zip rpt_inds repeats' lmad' = LMAD off' dims' return $ IxFun (setLMADPermutation perm' lmad' :| lmads) oldbase cg where consecutive _ [] = True consecutive i [p] = i == p consecutive i ps = and $ zipWith (==) ps [i, i + 1 ..] splitCoercions :: (Eq num, IntegralExp num) => ShapeChange num -> Maybe (ShapeChange num, ShapeChange num, ShapeChange num) splitCoercions newshape' = do let (head_coercions, newshape'') = span isCoercion newshape' (reshapes, tail_coercions) = break isCoercion newshape'' guard (all isCoercion tail_coercions) return (head_coercions, reshapes, tail_coercions) where isCoercion DimCoercion {} = True isCoercion _ = False -- | Reshape an index function. reshape :: (Eq num, IntegralExp num) => IxFun num -> ShapeChange num -> IxFun num reshape ixfun new_shape | Just ixfun' <- reshapeCoercion ixfun new_shape = ixfun' | Just ixfun' <- reshapeOneLMAD ixfun new_shape = ixfun' reshape (IxFun (lmad0 :| lmad0s) oshp cg) new_shape = case iota (newDims new_shape) of IxFun (lmad :| []) _ _ -> IxFun (lmad :| lmad0 : lmad0s) oshp cg _ -> error "reshape: reached impossible case" -- | The number of dimensions in the domain of the input function. rank :: IntegralExp num => IxFun num -> Int rank (IxFun (LMAD _ sss :| _) _ _) = length sss -- | Handle the case where a rebase operation can stay within m + n - 1 LMADs, -- where m is the number of LMADs in the index function, and n is the number of -- LMADs in the new base. If both index function have only on LMAD, this means -- that we stay within the single-LMAD domain. -- -- We can often stay in that domain if the original ixfun is essentially a -- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`. -- -- XXX: TODO: handle repetitions in both lmads. -- -- How to handle repeated dimensions in the original? -- -- (a) Shave them off of the last lmad of original -- (b) Compose the result from (a) with the first -- lmad of the new base -- (c) apply a repeat operation on the result of (b). -- -- However, I strongly suspect that for in-place update what we need is actually -- the INVERSE of the rebase function, i.e., given an index function new-base -- and another one orig, compute the index function ixfun0 such that: -- -- new-base == rebase ixfun0 ixfun, or equivalently: -- new-base == ixfun o ixfun0 -- -- because then I can go bottom up and compose with ixfun0 all the index -- functions corresponding to the memory block associated with ixfun. rebaseNice :: (Eq num, IntegralExp num) => IxFun num -> IxFun num -> Maybe (IxFun num) rebaseNice new_base@(IxFun (lmad_base :| lmads_base) _ cg_base) ixfun@(IxFun lmads shp cg) = do let (lmad :| lmads') = NE.reverse lmads dims = lmadDims lmad perm = lmadPermutation lmad perm_base = lmadPermutation lmad_base guard $ -- Core rebase condition. base ixfun == shape new_base -- Conservative safety conditions: ixfun is contiguous and has known -- monotonicity for all dimensions. && cg && all ((/= Unknown) . ldMon) dims -- XXX: We should be able to handle some basic cases where both index -- functions have non-trivial permutations. && (hasContiguousPerm ixfun || hasContiguousPerm new_base) -- We need the permutations to be of the same size if we want to compose -- them. They don't have to be of the same size if the ixfun has a trivial -- permutation. Supporting this latter case allows us to rebase when ixfun -- has been created by slicing with fixed dimensions. && (length perm == length perm_base || hasContiguousPerm ixfun) -- To not have to worry about ixfun having non-1 strides, we also check that -- it is a row-major array (modulo permutation, which is handled -- separately). Accept a non-full innermost dimension. XXX: Maybe this can -- be less conservative? && and ( zipWith3 (\sn ld inner -> sn == ldShape ld || (inner && ldStride ld == 1)) shp dims (replicate (length dims - 1) False ++ [True]) ) -- Compose permutations, reverse strides and adjust offset if necessary. let perm_base' = if hasContiguousPerm ixfun then perm_base else map (perm !!) perm_base lmad_base' = setLMADPermutation perm_base' lmad_base dims_base = lmadDims lmad_base' n_fewer_dims = length dims_base - length dims (dims_base', offs_contrib) = unzip $ zipWith ( \(LMADDim s1 r1 n1 p1 _) (LMADDim _ r2 _ _ m2) -> let (s', off') | m2 == Inc = (s1, 0) | otherwise = (s1 * (-1), s1 * (n1 - 1)) r' | m2 == Inc = if r2 == 0 then r1 else r1 + r2 | r1 == 0 = r2 | r2 == 0 = n1 - r1 | otherwise = n1 - r1 + r2 in (LMADDim s' r' n1 (p1 - n_fewer_dims) Inc, off') ) -- If @dims@ is morally a slice, it might have fewer dimensions than -- @dims_base@. Drop extraneous outer dimensions. (drop n_fewer_dims dims_base) dims off_base = lmadOffset lmad_base' + sum offs_contrib lmad_base'' | lmadOffset lmad == 0 = LMAD off_base dims_base' | otherwise = -- If the innermost dimension of the ixfun was not full (but still -- had a stride of 1), add its offset relative to the new base. setLMADShape (lmadShape lmad) ( LMAD (off_base + ldStride (last dims_base) * lmadOffset lmad) dims_base' ) new_base' = IxFun (lmad_base'' :| lmads_base) shp cg_base IxFun lmads_base' _ _ = new_base' lmads'' = lmads' ++@ lmads_base' return $ IxFun lmads'' shp (cg && cg_base) -- | Rebase an index function on top of a new base. rebase :: (Eq num, IntegralExp num) => IxFun num -> IxFun num -> IxFun num rebase new_base@(IxFun lmads_base shp_base cg_base) ixfun@(IxFun lmads shp cg) | Just ixfun' <- rebaseNice new_base ixfun = ixfun' -- In the general case just concatenate LMADs since this refers to index -- function composition, which is always safe. | otherwise = let (lmads_base', shp_base') = if base ixfun == shape new_base then (lmads_base, shp_base) else let IxFun lmads' shp_base'' _ = reshape new_base $ map DimCoercion shp in (lmads', shp_base'') in IxFun (lmads @++@ lmads_base') shp_base' (cg && cg_base) ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity ixfunMonotonicity = ixfunMonotonicityRots False -- | If the memory support of the index function is contiguous and row-major -- (i.e., no transpositions, repetitions, rotates, etc.), then this should -- return the offset from which the memory-support of this index function -- starts. linearWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe num linearWithOffset ixfun@(IxFun (lmad :| []) _ cg) elem_size | hasContiguousPerm ixfun && cg && ixfunMonotonicity ixfun == Inc = Just $ lmadOffset lmad * elem_size linearWithOffset _ _ = Nothing -- | Similar restrictions to @linearWithOffset@ except for transpositions, which -- are returned together with the offset. rearrangeWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe (num, [(Int, num)]) rearrangeWithOffset (IxFun (lmad :| []) oshp cg) elem_size = do -- Note that @cg@ describes whether the index function is -- contiguous, *ignoring permutations*. This function requires that -- functionality. let perm = lmadPermutation lmad perm_contig = [0 .. length perm -1] offset <- linearWithOffset (IxFun (setLMADPermutation perm_contig lmad :| []) oshp cg) elem_size return (offset, zip perm (permuteFwd perm (lmadShapeBase lmad))) rearrangeWithOffset _ _ = Nothing -- | Is this a row-major array starting at offset zero? isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool isLinear = (== Just 0) . flip linearWithOffset 1 permuteFwd :: Permutation -> [a] -> [a] permuteFwd ps elems = map (elems !!) ps permuteInv :: Permutation -> [a] -> [a] permuteInv ps elems = map snd $ sortBy (compare `on` fst) $ zip ps elems flatOneDim :: (Eq num, IntegralExp num) => (num, num, num) -> num -> num flatOneDim (s, r, n) i | s == 0 = 0 | r == 0 = i * s | otherwise = ((i + r) `mod` n) * s -- | Generalised iota with user-specified offset and rotates. makeRotIota :: IntegralExp num => Monotonicity -> -- | Offset num -> -- | Pairs of shape and rotation [(num, num)] -> LMAD num makeRotIota mon off support | mon == Inc || mon == Dec = let rk = length support (rs, ns) = unzip support ss0 = reverse $ take rk $ scanl (*) 1 $ reverse ns ss = if mon == Inc then ss0 else map (* (-1)) ss0 ps = map fromIntegral [0 .. rk -1] fi = replicate rk mon in LMAD off $ zipWith5 LMADDim ss rs ns ps fi | otherwise = error "makeRotIota: requires Inc or Dec" -- | Check monotonicity of an index function. ixfunMonotonicityRots :: (Eq num, IntegralExp num) => Bool -> IxFun num -> Monotonicity ixfunMonotonicityRots ignore_rots (IxFun (lmad :| lmads) _ _) = let mon0 = lmadMonotonicityRots lmad in if all ((== mon0) . lmadMonotonicityRots) lmads then mon0 else Unknown where lmadMonotonicityRots :: (Eq num, IntegralExp num) => LMAD num -> Monotonicity lmadMonotonicityRots (LMAD _ dims) | all (isMonDim Inc) dims = Inc | all (isMonDim Dec) dims = Dec | otherwise = Unknown isMonDim :: (Eq num, IntegralExp num) => Monotonicity -> LMADDim num -> Bool isMonDim mon (LMADDim s r _ _ ldmon) = s == 0 || ((ignore_rots || r == 0) && mon == ldmon) -- | Generalization (anti-unification) -- -- Anti-unification of two index functions is supported under the following conditions: -- 0. Both index functions are represented by ONE lmad (assumed common case!) -- 1. The support array of the two indexfuns have the same dimensionality -- (we can relax this condition if we use a 1D support, as we probably should!) -- 2. The contiguous property and the per-dimension monotonicity are the same -- (otherwise we might loose important information; this can be relaxed!) -- 3. Most importantly, both index functions correspond to the same permutation -- (since the permutation is represented by INTs, this restriction cannot -- be relaxed, unless we move to a gated-LMAD representation!) leastGeneralGeneralization :: Eq v => IxFun (PrimExp v) -> IxFun (PrimExp v) -> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)]) leastGeneralGeneralization (IxFun (lmad1 :| []) oshp1 ctg1) (IxFun (lmad2 :| []) oshp2 ctg2) = do guard $ length oshp1 == length oshp2 && ctg1 == ctg2 && map ldPerm (lmadDims lmad1) == map ldPerm (lmadDims lmad2) && lmadDMon lmad1 == lmadDMon lmad2 let (ctg, dperm, dmon) = (ctg1, lmadPermutation lmad1, lmadDMon lmad1) (dshp, m1) <- generalize [] (lmadDShp lmad1) (lmadDShp lmad2) (oshp, m2) <- generalize m1 oshp1 oshp2 (dstd, m3) <- generalize m2 (lmadDSrd lmad1) (lmadDSrd lmad2) (drot, m4) <- generalize m3 (lmadDRot lmad1) (lmadDRot lmad2) let (offt, m5) = PEG.leastGeneralGeneralization m4 (lmadOffset lmad1) (lmadOffset lmad2) let lmad_dims = map (\(a, b, c, d, e) -> LMADDim a b c d e) $ zip5 dstd drot dshp dperm dmon lmad = LMAD offt lmad_dims return (IxFun (lmad :| []) oshp ctg, m5) where lmadDMon = map ldMon . lmadDims lmadDSrd = map ldStride . lmadDims lmadDShp = map ldShape . lmadDims lmadDRot = map ldRotate . lmadDims generalize m l1 l2 = foldM ( \(l_acc, m') (pe1, pe2) -> do let (e, m'') = PEG.leastGeneralGeneralization m' pe1 pe2 return (l_acc ++ [e], m'') ) ([], m) (zip l1 l2) leastGeneralGeneralization _ _ = Nothing isSequential :: [Int] -> Bool isSequential xs = all (uncurry (==)) $ zip xs [0 ..] existentializeExp :: TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v)) existentializeExp e = do i <- gets length modify (++ [e]) let t = primExpType $ untyped e return $ TPrimExp $ LeafExp (Ext i) t -- We require that there's only one lmad, and that the index function is contiguous, and the base shape has only one dimension existentialize :: (IntExp t, Eq v, Pretty v) => IxFun (TPrimExp t v) -> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v)))) existentialize (IxFun (lmad :| []) oshp True) | all ((== 0) . ldRotate) (lmadDims lmad), length (lmadShape lmad) == length oshp, isSequential (map ldPerm $ lmadDims lmad) = do oshp' <- mapM existentializeExp oshp lmadOffset' <- existentializeExp $ lmadOffset lmad lmadDims' <- mapM existentializeLMADDim $ lmadDims lmad let lmad' = LMAD lmadOffset' lmadDims' return $ Just $ IxFun (lmad' :| []) oshp' True where existentializeLMADDim :: LMADDim (TPrimExp t v) -> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v))) existentializeLMADDim (LMADDim str rot shp perm mon) = do stride' <- existentializeExp str shape' <- existentializeExp shp return $ LMADDim stride' (fmap Free rot) shape' perm mon existentialize _ = return Nothing -- | When comparing index functions as part of the type check in KernelsMem, -- we may run into problems caused by the simplifier. As index functions can be -- generalized over if-then-else expressions, the simplifier might hoist some of -- the code from inside the if-then-else (computing the offset of an array, for -- instance), but now the type checker cannot verify that the generalized index -- function is valid, because some of the existentials are computed somewhere -- else. To Work around this, we've had to relax the KernelsMem type-checker -- a bit, specifically, we've introduced this function to verify whether two -- index functions are "close enough" that we can assume that they match. We use -- this instead of `ixfun1 == ixfun2` and hope that it's good enough. closeEnough :: IxFun num -> IxFun num -> Bool closeEnough ixf1 ixf2 = (length (base ixf1) == length (base ixf2)) && (NE.length (ixfunLMADs ixf1) == NE.length (ixfunLMADs ixf2)) && all closeEnoughLMADs (NE.zip (ixfunLMADs ixf1) (ixfunLMADs ixf2)) where closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool closeEnoughLMADs (lmad1, lmad2) = length (lmadDims lmad1) == length (lmadDims lmad2) && map ldPerm (lmadDims lmad1) == map ldPerm (lmadDims lmad2) -- | Returns true if two 'IxFun's are equivalent. -- -- Equivalence in this case is defined as having the same number of LMADs, with -- each pair of LMADs matching in permutation, offsets, strides and rotations. equivalent :: Eq num => IxFun num -> IxFun num -> Bool equivalent ixf1 ixf2 = NE.length (ixfunLMADs ixf1) == NE.length (ixfunLMADs ixf2) && all closeEnoughLMADs (NE.zip (ixfunLMADs ixf1) (ixfunLMADs ixf2)) where closeEnoughLMADs (lmad1, lmad2) = length (lmadDims lmad1) == length (lmadDims lmad2) && map ldPerm (lmadDims lmad1) == map ldPerm (lmadDims lmad2) && lmadOffset lmad1 == lmadOffset lmad2 && map ldStride (lmadDims lmad1) == map ldStride (lmadDims lmad2) && map ldRotate (lmadDims lmad1) == map ldRotate (lmadDims lmad2)