{-# LANGUAGE TemplateHaskell #-}
module Data.SortingNetwork.TH (
gMkSortBy,
mkSortListBy,
mkSortTupBy,
mkSortListByFns,
mkSortTupByFns,
) where
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Language.Haskell.TH
type MkPairs = Int -> [(Int, Int)]
type PartQ = Exp -> Q Exp
gMkSortBy :: MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy :: MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
mkP [Exp] -> Exp
mkE = do
Name
cmp <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"cmp"
Name
swapper <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"sw"
Exp
swapperVal <- [|\u v f -> if $(varE cmp) u v == GT then f v u else f u v|]
[Name]
ns0 <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
let
step0 :: PartQ
step0 :: PartQ
step0 Exp
bd = [|let $(varP swapper) = $(pure swapperVal) in $(pure bd)|]
(PartQ
mkBody :: PartQ, [Name]
ns :: [Name]) <- do
MVector RealWorld Name
nv <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw (forall a. [a] -> Vector a
V.fromList [Name]
ns0)
PartQ
e <-
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \(PartQ
mk :: PartQ) (Int
i, Int
j) -> do
Name
iOld <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.unsafeRead MVector RealWorld Name
nv Int
i
Name
jOld <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.unsafeRead MVector RealWorld Name
nv Int
j
Name
iNew <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
Name
jNew <- forall (m :: * -> *). Quote m => String -> m Name
newName String
"v"
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector RealWorld Name
nv Int
i Name
iNew
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector RealWorld Name
nv Int
j Name
jNew
forall (f :: * -> *) a. Applicative f => a -> f a
pure \(Exp
hole :: Exp) ->
PartQ
mk
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [|
$(varE swapper)
$(varE iOld)
$(varE jOld)
(\ $(varP iNew) $(varP jNew) -> $(pure hole))
|]
)
PartQ
step0
(MkPairs
mkPairs Int
n)
Vector Name
nvFin <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector RealWorld Name
nv
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PartQ
e, forall a. Vector a -> [a]
V.toList Vector Name
nvFin)
[|
\ $(varP cmp)
$(pure $ mkP $ VarP <$> ns0) ->
$(mkBody $ mkE $ VarE <$> ns)
|]
mkSortListBy, mkSortTupBy :: MkPairs -> Int -> ExpQ
mkSortListBy :: MkPairs -> Int -> Q Exp
mkSortListBy MkPairs
mkPairs Int
n = MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
ListP [Exp] -> Exp
ListE
mkSortTupBy :: MkPairs -> Int -> Q Exp
mkSortTupBy MkPairs
mkPairs Int
n = MkPairs -> Int -> ([Pat] -> Pat) -> ([Exp] -> Exp) -> Q Exp
gMkSortBy MkPairs
mkPairs Int
n [Pat] -> Pat
TupP ([Maybe Exp] -> Exp
TupE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Maybe a
Just)
mkSortListByFns, mkSortTupByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortListByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortListByFns MkPairs
mkPairs [Int]
ns =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
ns \Int
n -> do
let defN :: Name
defN :: Name
defN = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"sortList" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
"By"
Exp
bd <- MkPairs -> Int -> Q Exp
mkSortListBy MkPairs
mkPairs Int
n
[d|$(varP defN) = $(pure bd)|]
mkSortTupByFns :: MkPairs -> [Int] -> Q [Dec]
mkSortTupByFns MkPairs
mkPairs [Int]
ns =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
ns \Int
n -> do
let defN :: Name
defN = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"sortTup" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n forall a. Semigroup a => a -> a -> a
<> String
"By"
Exp
bd <- MkPairs -> Int -> Q Exp
mkSortTupBy MkPairs
mkPairs Int
n
[d|$(varP defN) = $(pure bd)|]