{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | An index function represents a mapping from an array index space -- to a flat byte offset. This implements a representation for the -- index function based on linear-memory accessor descriptors, see -- Zhu, Hoeflinger and David work. Our specific representation is: -- LMAD = \overline{s,r,n}^k + o, where `o` is the offset, and `s_j`, -- `r_j`, and `n_j` are the stride, the rotate factor and the number -- of elements on dimension j. Dimensions are ordered in row major fashion. -- By definition, the LMAD above denotes the set of points: -- \{ o + \Sigma_{j=0}^{k} ((i_j+r_j) `mod` n_j)*s_j, -- \forall i_j such that 0<=i_j Pretty (Lmad num) where ppr (Lmad tau srnps) = let (ss, rs, ns, ps, fs) = unzip5 srnps in text " | " <> ppr tau <> text " + " <> brackets (commasep $ map ppr ss) <> text "v" <> brackets (commasep $ map ppr rs) <> text "v" <> brackets (commasep $ map ppr ns) <> text "v" <> brackets (commasep $ map ppr ps) <> text "v" <> brackets (commasep $ map ppr fs) <> text " | " instance Pretty num => Pretty (IxFun num) where ppr (IxFun lmads orgshp cg) = text "Shape: " <> braces (commasep $ map ppr orgshp) <> text " LMADS: " <> braces (stack $ map ppr lmads) <> text " CONTIG: "<> text (show cg) instance Substitute num => Substitute (Lmad num) where substituteNames substs = fmap $ substituteNames substs instance Substitute num => Substitute (IxFun num) where substituteNames substs = fmap $ substituteNames substs instance Substitute num => Rename (Lmad num) where rename = substituteRename instance Substitute num => Rename (IxFun num) where rename = substituteRename instance FreeIn num => FreeIn (Lmad num) where freeIn = foldMap freeIn instance FreeIn num => FreeIn (IxFun num) where freeIn = foldMap freeIn instance Functor Lmad where fmap f = runIdentity . traverse (return . f) instance Functor IxFun where fmap f = runIdentity . traverse (return . f) instance Foldable Lmad where foldMap f = execWriter . traverse (tell . f) instance Foldable IxFun where foldMap f = execWriter . traverse (tell . f) instance Traversable Lmad where traverse f (Lmad x l) = Lmad <$> f x <*> traverse f' l where f' (a, b, c, k, info) = (,,,,) <$> f a <*> f b <*> f c <*> pure k <*> pure info instance Traversable IxFun where traverse f (IxFun lmads shp cg) = IxFun <$> traverse (traverse f) lmads <*> traverse f shp <*> pure cg -- | Substituting a name with a PrimExp in an Lmad. substituteInLmad :: M.Map VName (PrimExp VName) -> Lmad (PrimExp VName) -> Lmad (PrimExp VName) substituteInLmad tab (Lmad off srnpds) = let off' = substituteInPrimExp tab off srnpds' = map (\(s,r,n,p,d) -> ( substituteInPrimExp tab s , substituteInPrimExp tab r , substituteInPrimExp tab n , p, d ) ) srnpds in Lmad off' srnpds' -- | Substituting a name with a PrimExp in an index function. substituteInIxFun :: M.Map VName (PrimExp VName) -> IxFun (PrimExp VName) -> IxFun (PrimExp VName) substituteInIxFun tab (IxFun lmads shp b) = IxFun (map (substituteInLmad tab) lmads) (map (substituteInPrimExp tab) shp) b ------------------------------------------ --- Index Function/LMAD Implementation --- ------------------------------------------ -- | whether this is a row-major array isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool isDirect (IxFun [Lmad off info] shp True) | length shp == length info, all (\((s,r,n,p,_),i,d) -> s==1 && r==0 && n==d && p==i) (zip3 info [0..length info - 1] shp), off == 0 = True | otherwise = False isDirect _ = False -- | whether an index function has contiguous memory support isContiguous :: (Eq num, IntegralExp num) => IxFun num -> Bool isContiguous (IxFun _ _ cg) = cg -- | Shape of an Lmad shape0 :: (Eq num, IntegralExp num) => Lmad num -> Shape num shape0 lmad@(Lmad _ srns) = map (\(_,_,z,_,_)->z) $ permuteInv (getPermutation lmad) srns -- | Shape of an index function shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num shape (IxFun [] _ _) = error "shape: empty index function" shape (IxFun (lmad:_) _ _) = shape0 lmad -- | Computing the flat memory index for a complete set `inds` -- of array indices and a certain element size `elem_size`. index :: (IntegralExp num, Eq num) => IxFun num -> Indices num -> num -> num index (IxFun [] _ _) _ _ = error "index: empty index function" index (IxFun [lmad] _ _) iis elm_size = index0 lmad iis elm_size index (IxFun (lmad1:lmad2:lmads) oshp c) iis elm_size = let i_flat = index0 lmad1 iis 1 new_inds = unflattenIndex (shape0 lmad2) i_flat in index (IxFun (lmad2:lmads) oshp c) new_inds elm_size -- | Helper for index: computing the flat index of an Lmad. index0 :: (Eq num, IntegralExp num) => Lmad num -> Indices num -> num -> num index0 lmad@(Lmad tau srnps) inds elm_size = let prod = sum $ zipWith flatOneDim (map (\(s,r,n,_,_) -> (s,r,n)) srnps) (permuteInv (getPermutation lmad) inds) ind = tau + prod in if elm_size == 1 then ind else ind * elm_size -- | iota iota :: (IntegralExp num) => Shape num -> IxFun num iota ns = IxFun [makeRotIota Inc 0 $ zip rs ns] ns True where rs = replicate (length ns) 0 -- | permute dimensions permute :: IntegralExp num => IxFun num -> Permutation -> IxFun num permute (IxFun [] _ _) _ = error "permute: empty index function" permute (IxFun (lmad:lmads) oshp cg) ps = let perm = map (\p -> ps !! p) $ getPermutation lmad in IxFun (setPermutation perm lmad : lmads) oshp cg -- | repeating dimensions repeat :: (Eq num, IntegralExp num) => IxFun num -> [Shape num] -> Shape num -> IxFun num repeat (IxFun [] _ _) _ _ = error "repeat: empty index function" repeat (IxFun (lmad@(Lmad tau srnps) : lmads) oshp cg) shps shp = let perm = getPermutation lmad -- inverse permute the shapes and update the permutation! lens = map (\s -> 1 + length s) shps (shps', lens') = unzip $ permuteInv perm $ zip shps lens scn = drop 1 $ scanl (+) 0 lens' perm' = concatMap (\(p,l) -> map (\i-> (scn!!p)-l+i) [0..l-1]) $ zip perm lens tmp = length perm' perm'' = perm' ++ [tmp..tmp-1+length shp] srnps' = concatMap (\(shp_k, srnp)-> map fakeDim shp_k ++ [srnp] ) $ zip shps' srnps lmad' = setPermutation perm'' $ Lmad tau (srnps' ++ map fakeDim shp) in IxFun (lmad' : lmads) oshp cg where fakeDim x = (0,0,x,0,Unknown) -- | Rotating an index function: rotate :: (Eq num, IntegralExp num) => IxFun num -> Indices num -> IxFun num rotate (IxFun [] _ _) _ = error "rotate: empty index function" rotate (IxFun (lmad@(Lmad off srnps) : lmads) oshp cg) offs = let srnps' = zipWith (\(s,r,n,p,f) o -> if s == 0 then (0,0,n,p,Unknown) else (s,r+o,n,p,f) ) srnps $ permuteInv (getPermutation lmad) offs in IxFun (Lmad off srnps':lmads) oshp cg -- | Slicing an index function. slice :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> IxFun num slice (IxFun [] _ _) _ = error "slice: empty index function" slice _ [] = error "slice: empty slice ???" slice ixfn dim_slices -- Avoid identity slicing. | dim_slices == map (unitSlice 0) (shape ixfn) = ixfn slice (IxFun (lmad@(Lmad _ srnpfs):lmads) oshp cg) is = let perm= getPermutation lmad is' = permuteInv perm is contig = cg && preservesContiguous lmad is' in if harmlessRotation lmad is' then let lmad' = foldl sliceOne (Lmad (getOffset lmad) []) $ zip is' srnpfs -- need to remove the fixed dims from the permutation perm' = updatePerm perm $ map fst $ filter isFixedDim $ zip [0..length is' - 1] is' in IxFun (setPermutation perm' lmad':lmads) oshp contig else -- falls outside LMAD formula, hence append a new LMAD case slice (iota (shape0 lmad)) is of IxFun [lmad'] _ _ -> IxFun (lmad':lmad:lmads) oshp contig _ -> error "slice: reached impossible case!" where isFixedDim (_,DimFix{}) = True isFixedDim _ = False updatePerm ps inds = foldl (\acc p -> acc ++ decrease p) [] ps where decrease p = let d = foldl (\n i -> if i == p then (-1) else if i > p then n else if n /= (-1) then n+1 else n ) 0 inds in if d == (-1) then [] else [p-d] harmlessRotation0 :: (Eq num, IntegralExp num) => (num,num,num,Int,DimInfo) -> DimIndex num -> Bool harmlessRotation0 _ (DimFix _) = True harmlessRotation0 (0,_,_,_,_) _ = True harmlessRotation0 (_,0,_,_,_) _ = True harmlessRotation0 (_,_,n,_,_) dslc | dslc == DimSlice (n-1) n (-1) || dslc == unitSlice 0 n = True harmlessRotation0 _ _ = False harmlessRotation :: (Eq num, IntegralExp num) => Lmad num -> Slice num -> Bool harmlessRotation (Lmad _ srnps) iss = and $ zipWith harmlessRotation0 srnps iss -- | TODO: what happens to r on a negative-stride slice; is there a such case? sliceOne :: (Eq num, IntegralExp num) => Lmad num -> (DimIndex num, (num,num,num,Int,DimInfo)) -> Lmad num sliceOne (Lmad tau srns) (DimFix i, (s,r,n,_,_)) = Lmad (tau + flatOneDim (s,r,n) i) srns sliceOne (Lmad tau srns) (DimSlice _ ne _, (0,_,_,p,_)) = Lmad tau (srns ++ [(0,0,ne,p,Unknown)]) sliceOne (Lmad tau srns) (dmind, srn@(_,_,n,_,_)) | dmind == unitSlice 0 n = Lmad tau (srns ++ [srn]) sliceOne (Lmad tau srns) (dmind, (s,r,n,p,f)) | dmind == DimSlice (n-1) n (-1) = let r' = if r == 0 then 0 else n-r in Lmad tau' (srns ++ [(s*(-1),r',n,p, invertInfo f)]) where tau' = tau + flatOneDim (s,0,n) (n-1) sliceOne (Lmad tau srns) (DimSlice b ne 0, (s,r,n,p,_)) = Lmad (tau + flatOneDim (s,r,n) b) (srns ++ [(0,0,ne,p,Unknown)]) sliceOne (Lmad tau srns) (DimSlice bs ns ss, (s,0,_,p,f)) = let f' = case sgn ss of Just 1 -> f Just (-1) -> invertInfo f _ -> Unknown in Lmad (tau + s*bs) (srns ++ [(ss*s,0,ns,p,f')]) sliceOne _ _ = error "slice: reached impossible case!" normIndex :: (Eq num, IntegralExp num) => DimIndex num -> DimIndex num normIndex (DimSlice b 1 _) = DimFix b normIndex (DimSlice b _ 0) = DimFix b normIndex d = d preservesContiguous :: (Eq num, IntegralExp num) => Lmad num -> Slice num -> Bool preservesContiguous (Lmad _ srnps) slc = -- remove from the slice the Lmad dimensions who have stride 0. -- If the Lmad was contiguous in mem, then these dims will not -- influence the contiguousness of the result. -- Also normalize the input slice, i.e., 0-stride and size-1 -- slices are rewritten as DimFixed. let (srnps', slc') = unzip $ filter (\((s,_,_,_,_),_) -> s /= 0) $ zip srnps $ map normIndex slc -- Check that: -- 1. a clean split point exists between Fixed and Sliced dims -- 2. the outermost sliced dim has +/- 1 stride AND is unrottated or full. -- 3. the rest of inner sliced dims are full. (_, success) = foldl (\(found,res) (slcdim, (_,r,n,_,_)) -> case (slcdim, found) of (DimFix{}, True ) -> (found, False) (DimFix{}, False) -> (found, res) (DimSlice _ ne ds, False) -> -- outermost sliced dim: +/-1 stride let res' = (r == 0 || n == ne) && (ds == 1 || ds == (-1)) in (True, res && res') (DimSlice _ ne ds, True) -> -- inner sliced dim: needs to be full let res' = (n == ne) && (ds == 1 || ds == (-1)) in (found, res && res') ) (False,True) $ zip slc' srnps' in success -- | Reshaping an index function. -- There are four conditions that all must hold for the result -- of a reshape operation to remain into the one-Lmad domain: -- (1) the permutation of the underlying Lmad must leave unchanged -- the Lmad dimensions that were *not* reshape coercions. -- (2) the repetition of dimensions of the underlying Lmad must -- refer only to the coerced-dimensions of the reshape operation. -- (3) similarly, the rotated dimensions must refer only to -- dimensions that are coerced by the reshape operation. -- (4) finally, the underlying memory is contiguous (and monotonous) -- -- If any of this conditions does not hold then the reshape operation -- will conservatively add a new Lmad to the list, leading to a -- representation that provides less opportunities for further analysis. -- -- Actually there are some special cases that need to be treated, -- for example if everything is a coercion, then it should succeed -- no matter what. reshape :: (Eq num, IntegralExp num) => IxFun num -> ShapeChange num -> IxFun num reshape (IxFun [] _ _) _ = error "reshape: empty index function" reshape ixfn@(IxFun (lmad@(Lmad tau srnps):lmads) oshp cg) newshape | -- first take care of the case when this is all a coercion! perm <- getPermutation lmad, Just (head_coercions, reshapes, tail_coercions) <- splitCoercions newshape, hd_len <- length head_coercions, num_coercions <- hd_len + length tail_coercions, srnps' <- permuteFwd perm srnps, mid_srnps <- take (length srnps - num_coercions) $ drop hd_len srnps', num_rshps <- length reshapes, num_rshps == 0 || (num_rshps == 1 && length mid_srnps == 1), srnps'' <- map snd $ L.sortBy sortGT $ zipWith (\(s,r,_,p,f) n -> (p,(s,r,n,p,f))) srnps' $ newDims newshape = IxFun (Lmad tau srnps'':lmads) oshp cg | perm <- getPermutation lmad, Just (head_coercions, reshapes, tail_coercions) <- splitCoercions newshape, hd_len <- length head_coercions, num_coercions <- hd_len + length tail_coercions, srnps_perm <- permuteFwd perm srnps, mid_srnps <- take (length srnps - num_coercions) $ drop hd_len srnps_perm, -- checking conditions (2) and (3) all (\ (s,r,_,_,_) -> s /= 0 && r == 0) mid_srnps, -- checking condition (1) consecutive hd_len $ map (\(_,_,_,p,_)->p) mid_srnps, -- checking condition (4) info <- getMonotonicityRots True ixfn, cg && (info == Inc || info == Dec), -- make new permutation rsh_len <- length reshapes, diff <- length newshape - length srnps, iota_shape <- [0..length newshape-1], perm' <- map (\i -> let ind = if i < hd_len then i else i - diff in if (i>=hd_len) && (i < hd_len+rsh_len) then i -- already checked mid_srnps not affected else let (_,_,_,p,_) = srnps !! ind in if p < hd_len then p else p + diff ) iota_shape, -- split the dimensions (suport_inds, repeat_inds) <- foldl (\(sup,rpt) (i,shpdim,ip) -> case (i < hd_len, i >= hd_len+rsh_len, shpdim) of (True, _, DimCoercion n) -> case srnps_perm !! i of (0,_,_,_,_) -> ( sup, (ip,n) : rpt ) (_,r,_,_,_) -> ( (ip,(r,n)) : sup, rpt ) (_, True, DimCoercion n) -> case srnps_perm !! (i-diff) of (0,_,_,_,_) -> ( sup, (ip,n) : rpt ) (_,r,_,_,_) -> ( (ip,(r,n)) : sup, rpt ) (False, False, _) -> ( (ip, (0, newDim shpdim)) : sup, rpt ) -- already checked that the reshaped -- dims cannot be repeats or rotates _ -> error "reshape: reached impossible case!" ) ([],[]) $ reverse $ zip3 iota_shape newshape perm', (sup_inds, support) <- unzip $ L.sortBy sortGT suport_inds, (rpt_inds, repeats) <- unzip repeat_inds, Lmad tau' srnps_sup <- makeRotIota info tau support, repeats' <- map (\n -> (0,0,n,0,Unknown)) repeats, srnps' <- map snd $ L.sortBy sortGT $ zip sup_inds srnps_sup ++ zip rpt_inds repeats' = IxFun (setPermutation perm' (Lmad tau' srnps') : lmads) oshp cg where splitCoercions newshape' = do let (head_coercions, newshape'') = span isCoercion newshape' let (reshapes, tail_coercions) = break isCoercion newshape'' guard (all isCoercion tail_coercions) return (head_coercions, reshapes, tail_coercions) isCoercion DimCoercion{} = True isCoercion _ = False consecutive _ [] = True consecutive i [p]= i == p consecutive i ps = and $ zipWith (==) ps [i, i+1..] reshape (IxFun lmads oshp cg) newshape = let new_dims = newDims newshape in case iota new_dims of IxFun [lmad] _ _ -> IxFun (lmad : lmads) oshp cg _ -> error "reshape: impossible case reached" rank :: IntegralExp num => IxFun num -> Int rank (IxFun [] _ _) = error "rank: empty index function" rank (IxFun (Lmad _ sss : _) _ _) = length sss base :: IxFun num -> Shape num base (IxFun [] _ _) = error "base: empty index function" base (IxFun _ osh _) = osh -- | Correctness assumption: the shape of the new base is -- equal to the base of the index function (to be rebased). rebase :: (Eq num, IntegralExp num) => IxFun num -> IxFun num -> IxFun num rebase (IxFun [] _ _) _ = error "base: empty index function 1" rebase _ (IxFun [] _ _) = error "base: empty index function 2" -- | Special Case: `x[i, (k1,m,s1), (k2,n,s2)] = orig` -- The new base would be the slice of x. -- If orig is full (contiguous) and monotonicity is known -- for all orig's dimensions (i.e., either Inc or Dec) -- Then we can compose the two into one lmad, the result -- mainly adapts the index function of the new base. -- How to handle repeated dimensions in the original? -- (a) Shave them off of the last lmad of original -- (b) Compose the result from (a) with the first -- lmad of the new base -- (c) apply a repeat operation on the result of (b). -- However, I strongly suspect that for in-place update -- what we need is actually the INVERSE of the rebase function, -- i.e., given an index function new-base and another one orig, -- compute the index function ixfn0 such that: -- new-base == rebase ixfn0 ixfn, or equivalently: -- new-base == ixfn o ixfn0 -- because then I can go bottom up and compose with ixfn0 -- all the index functions corresponding to the memory -- block associated with ixfn. rebase newbase@(IxFun (lmad_base:lmads_base) shp_base cg_base) ixfn@(IxFun lmads shp cg) | lmad_full <- last lmads, (repeats, lmad) <- shaveoffRepeats lmad_full, perm <- getPermutation lmad, srnps<- getLmadDims lmad, -- sanity condition base ixfn == shape newbase, -- TODO: handle repetitions in both lmads. -- 1) orig is full and monotonicity is known for all dims cg && length shp == length srnps, and $ zipWith (\n2 (_,_,n1,_,i1) -> n1 == n2 && i1 /= Unknown) shp srnps, -- Building the result srnps: compose permutations, -- reverse strides and adjust offset if necessary. perm_base <- getPermutation lmad_base, perm' <- map (\p -> perm !! p) perm_base, lmad_base' <- setPermutation perm' lmad_base, (srnps_base, taus_contrib) <- unzip $ zipWith (\ (s1,r1,n1,p1,_) (_,r2,_,_,i2) -> -- assumes the monotonicity of all dimensions is known let (s', tau') = if i2 == Inc then (s1,0) else (s1*(-1),s1*(n1-1)) r' | i2 == Inc = if r2 == 0 then r1 else r1+r2 | r1 == 0 = r2 | r2 == 0 = n1-r1 | otherwise = n1-r1+r2 in ((s',r',n1,p1,Inc),tau') ) (getLmadDims lmad_base') $ permuteInv perm_base srnps, -- Make resulting lmads: tau_base' <- getOffset lmad_base' + sum taus_contrib, lmad_base'' <- Lmad tau_base' srnps_base, -- Put the repeat back on top of the result newbase' <- IxFun (lmad_base'':lmads_base) shp_base cg_base, (reps, rep) <- repeats, IxFun lmads_base'' _ _ <- repeat newbase' reps rep, lmads' <- take (length lmads - 1) lmads ++ lmads_base'' = IxFun lmads' shp_base (cg && cg_base) -- General case: just concatenate Lmads since this -- refers to index-function composition -- always safe! | base ixfn == shape newbase = IxFun (lmads ++ lmad_base:lmads_base) shp_base (cg && cg_base) | otherwise = let IxFun lmads' shp_base' _ = reshape newbase $ map DimCoercion shp in IxFun (lmads ++ lmads') shp_base' (cg && cg_base) getMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> DimInfo getMonotonicity = getMonotonicityRots False -- | results in the index function corresponding to indexing -- with `i` on the outermost dimension. offsetIndex :: (Eq num, IntegralExp num) => IxFun num -> num -> IxFun num offsetIndex ixfun i | i == 0 = ixfun offsetIndex ixfun i = case shape ixfun of d:ds -> slice ixfun (DimSlice i (d-i) 1 : map (unitSlice 0) ds) [] -> error "offsetIndex: underlying index function has rank zero" -- | results in the index function corresponding to making -- the outermost dimension strided by `s`. strideIndex :: (Eq num, IntegralExp num) => IxFun num -> num -> IxFun num strideIndex ixfun s = case shape ixfun of d:ds -> slice ixfun (DimSlice 0 d s : map (unitSlice 0) ds) [] -> error "offsetIndex: underlying index function has rank zero" -- | If the memory support of the index function is contiguous -- and row-major (i.e., no transpositions, repetitions, -- rotates, etc.), then this should return the offset from -- which the memory-support of this index function starts. linearWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe num linearWithOffset (IxFun [] _ _) _ = error "linearWithOffset: empty index function" linearWithOffset ixfn@(IxFun [lmad] _ cg) elem_size | mon <- getMonotonicity ixfn, perm <- getPermutation lmad, cg && mon == Inc, all (\(s,_,_,_,_) -> s /= 0) (getLmadDims lmad), perm == [0..length perm - 1], off <- getOffset lmad = return $ off * elem_size | otherwise = Nothing linearWithOffset _ _ = Nothing -- | Similar restrictions to `linearWithOffset` except -- for transpositions, which are returned together -- with the offset. rearrangeWithOffset :: (Eq num, IntegralExp num) => IxFun num -> num -> Maybe (num, [(Int,num)]) rearrangeWithOffset (IxFun [] _ _) _ = error "rearrangeWithOffset: empty index function" rearrangeWithOffset ixfn@(IxFun [lmad] _ cg) elem_size | perm <- getPermutation lmad, mon <- getMonotonicity ixfn, cg && mon == Inc, all (\(s,_,_,_,_) -> s /= 0) (getLmadDims lmad), perm /= [0..length perm - 1], offset <- getOffset lmad * elem_size = return (offset, zip perm $ rearrangeShape perm $ shape ixfn) | otherwise = Nothing rearrangeWithOffset _ _ = Nothing isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool isLinear = (==Just 0) . flip linearWithOffset 1 ------------------------ --- Helper functions --- ------------------------ invertInfo :: DimInfo -> DimInfo invertInfo Inc = Dec invertInfo Dec = Inc invertInfo Unknown = Unknown getOffset :: Lmad num -> num getOffset (Lmad tau _) = tau getPermutation :: Lmad num -> Permutation getPermutation (Lmad _ srns) = map (\(_,_,_,p,_) -> p) srns getLmadDims :: Lmad num -> [(num,num,num,Int,DimInfo)] getLmadDims (Lmad _ srnps) = srnps setPermutation :: Permutation -> Lmad num -> Lmad num setPermutation perm (Lmad tau srnps) = Lmad tau $ zipWith (\(s,r,n,_,i) p -> (s,r,n,p,i)) srnps perm --setOffset :: num -> Lmad num -> Lmad num --setOffset tau (Lmad _ srnps) = Lmad tau srnps -- | Given an input lmad, this function computes a repetition `r` -- and a new lmad `res`, such that `repeat r res` is identical -- to the input lmad`. shaveoffRepeats :: (Eq num, IntegralExp num) => Lmad num -> (([Shape num], Shape num), Lmad num) shaveoffRepeats lmad = let perm = getPermutation lmad srnps = getLmadDims lmad -- compute the Repeat: resacc= foldl (\acc (s,_,n,_,_) -> case acc of rpt:acc0 -> if s == 0 then (n:rpt) : acc0 else [] : (rpt:acc0) _ -> error "shaveoffRepeats: empty accum!" ) [[]] $ L.reverse $ permuteFwd perm srnps last_shape = last resacc shapes = take (length resacc - 1) resacc -- update permutation and lmad: howManyRepLT k = foldl (\i (s,_,_,p,_) -> if s == 0 && p < k then i + 1 else i ) 0 srnps srnps' = foldl (\acc (s,r,n,p,info) -> if s == 0 then acc else let p' = p - howManyRepLT p in (s,r,n,p',info):acc ) [] $ L.reverse srnps lmad' = Lmad (getOffset lmad) srnps' in ((shapes,last_shape), lmad') permuteFwd :: Permutation -> [a] -> [a] permuteFwd [] _ = [] permuteFwd (p:ps) ds = (ds !! p) : permuteFwd ps ds permuteInv :: Permutation -> [a] -> [a] permuteInv ps elems = map snd $ L.sortBy sortGT $ zip ps elems sortGT :: Ord a => (a, b1) -> (a, b2) -> Ordering sortGT (a1, _) (a2, _) | a1 > a2 = GT | a1 < a2 = LT | otherwise = GT flatOneDim :: (Eq num, IntegralExp num) => (num, num, num) -> num -> num flatOneDim (s,r,n) i | s == 0 = 0 | r == 0 = i*s | otherwise = ((i+r) `mod` n) * s makeRotIota :: (IntegralExp num) => DimInfo -> num -> [(num,num)] -> Lmad num makeRotIota info tau support | info == Inc || info == Dec = let rk = length support (rs,ns) = unzip support ss0= L.reverse $ take rk $ scanl (*) 1 $ L.reverse ns ss = if info == Inc then ss0 else map (*(-1)) ss0 ps = map fromIntegral [0..rk-1] fi = replicate rk info in Lmad tau $ zip5 ss rs ns ps fi | otherwise = error "makeRotIota requires Inc or Dec!" getMonotonicityRots :: (Eq num, IntegralExp num) => Bool -> IxFun num -> DimInfo getMonotonicityRots _ (IxFun [] _ _) = error "getMonotonicityRots: empty index function" getMonotonicityRots ignore_rots (IxFun (lmad:lmads) _ _) = let mon1 = getLmadMonotonicity ignore_rots lmad in if all (==mon1) $ map (getLmadMonotonicity ignore_rots) lmads then mon1 else Unknown getLmadMonotonicity :: (Eq num, IntegralExp num) => Bool -> Lmad num -> DimInfo getLmadMonotonicity ignore_rots (Lmad _ dims) | all (isMonDim ignore_rots Inc) dims = Inc | all (isMonDim ignore_rots Dec) dims = Dec | otherwise = Unknown isMonDim :: (Eq num, IntegralExp num) => Bool -> DimInfo -> (num, num, num, Int, DimInfo) -> Bool isMonDim ignore_rots mon (s,r,_,_,info) = s == 0 || ((ignore_rots || r == 0) && mon == info)