Ticket #3181: StrictPair.hs

File StrictPair.hs, 1.9 KB (added by dolio, 4 years ago)

A test case: a portion of a heap sort algorithm

Line 
1{-# LANGUAGE TypeOperators #-}
2
3module Main (main) where
4
5import Control.Monad.ST
6
7import Data.Array.Vector
8
9siftByOffset :: (UA e) => (e -> e -> Ordering) -> MUArr e s -> e -> Int -> Int -> Int -> ST s ()
10siftByOffset cmp a val off start len = sift val start len
11 where
12 sift val root len
13   | child < len = do (child' :*: ac) <- maximumChild cmp a off child len
14                      case cmp val ac of
15                        LT -> writeMU a (root + off) ac >> sift val child' len
16                        _  -> writeMU a (root + off) val
17   | otherwise = writeMU a (root + off) val
18  where child = root * 3 + 1
19{-# INLINE siftByOffset #-}
20
21maximumChild :: (UA e) => (e -> e -> Ordering) -> MUArr e s -> Int -> Int -> Int -> ST s (Int :*: e)
22maximumChild cmp a off child1 len
23  | child3 < len = do ac1 <- readMU a (child1 + off)
24                      ac2 <- readMU a (child2 + off)
25                      ac3 <- readMU a (child3 + off)
26                      return $ case cmp ac1 ac2 of
27                                 LT -> case cmp ac2 ac3 of
28                                         LT -> child3 :*: ac3
29                                         _  -> child2 :*: ac2
30                                 _  -> case cmp ac1 ac3 of
31                                         LT -> child3 :*: ac3
32                                         _  -> child1 :*: ac1
33  | child2 < len = do ac1 <- readMU a (child1 + off)
34                      ac2 <- readMU a (child2 + off)
35                      return $ case cmp ac1 ac2 of
36                                 LT -> child2 :*: ac2
37                                 _  -> child1 :*: ac1
38  | otherwise    = do ac1 <- readMU a (child1 + off) ; return (child1 :*: ac1)
39 where
40 child2 = child1 + 1
41 child3 = child1 + 2
42{-# INLINE maximumChild #-}
43
44test :: MUArr Int s -> ST s ()
45test arr = siftByOffset compare arr len len len len
46 where len = lengthMU arr
47
48main = stToIO (newMU 40 >>= test)