{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.Subroutine.Sort
( SortBy (..), sortBy, sort
, SortableDataFrame
) where
import Control.Monad
import Control.Monad.ST
import Control.Monad.ST.Unsafe
import Data.Kind
import Data.Type.Lits
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Unsafe.Coerce
sort :: forall (t :: Type) n ns
. ( SortableDataFrame t (n ': ns), Ord t, SortBy n)
=> DataFrame t (n ': ns)
-> DataFrame t (n ': ns)
sort df = case dimKind @(KindOf n) of
DimKNat -> case uniqueOrCumulDims df of
Left _ -> df
Right steps
| SomeDims (Dims :: Dims ms) <- fromSteps steps
, Dict <- (unsafeCoerce (Dict @(ns ~ ns)) :: Dict (ns ~ ms))
-> sortBy compare df
| otherwise
-> error "sort/DimNat/uniqueOrCumulDims -- impossible pattern"
DimKXNat
| XFrame (df' :: DataFrame t ms) <- df
, D :* Dims <- dims @ms
-> XFrame (sortBy compare df')
| otherwise
-> error "sort/DimXNat -- impossible pattern"
sortBy :: forall (t :: Type) n ns
. ( SortableDataFrame t (n ': ns)
, SortBy n)
=> (DataFrame t ns -> DataFrame t ns -> Ordering)
-> DataFrame t (n ': ns)
-> DataFrame t (n ': ns)
sortBy cmp df = case dimKind @(KindOf n) of
DimKNat -> runST $
flip (withThawDataFrame (const $ pure df)) df $ \mdf -> do
sortByInplace
(\x y -> cmp <$> unsafeFreezeDataFrame x <*> unsafeFreezeDataFrame y)
mdf
unsafeFreezeDataFrame mdf
DimKXNat
| XFrame dfN <- df
, D :* Dims <- dims `inSpaceOf` dfN
-> XFrame (sortBy (\a b -> cmp (XFrame a) (XFrame b)) dfN)
| otherwise
-> error "sortBy/DimXNat -- impossible pattern"
type family SortableDataFrame (t :: Type) (ns :: [k]) :: Constraint where
SortableDataFrame t ((n ': ns) :: [Nat])
= (PrimArray t (DataFrame t ns), PrimArray t (DataFrame t (n ': ns)))
SortableDataFrame t ((n ': ns) :: [XNat])
= PrimBytes t
class BoundedDim n => SortBy n where
sortByInplace :: PrimBytes t
=> (STDataFrame s t ns -> STDataFrame s t ns -> ST s Ordering)
-> STDataFrame s t (n ': ns)
-> ST s ()
instance SortBy 0 where
sortByInplace _ _ = pure ()
instance SortBy 1 where
sortByInplace _ _ = pure ()
instance SortBy 2 where
sortByInplace cmp xs = cmp a b >>= \case
GT -> do
tmp <- oneMoreDataFrame a
swapDF tmp a b
_ -> pure ()
where
a = subDataFrameView' (Idx 0 :* U) xs
b = subDataFrameView' (Idx 1 :* U) xs
instance SortBy 3 where
sortByInplace cmp xs = join $
go <$> (unsafeDupableInterleaveST (oneMoreDataFrame a))
<*> cmp a b <*> cmp b c <*> cmp a c
where
a = subDataFrameView' (Idx 0 :* U) xs
b = subDataFrameView' (Idx 1 :* U) xs
c = subDataFrameView' (Idx 2 :* U) xs
go tmp GT LT GT
= swap3DF tmp a b c
go tmp LT GT GT
= swap3DF tmp b a c
go tmp GT bc ac | bc /= GT && ac /= GT
= swapDF tmp a b
go tmp ab GT ac | ab /= GT && ac /= GT
= swapDF tmp b c
go tmp ab bc GT | ab /= LT && bc /= LT
= swapDF tmp a c
go _ _ _ _ = pure ()
instance SortBy 4 where
sortByInplace cmp xs = do
tmp <- unsafeDupableInterleaveST (oneMoreDataFrame a)
cmpSwap tmp a c
cmpSwap tmp b d
cmpSwap tmp a b
cmpSwap tmp c d
cmpSwap tmp b c
where
a = subDataFrameView' (Idx 0 :* U) xs
b = subDataFrameView' (Idx 1 :* U) xs
c = subDataFrameView' (Idx 2 :* U) xs
d = subDataFrameView' (Idx 3 :* U) xs
cmpSwap tmp x y = cmp x y >>= \case
GT -> swapDF tmp x y
_ -> pure ()
instance {-# INCOHERENT #-}
KnownDim n => SortBy (n :: Nat) where
sortByInplace cmp (xs :: STDataFrame s t (n ': ns)) = do
tmp <- oneMoreDataFrame xs
copyMutableDataFrame' U xs tmp
mergeSort D tmp xs
where
mergeSort :: Dim (d :: Nat)
-> STDataFrame s t (d ': ns)
-> STDataFrame s t (d ': ns)
-> ST s ()
mergeSort D0 _ _ = pure ()
mergeSort D1 _ _ = pure ()
mergeSort (d@D :: Dim d) b a = do
d2l@D <- pure $ divDim d D2
Just d2r@D <- pure $ minusDimM d d2l
d2li@D <- pure $ plusDim d2l D1
d2ri@D <- pure $ plusDim d2r D1
Just Dict <- pure $ sameDim (plusDim d D1) (plusDim d2li d2r)
Just Dict <- pure $ sameDim (plusDim d D1) (plusDim d2ri d2l)
let leA = subDataFrameView @t @d @(d - Div d 2 + 1) @(Div d 2) @'[]
(Idx 0 :* U) a
riA = subDataFrameView @t @d @(Div d 2 + 1) @(d - Div d 2) @'[]
(Idx (dimVal d2l) :* U) a
leB = subDataFrameView @t @d @(d - Div d 2 + 1) @(Div d 2) @'[]
(Idx 0 :* U) b
riB = subDataFrameView @t @d @(Div d 2 + 1) @(d - Div d 2) @'[]
(Idx (dimVal d2l) :* U) b
mergeSort d2l leA leB
mergeSort d2r riA riB
merge d2l d2r d leB riB a
merge :: forall (a :: Nat) (b :: Nat) (ab :: Nat)
. Dim a -> Dim b -> Dim ab
-> STDataFrame s t (a ': ns)
-> STDataFrame s t (b ': ns)
-> STDataFrame s t (ab ': ns)
-> ST s ()
merge da@D db@D dab@D a b ab = foldM_ f (Just (0,0)) [0 .. dimVal dab - 1]
where
f Nothing _ = pure Nothing
f (Just (i,j)) k
| i >= dimVal da
, Dx dj@(D :: Dim j) <- someDimVal j
, D <- plusDim dj D1
, Just bmj@D <- minusDimM db dj
, Just bmji@D <- minusDimM (plusDim dab D1) bmj
, Just Dict <- sameDim (plusDim dab D1) (plusDim bmji bmj)
, Just Dict <- sameDim (plusDim db D1) (dj `plusDim` D1 `plusDim` bmj)
= Nothing <$ copyMutableDataFrame @t @ab @(ab + 1 - (b - j))
@(b - j) (Idx k :* U)
(subDataFrameView @t @b @(j + 1) @(b - j) (Idx j :* U) b) ab
| j >= dimVal db
, Dx di@(D :: Dim i) <- someDimVal i
, D <- plusDim di D1
, Just bmi@D <- minusDimM da di
, Just bmii@D <- minusDimM (plusDim dab D1) bmi
, Just Dict <- sameDim (plusDim dab D1) (plusDim bmii bmi)
, Just Dict <- sameDim (plusDim da D1) (di `plusDim` D1 `plusDim` bmi)
= Nothing <$ copyMutableDataFrame (Idx k :* U)
(subDataFrameView @t @a @(i + 1) @(a - i) (Idx i :* U) a) ab
| otherwise
= cmp (subDataFrameView' (Idx i :* U) a)
(subDataFrameView' (Idx j :* U) b) >>= \case
GT -> Just (i, j + 1)
<$ copyMutableDataFrame' (Idx k :* U)
(subDataFrameView' (Idx j :* U) b) ab
_ -> Just (i + 1, j)
<$ copyMutableDataFrame' (Idx k :* U)
(subDataFrameView' (Idx i :* U) a) ab
instance BoundedDim xn => SortBy (xn :: XNat) where
sortByInplace cmp (XSTFrame xs)
| D :* _ <- dims `inSpaceOf` xs
= sortByInplace (\x y -> cmp (castDataFrame x) (castDataFrame y)) xs
| otherwise = error "sortByInplace: impossible pattern"
swapDF :: forall (s :: Type) (t :: Type) (ns :: [Nat])
. PrimBytes t
=> STDataFrame s t ns
-> STDataFrame s t ns
-> STDataFrame s t ns
-> ST s ()
swapDF tmp a b = do
copyMutableDataFrame' U a tmp
copyMutableDataFrame' U b a
copyMutableDataFrame' U tmp b
swap3DF :: forall (s :: Type) (t :: Type) (ns :: [Nat])
. PrimBytes t
=> STDataFrame s t ns
-> STDataFrame s t ns
-> STDataFrame s t ns
-> STDataFrame s t ns
-> ST s ()
swap3DF tmp a b c = do
copyMutableDataFrame' U a tmp
copyMutableDataFrame' U b a
copyMutableDataFrame' U c b
copyMutableDataFrame' U tmp c