-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SBV.Plugin.Examples.MergeSort
-- Copyright   :  (c) Levent Erkok
-- License     :  BSD3
-- Maintainer  :  erkokl@gmail.com
-- Stability   :  experimental
--
-- An implementation of merge-sort and its correctness.
-----------------------------------------------------------------------------

{-# LANGUAGE CPP #-}

#ifndef HADDOCK
{-# OPTIONS_GHC -fplugin=Data.SBV.Plugin #-}
#endif

{-# OPTIONS_GHC -Wall -Werror #-}

module Data.SBV.Plugin.Examples.MergeSort where

import Data.SBV.Plugin

-----------------------------------------------------------------------------
-- * Implementing merge-sort
-- ${mergeSort}
-----------------------------------------------------------------------------
{- $mergeSort
A straightforward implementation of merge sort. We simply divide the input list
in to two halves so long as it has at least two elements, sort each half on its
own, and then merge.
-}

-- | Merging two given sorted lists, preserving the order.
merge :: Ord a => [a] -> [a] -> [a]
merge :: forall a. Ord a => [a] -> [a] -> [a]
merge []     [a]
ys           = [a]
ys
merge [a]
xs     []           = [a]
xs
merge xs :: [a]
xs@(a
x:[a]
xr) ys :: [a]
ys@(a
y:[a]
yr)
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
y                 = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a]
forall a. Ord a => [a] -> [a] -> [a]
merge [a]
xr [a]
ys
  | Bool
True                  = a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a]
forall a. Ord a => [a] -> [a] -> [a]
merge [a]
xs [a]
yr

-- | Simple merge-sort implementation.
mergeSort :: Ord a => [a] -> [a]
mergeSort :: forall a. Ord a => [a] -> [a]
mergeSort []  = []
mergeSort [a
x] = [a
x]
mergeSort [a]
xs  = [a] -> [a] -> [a]
forall a. Ord a => [a] -> [a] -> [a]
merge ([a] -> [a]
forall a. Ord a => [a] -> [a]
mergeSort [a]
th) ([a] -> [a]
forall a. Ord a => [a] -> [a]
mergeSort [a]
bh)
   where ([a]
th, [a]
bh) = [a] -> ([a], [a]) -> ([a], [a])
forall a. [a] -> ([a], [a]) -> ([a], [a])
halve [a]
xs ([], [])
         halve :: [a] -> ([a], [a]) -> ([a], [a])
         halve :: forall a. [a] -> ([a], [a]) -> ([a], [a])
halve []     ([a], [a])
sofar    = ([a], [a])
sofar
         halve (a
a:[a]
as) ([a]
fs, [a]
ss) = [a] -> ([a], [a]) -> ([a], [a])
forall a. [a] -> ([a], [a]) -> ([a], [a])
halve [a]
as ([a]
ss, a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
fs)

-----------------------------------------------------------------------------
-- * Proving correctness of sorting
-- ${props}
-----------------------------------------------------------------------------
{- $props
There are two main parts to proving that a sorting algorithm is correct:

       * Prove that the output is non-decreasing
 
       * Prove that the output is a permutation of the input
-}

-- | Check whether a given sequence is non-decreasing.
nonDecreasing :: Ord a => [a] -> Bool
nonDecreasing :: forall a. Ord a => [a] -> Bool
nonDecreasing []       = Bool
True
nonDecreasing [a
_]      = Bool
True
nonDecreasing (a
a:a
b:[a]
xs) = a
a a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
b Bool -> Bool -> Bool
&& [a] -> Bool
forall a. Ord a => [a] -> Bool
nonDecreasing (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs)

-- | Check whether two given sequences are permutations. We simply check that each sequence
-- is a subset of the other, when considered as a set. The check is slightly complicated
-- for the need to account for possibly duplicated elements.
isPermutationOf :: Eq a => [a] -> [a] -> Bool
isPermutationOf :: forall a. Eq a => [a] -> [a] -> Bool
isPermutationOf [a]
as [a]
bs = [a] -> [(a, Bool)] -> Bool
forall a. Eq a => [a] -> [(a, Bool)] -> Bool
go [a]
as [(a
b, Bool
True) | a
b <- [a]
bs] Bool -> Bool -> Bool
&& [a] -> [(a, Bool)] -> Bool
forall a. Eq a => [a] -> [(a, Bool)] -> Bool
go [a]
bs [(a
a, Bool
True) | a
a <- [a]
as]
  where go :: Eq a => [a] -> [(a, Bool)] -> Bool
        go :: forall a. Eq a => [a] -> [(a, Bool)] -> Bool
go []     [(a, Bool)]
_  = Bool
True
        go (a
x:[a]
xs) [(a, Bool)]
ys = Bool
found Bool -> Bool -> Bool
&& [a] -> [(a, Bool)] -> Bool
forall a. Eq a => [a] -> [(a, Bool)] -> Bool
go [a]
xs [(a, Bool)]
ys'
           where (Bool
found, [(a, Bool)]
ys') = a -> [(a, Bool)] -> (Bool, [(a, Bool)])
forall a. Eq a => a -> [(a, Bool)] -> (Bool, [(a, Bool)])
mark a
x [(a, Bool)]
ys

        -- Go and mark off an instance of 'x' in the list, if possible. We keep track
        -- of unmarked elements by associating a boolean bit. Note that we have to
        -- keep the lists equal size for the recursive result to merge properly.
        mark :: Eq a => a -> [(a, Bool)] -> (Bool, [(a, Bool)])
        mark :: forall a. Eq a => a -> [(a, Bool)] -> (Bool, [(a, Bool)])
mark a
_ []            = (Bool
False, [])
        mark a
x ((a
y, Bool
v) : [(a, Bool)]
ys)
          | Bool
v Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y      = (Bool
True, (a
y, Bool -> Bool
not Bool
v) (a, Bool) -> [(a, Bool)] -> [(a, Bool)]
forall a. a -> [a] -> [a]
: [(a, Bool)]
ys)
          | Bool
True             = (Bool
r, (a
y, Bool
v) (a, Bool) -> [(a, Bool)] -> [(a, Bool)]
forall a. a -> [a] -> [a]
: [(a, Bool)]
ys')
          where (Bool
r, [(a, Bool)]
ys') = a -> [(a, Bool)] -> (Bool, [(a, Bool)])
forall a. Eq a => a -> [(a, Bool)] -> (Bool, [(a, Bool)])
mark a
x [(a, Bool)]
ys

-----------------------------------------------------------------------------
-- * The correctness theorem
-----------------------------------------------------------------------------

-- | Asserting correctness of merge-sort for a list of the given size. Note that we can
-- only check correctness for fixed-size lists. Also, the proof will get more and more
-- complicated for the backend SMT solver as @n@ increases. Here we try it with 4.
--
-- We have:
--
-- @
--   [SBV] tests/T48.hs:100:1-16 Proving "mergeSortCorrect", using Z3.
--   [Z3] Q.E.D.
-- @
{-# ANN mergeSortCorrect theorem { options = [ListSize 4] } #-}
mergeSortCorrect :: [Int] -> Bool
mergeSortCorrect :: [Int] -> Bool
mergeSortCorrect [Int]
xs = [Int] -> Bool
forall a. Ord a => [a] -> Bool
nonDecreasing [Int]
ys Bool -> Bool -> Bool
&& [Int] -> [Int] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
isPermutationOf [Int]
xs [Int]
ys
   where ys :: [Int]
ys = [Int] -> [Int]
forall a. Ord a => [a] -> [a]
mergeSort [Int]
xs