{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoStarIsType          #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# OPTIONS_HADDOCK prune not-home #-}

{-|
Copyright   : (c) 2020-2021 Tim Emiola
SPDX-License-Identifier: BSD3
Maintainer  : Tim Emiola <adetokunbo@users.noreply.github.com>

Defines type-level combinators for performing a merge sort of type-level lists.

'SortSymbols' sorts type-level lists of @Symbols@.

The other exported combinators make it easy to implement type-level merge sort
for similar type-level lists.

This is an internal module that provides type-level functions used in
various constraints in "System.TmpProc.Docker".

-}
module System.TmpProc.TypeLevel.Sort
  ( -- * Merge sort for @Symbols@.
    SortSymbols

     -- * Sort combinators
   , Take
   , Drop
   , LengthOf
   , HalfOf
  )
where

import           GHC.TypeLits (CmpNat, CmpSymbol, Nat, Symbol, type (*),
                               type (+), type (-))


-- $setup
-- >>> import Data.Proxy
-- >>> :set -XDataKinds
-- >>> :set -XTypeFamilies


{-| Takes 1 element at a time from a list until the desired length is reached.

==== __Examples__

>>> :kind! Take '[1, 2, 3, 4] 2
Take '[1, 2, 3, 4] 2 :: [Nat]
= '[1, 2]

-}
type family Take (xs :: [k]) (n :: Nat) :: [k] where
    Take '[] n = '[]
    Take xs 0 = '[]
    Take (x ': xs) n = (x ': Take xs (n - 1))


{-| Drops 1 element at a time until the the dropped target is reached.

==== __Examples__

>>> :kind! Drop '[1, 2, 3, 4] 2
Drop '[1, 2, 3, 4] 2 :: [Nat]
= '[3, 4]


>>> :kind! Drop '[1] 2
Drop '[1] 2 :: [Nat]
= '[]

-}
type family Drop (xs :: [k]) (n :: Nat) :: [k] where
    Drop '[] n = '[]
    Drop xs 0 = xs
    Drop (x ': xs) n = Drop xs (n - 1)


{-| Counts a list, 1 element at a time.

==== __Examples__

>>> :kind! LengthOf '[1, 2, 3, 4]
LengthOf '[1, 2, 3, 4] :: Nat
= 4

-}
type family LengthOf (xs :: [k]) :: Nat where
    LengthOf '[] = 0
    LengthOf (x ': xs) = 1 + LengthOf xs


{-| Computes the midpoint of a number.

N.B: maximum value that this works for depends on the reduction limit of the
type-checker.

==== __Examples__

>>> :kind! HalfOf 99
HalfOf 99 :: Nat
= 49

>>> :kind! HalfOf 100
HalfOf 100 :: Nat
= 50

-}
type family HalfOf (n :: Nat) :: Nat where
    -- optimizations for faster compilation
    HalfOf 0 = 0
    HalfOf 1 = 1
    HalfOf 2 = 1
    HalfOf 3 = 1
    HalfOf 4 = 2
    HalfOf 5 = 2
    HalfOf 6 = 3
    HalfOf 7 = 3
    -- the general case
    HalfOf n = HalfOfImpl n 0 n 'LT

{-| Implements 'HalfOf'. -}
type family HalfOfImpl (n :: Nat) (i :: Nat) (dist :: Nat) (o :: Ordering) :: Nat where
    HalfOfImpl n m dist 'GT = m - 1
    HalfOfImpl n m dist 'EQ = m
    HalfOfImpl n m 1 'LT = m
    HalfOfImpl n m dist 'LT = HalfOfImpl n (m + 2) (n - ((m + 2) * 2)) (CmpNat ((m + 2) * 2) n)


{-| Sort a list of type-level @symbols@ using merge sort.

==== __Examples__

>>> :kind! SortSymbols '["xyz", "def", "abc"]
SortSymbols '["xyz", "def", "abc"] :: [Symbol]
= '["abc", "def", "xyz"]

-}
type family SortSymbols (xs :: [Symbol]) :: [Symbol] where
    SortSymbols '[]     = '[]
    SortSymbols '[x]    = '[x]
    SortSymbols '[x, y] = MergeSymbols '[x] '[y] -- an optimization, could be removed
    SortSymbols xs      = SortSymbolsStep xs (HalfOf (LengthOf xs))


{-| Used internally by @SortSymbols. -}
type family SortSymbolsStep (xs :: [Symbol]) (halfLen :: Nat) :: [Symbol] where
    SortSymbolsStep xs halfLen = MergeSymbols
      (SortSymbols (Take xs halfLen))
      (SortSymbols (Drop xs halfLen))

{-| Used internally by @SortSymbols. -}
type family MergeSymbols (xs :: [Symbol]) (ys :: [Symbol]) :: [Symbol] where
    MergeSymbols xs '[]              = xs
    MergeSymbols '[] ys              = ys
    MergeSymbols (x ': xs) (y ': ys) = MergeSymbolsImpl (x ': xs) (y ': ys) (CmpSymbol x y)

type family MergeSymbolsImpl (xs :: [Symbol]) (ys :: [Symbol]) (o :: Ordering) :: [Symbol] where
    MergeSymbolsImpl (x ': xs) (y ': ys) 'GT = y ': MergeSymbols (x ': xs) ys
    MergeSymbolsImpl (x ': xs) (y ': ys) leq = x ': MergeSymbols xs (y ': ys)