{-# LANGUAGE RankNTypes #-}
module Data.Collate
(
Collate(..), Collator(..)
, sample, bulkSample
, collate, collateOf
, withCollator, feedCollatorOf, feedCollator
) where
import Control.Arrow (first)
import Control.Monad (void)
import Control.Monad.ST (ST, runST)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
( StateT, get, modify, execStateT
)
import Data.Functor.Const (Const(..))
import qualified Data.IntMap as IM
import Data.STRef (newSTRef, readSTRef, writeSTRef)
import Control.Lens
( Traversing', Sequenced
, forMOf_, folded, traversed, ifoldMapOf, itraverseOf, taking
)
import Control.Monad.Primitive (PrimMonad, PrimState, liftPrim)
import qualified Data.Vector.Mutable as V
newtype Collator m c = Collator
{ Collator m c -> IntMap (c -> ST (PrimState m) ())
getCollator :: IM.IntMap (c -> ST (PrimState m) ())
}
instance Semigroup (Collator m c) where
Collator IntMap (c -> ST (PrimState m) ())
l <> :: Collator m c -> Collator m c -> Collator m c
<> Collator IntMap (c -> ST (PrimState m) ())
r = IntMap (c -> ST (PrimState m) ()) -> Collator m c
forall (m :: * -> *) c.
IntMap (c -> ST (PrimState m) ()) -> Collator m c
Collator (IntMap (c -> ST (PrimState m) ()) -> Collator m c)
-> IntMap (c -> ST (PrimState m) ()) -> Collator m c
forall a b. (a -> b) -> a -> b
$ ((c -> ST (PrimState m) ())
-> (c -> ST (PrimState m) ()) -> c -> ST (PrimState m) ())
-> IntMap (c -> ST (PrimState m) ())
-> IntMap (c -> ST (PrimState m) ())
-> IntMap (c -> ST (PrimState m) ())
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IM.unionWith (\c -> ST (PrimState m) ()
x c -> ST (PrimState m) ()
y c
c -> c -> ST (PrimState m) ()
x c
c ST (PrimState m) () -> ST (PrimState m) () -> ST (PrimState m) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> c -> ST (PrimState m) ()
y c
c) IntMap (c -> ST (PrimState m) ())
l IntMap (c -> ST (PrimState m) ())
r
instance Monoid (Collator m c) where
mempty :: Collator m c
mempty = IntMap (c -> ST (PrimState m) ()) -> Collator m c
forall (m :: * -> *) c.
IntMap (c -> ST (PrimState m) ()) -> Collator m c
Collator IntMap (c -> ST (PrimState m) ())
forall a. IntMap a
IM.empty
newtype Collate c a = Collate
{ Collate c a -> forall s. ST s (ST s a, Collator (ST s) c)
unCollate :: forall s. ST s (ST s a, Collator (ST s) c)
}
instance Functor (Collate c) where
fmap :: (a -> b) -> Collate c a -> Collate c b
fmap a -> b
f (Collate forall s. ST s (ST s a, Collator (ST s) c)
go) = (forall s. ST s (ST s b, Collator (ST s) c)) -> Collate c b
forall c a.
(forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
Collate ((forall s. ST s (ST s b, Collator (ST s) c)) -> Collate c b)
-> (forall s. ST s (ST s b, Collator (ST s) c)) -> Collate c b
forall a b. (a -> b) -> a -> b
$ ((ST s a, Collator (ST s) c) -> (ST s b, Collator (ST s) c))
-> ST s (ST s a, Collator (ST s) c)
-> ST s (ST s b, Collator (ST s) c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ST s a -> ST s b)
-> (ST s a, Collator (ST s) c) -> (ST s b, Collator (ST s) c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f)) ST s (ST s a, Collator (ST s) c)
forall s. ST s (ST s a, Collator (ST s) c)
go
instance Applicative (Collate c) where
pure :: a -> Collate c a
pure a
x = (forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
forall c a.
(forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
Collate ((forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a)
-> (forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
forall a b. (a -> b) -> a -> b
$ (ST s a, Collator (ST s) c) -> ST s (ST s a, Collator (ST s) c)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x, Collator (ST s) c
forall a. Monoid a => a
mempty)
Collate forall s. ST s (ST s (a -> b), Collator (ST s) c)
goF <*> :: Collate c (a -> b) -> Collate c a -> Collate c b
<*> Collate forall s. ST s (ST s a, Collator (ST s) c)
goX = (forall s. ST s (ST s b, Collator (ST s) c)) -> Collate c b
forall c a.
(forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
Collate ((forall s. ST s (ST s b, Collator (ST s) c)) -> Collate c b)
-> (forall s. ST s (ST s b, Collator (ST s) c)) -> Collate c b
forall a b. (a -> b) -> a -> b
$ do
(ST s (a -> b)
mf, Collator (ST s) c
sf) <- ST s (ST s (a -> b), Collator (ST s) c)
forall s. ST s (ST s (a -> b), Collator (ST s) c)
goF
(ST s a
mx, Collator (ST s) c
sx) <- ST s (ST s a, Collator (ST s) c)
forall s. ST s (ST s a, Collator (ST s) c)
goX
(ST s b, Collator (ST s) c) -> ST s (ST s b, Collator (ST s) c)
forall (m :: * -> *) a. Monad m => a -> m a
return (ST s (a -> b)
mf ST s (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ST s a
mx, Collator (ST s) c
sf Collator (ST s) c -> Collator (ST s) c -> Collator (ST s) c
forall a. Semigroup a => a -> a -> a
<> Collator (ST s) c
sx)
withCollator :: PrimMonad m => Collate c a -> (Collator m c -> m ()) -> m a
withCollator :: Collate c a -> (Collator m c -> m ()) -> m a
withCollator (Collate forall s. ST s (ST s a, Collator (ST s) c)
go) Collator m c -> m ()
k = do
(ST (PrimState m) a
stA, Collator IntMap (c -> ST (PrimState (ST (PrimState m))) ())
samples) <- ST
(PrimState m) (ST (PrimState m) a, Collator (ST (PrimState m)) c)
-> m (ST (PrimState m) a, Collator (ST (PrimState m)) c)
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim ST
(PrimState m) (ST (PrimState m) a, Collator (ST (PrimState m)) c)
forall s. ST s (ST s a, Collator (ST s) c)
go
Collator m c -> m ()
k (IntMap (c -> ST (PrimState m) ()) -> Collator m c
forall (m :: * -> *) c.
IntMap (c -> ST (PrimState m) ()) -> Collator m c
Collator IntMap (c -> ST (PrimState m) ())
IntMap (c -> ST (PrimState (ST (PrimState m))) ())
samples)
ST (PrimState m) a -> m a
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim ST (PrimState m) a
stA
feedCollatorOf
:: forall m s c
. PrimMonad m
=> Traversing' (->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
-> Int -> Collator m c -> s -> m Int
feedCollatorOf :: Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
-> Int -> Collator m c -> s -> m Int
feedCollatorOf Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
l Int
i0 (Collator IntMap (c -> ST (PrimState m) ())
samplers) s
s = ST (PrimState m) Int -> m Int
forall (m1 :: * -> *) (m2 :: * -> *) a.
(PrimBase m1, PrimMonad m2, PrimState m1 ~ PrimState m2) =>
m1 a -> m2 a
liftPrim (ST (PrimState m) Int -> m Int) -> ST (PrimState m) Int -> m Int
forall a b. (a -> b) -> a -> b
$ (StateT Int (ST (PrimState m)) () -> Int -> ST (PrimState m) Int)
-> Int -> StateT Int (ST (PrimState m)) () -> ST (PrimState m) Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT Int (ST (PrimState m)) () -> Int -> ST (PrimState m) Int
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Int
i0 (StateT Int (ST (PrimState m)) () -> ST (PrimState m) Int)
-> StateT Int (ST (PrimState m)) () -> ST (PrimState m) Int
forall a b. (a -> b) -> a -> b
$
Getting (Sequenced () (StateT Int (ST (PrimState m)))) s c
-> s
-> (c -> StateT Int (ST (PrimState m)) ())
-> StateT Int (ST (PrimState m)) ()
forall (m :: * -> *) r s a.
Monad m =>
Getting (Sequenced r m) s a -> s -> (a -> m r) -> m ()
forMOf_ (Int
-> Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
-> Getting (Sequenced () (StateT Int (ST (PrimState m)))) s c
forall (p :: * -> * -> *) (f :: * -> *) s t a.
(Conjoined p, Applicative f) =>
Int -> Traversing p f s t a a -> Over p f s t a a
taking Int
n Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
l) s
s ((c -> StateT Int (ST (PrimState m)) ())
-> StateT Int (ST (PrimState m)) ())
-> (c -> StateT Int (ST (PrimState m)) ())
-> StateT Int (ST (PrimState m)) ()
forall a b. (a -> b) -> a -> b
$ \c
c -> do
Int
i <- StateT Int (ST (PrimState m)) Int
forall (m :: * -> *) s. Monad m => StateT s m s
get
(Int -> Int) -> StateT Int (ST (PrimState m)) ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
ST (PrimState m) () -> StateT Int (ST (PrimState m)) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST (PrimState m) () -> StateT Int (ST (PrimState m)) ())
-> ST (PrimState m) () -> StateT Int (ST (PrimState m)) ()
forall a b. (a -> b) -> a -> b
$ (c -> ST (PrimState m) ())
-> Int
-> IntMap (c -> ST (PrimState m) ())
-> c
-> ST (PrimState m) ()
forall a. a -> Int -> IntMap a -> a
IM.findWithDefault (ST (PrimState m) () -> c -> ST (PrimState m) ()
forall a b. a -> b -> a
const (() -> ST (PrimState m) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())) Int
i IntMap (c -> ST (PrimState m) ())
samplers c
c
where
n :: Int
n = case IntMap (c -> ST (PrimState m) ())
-> Maybe
((Int, c -> ST (PrimState m) ()),
IntMap (c -> ST (PrimState m) ()))
forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IM.maxViewWithKey IntMap (c -> ST (PrimState m) ())
samplers of
Maybe
((Int, c -> ST (PrimState m) ()),
IntMap (c -> ST (PrimState m) ()))
Nothing -> Int
0
Just ((Int
k, c -> ST (PrimState m) ()
_), IntMap (c -> ST (PrimState m) ())
_) -> Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
feedCollator
:: forall m f c
. (PrimMonad m, Foldable f)
=> Int -> Collator m c -> f c -> m Int
feedCollator :: Int -> Collator m c -> f c -> m Int
feedCollator = Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) (f c) c
-> Int -> Collator m c -> f c -> m Int
forall (m :: * -> *) s c.
PrimMonad m =>
Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
-> Int -> Collator m c -> s -> m Int
feedCollatorOf Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) (f c) c
forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
folded
collate :: Foldable f => Collate c a -> f c -> a
collate :: Collate c a -> f c -> a
collate = (forall s0.
Traversing'
(->) (Const (Sequenced () (StateT Int (ST s0)))) (f c) c)
-> Collate c a -> f c -> a
forall s c a.
(forall s0.
Traversing' (->) (Const (Sequenced () (StateT Int (ST s0)))) s c)
-> Collate c a -> s -> a
collateOf forall s0.
Traversing'
(->) (Const (Sequenced () (StateT Int (ST s0)))) (f c) c
forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
folded
collateOf
:: ( forall s0
. Traversing' (->) (Const (Sequenced () (StateT Int (ST s0)))) s c
)
-> Collate c a -> s -> a
collateOf :: (forall s0.
Traversing' (->) (Const (Sequenced () (StateT Int (ST s0)))) s c)
-> Collate c a -> s -> a
collateOf forall s0.
Traversing' (->) (Const (Sequenced () (StateT Int (ST s0)))) s c
l Collate c a
c s
f = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ Collate c a -> (Collator (ST s) c -> ST s ()) -> ST s a
forall (m :: * -> *) c a.
PrimMonad m =>
Collate c a -> (Collator m c -> m ()) -> m a
withCollator Collate c a
c ((Collator (ST s) c -> ST s ()) -> ST s a)
-> (Collator (ST s) c -> ST s ()) -> ST s a
forall a b. (a -> b) -> a -> b
$ \Collator (ST s) c
m -> ST s Int -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s Int -> ST s ()) -> ST s Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Traversing'
(->)
(Const (Sequenced () (StateT Int (ST (PrimState (ST s))))))
s
c
-> Int -> Collator (ST s) c -> s -> ST s Int
forall (m :: * -> *) s c.
PrimMonad m =>
Traversing'
(->) (Const (Sequenced () (StateT Int (ST (PrimState m))))) s c
-> Int -> Collator m c -> s -> m Int
feedCollatorOf Traversing'
(->)
(Const (Sequenced () (StateT Int (ST (PrimState (ST s))))))
s
c
forall s0.
Traversing' (->) (Const (Sequenced () (StateT Int (ST s0)))) s c
l Int
0 Collator (ST s) c
m s
f
sample :: Int -> (c -> a) -> Collate c a
sample :: Int -> (c -> a) -> Collate c a
sample Int
i c -> a
f = (forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
forall c a.
(forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
Collate ((forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a)
-> (forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
forall a b. (a -> b) -> a -> b
$ do
STRef s a
ref <- a -> ST s (STRef s a)
forall a s. a -> ST s (STRef s a)
newSTRef ([Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"sample: Internal error: unfulfilled promise")
(ST s a, Collator (ST s) c) -> ST s (ST s a, Collator (ST s) c)
forall (m :: * -> *) a. Monad m => a -> m a
return
( STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
ref
, IntMap (c -> ST (PrimState (ST s)) ()) -> Collator (ST s) c
forall (m :: * -> *) c.
IntMap (c -> ST (PrimState m) ()) -> Collator m c
Collator (IntMap (c -> ST (PrimState (ST s)) ()) -> Collator (ST s) c)
-> IntMap (c -> ST (PrimState (ST s)) ()) -> Collator (ST s) c
forall a b. (a -> b) -> a -> b
$ [(Int, c -> ST s ())] -> IntMap (c -> ST s ())
forall a. [(Int, a)] -> IntMap a
IM.fromList [(Int
i, \c
c -> STRef s a -> a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s a
ref (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! c -> a
f c
c)]
)
bulkSample :: Traversable t => t Int -> (c -> a) -> Collate c (t a)
bulkSample :: t Int -> (c -> a) -> Collate c (t a)
bulkSample t Int
t c -> a
f = (forall s. ST s (ST s (t a), Collator (ST s) c)) -> Collate c (t a)
forall c a.
(forall s. ST s (ST s a, Collator (ST s) c)) -> Collate c a
Collate ((forall s. ST s (ST s (t a), Collator (ST s) c))
-> Collate c (t a))
-> (forall s. ST s (ST s (t a), Collator (ST s) c))
-> Collate c (t a)
forall a b. (a -> b) -> a -> b
$ do
MVector s a
vec <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
V.new (t Int -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t Int
t)
let collator :: Collator (ST s) c
collator = IndexedGetting Int (Collator (ST s) c) (t Int) Int
-> (Int -> Int -> Collator (ST s) c) -> t Int -> Collator (ST s) c
forall i m s a. IndexedGetting i m s a -> (i -> a -> m) -> s -> m
ifoldMapOf IndexedGetting Int (Collator (ST s) c) (t Int) Int
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed
(\ Int
iVec Int
iInp -> IntMap (c -> ST (PrimState (ST s)) ()) -> Collator (ST s) c
forall (m :: * -> *) c.
IntMap (c -> ST (PrimState m) ()) -> Collator m c
Collator (IntMap (c -> ST (PrimState (ST s)) ()) -> Collator (ST s) c)
-> IntMap (c -> ST (PrimState (ST s)) ()) -> Collator (ST s) c
forall a b. (a -> b) -> a -> b
$
Int
-> (c -> ST (PrimState (ST s)) ())
-> IntMap (c -> ST (PrimState (ST s)) ())
forall a. Int -> a -> IntMap a
IM.singleton Int
iInp (\c
c -> MVector (PrimState (ST (PrimState (ST s)))) a
-> Int -> a -> ST (PrimState (ST s)) ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
V.write MVector s a
MVector (PrimState (ST (PrimState (ST s)))) a
vec Int
iVec (a -> ST (PrimState (ST s)) ()) -> a -> ST (PrimState (ST s)) ()
forall a b. (a -> b) -> a -> b
$! c -> a
f c
c))
t Int
t
(ST s (t a), Collator (ST s) c)
-> ST s (ST s (t a), Collator (ST s) c)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Indexed Int Int (ST s a) -> t Int -> ST s (t a))
-> (Int -> Int -> ST s a) -> t Int -> ST s (t a)
forall i a (f :: * -> *) b s t.
(Indexed i a (f b) -> s -> f t) -> (i -> a -> f b) -> s -> f t
itraverseOf Indexed Int Int (ST s a) -> t Int -> ST s (t a)
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed (\Int
i Int
_ -> MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
V.read MVector s a
MVector (PrimState (ST s)) a
vec Int
i) t Int
t, Collator (ST s) c
collator)