-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SBV.Plugin.Examples.MergeSort
-- Copyright   :  (c) Levent Erkok
-- License     :  BSD3
-- Maintainer  :  erkokl@gmail.com
-- Stability   :  experimental
--
-- An implementation of merge-sort and its correctness.
-----------------------------------------------------------------------------

{-# LANGUAGE CPP #-}

#ifndef HADDOCK
{-# OPTIONS_GHC -fplugin=Data.SBV.Plugin #-}
#endif

{-# OPTIONS_GHC -Wall -Werror #-}

module Data.SBV.Plugin.Examples.MergeSort where

#ifndef HADDOCK
import Data.SBV.Plugin
#endif

-----------------------------------------------------------------------------
-- * Implementing merge-sort
-- ${mergeSort}
-----------------------------------------------------------------------------
{- $mergeSort
A straightforward implementation of merge sort. We simply divide the input list
in to two halves so long as it has at least two elements, sort each half on its
own, and then merge.
-}

-- | Merging two given sorted lists, preserving the order.
merge :: [Int] -> [Int] -> [Int]
merge :: [Int] -> [Int] -> [Int]
merge []     [Int]
ys           = [Int]
ys
merge [Int]
xs     []           = [Int]
xs
merge xs :: [Int]
xs@(Int
x:[Int]
xr) ys :: [Int]
ys@(Int
y:[Int]
yr)
  | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
y                 = Int
x Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
merge [Int]
xr [Int]
ys
  | Bool
True                  = Int
y Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
merge [Int]
xs [Int]
yr

-- | Simple merge-sort implementation.
mergeSort :: [Int] -> [Int]
mergeSort :: [Int] -> [Int]
mergeSort []  = []
mergeSort [Int
x] = [Int
x]
mergeSort [Int]
xs  = [Int] -> [Int] -> [Int]
merge ([Int] -> [Int]
mergeSort [Int]
th) ([Int] -> [Int]
mergeSort [Int]
bh)
   where ([Int]
th, [Int]
bh) = [Int] -> ([Int], [Int]) -> ([Int], [Int])
halve [Int]
xs ([], [])
         halve :: [Int] -> ([Int], [Int]) -> ([Int], [Int])
         halve :: [Int] -> ([Int], [Int]) -> ([Int], [Int])
halve []     ([Int], [Int])
sofar    = ([Int], [Int])
sofar
         halve (Int
a:[Int]
as) ([Int]
fs, [Int]
ss) = [Int] -> ([Int], [Int]) -> ([Int], [Int])
halve [Int]
as ([Int]
ss, Int
aInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
fs)

-----------------------------------------------------------------------------
-- * Proving correctness of sorting
-- ${props}
-----------------------------------------------------------------------------
{- $props
There are two main parts to proving that a sorting algorithm is correct:

       * Prove that the output is non-decreasing
 
       * Prove that the output is a permutation of the input
-}

-- | Check whether a given sequence is non-decreasing.
nonDecreasing :: [Int] -> Bool
nonDecreasing :: [Int] -> Bool
nonDecreasing []       = Bool
True
nonDecreasing [Int
_]      = Bool
True
nonDecreasing (Int
a:Int
b:[Int]
xs) = Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
b Bool -> Bool -> Bool
&& [Int] -> Bool
nonDecreasing (Int
bInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
xs)

-- | Check whether two given sequences are permutations. We simply check that each sequence
-- is a subset of the other, when considered as a set. The check is slightly complicated
-- for the need to account for possibly duplicated elements.
isPermutationOf :: [Int] -> [Int] -> Bool
isPermutationOf :: [Int] -> [Int] -> Bool
isPermutationOf [Int]
as [Int]
bs = [Int] -> [(Int, Bool)] -> Bool
go [Int]
as [(Int
b, Bool
True) | Int
b <- [Int]
bs] Bool -> Bool -> Bool
&& [Int] -> [(Int, Bool)] -> Bool
go [Int]
bs [(Int
a, Bool
True) | Int
a <- [Int]
as]
  where go :: [Int] -> [(Int, Bool)] -> Bool
        go :: [Int] -> [(Int, Bool)] -> Bool
go []     [(Int, Bool)]
_  = Bool
True
        go (Int
x:[Int]
xs) [(Int, Bool)]
ys = Bool
found Bool -> Bool -> Bool
&& [Int] -> [(Int, Bool)] -> Bool
go [Int]
xs [(Int, Bool)]
ys'
           where (Bool
found, [(Int, Bool)]
ys') = Int -> [(Int, Bool)] -> (Bool, [(Int, Bool)])
mark Int
x [(Int, Bool)]
ys

        -- Go and mark off an instance of 'x' in the list, if possible. We keep track
        -- of unmarked elements by associating a boolean bit. Note that we have to
        -- keep the lists equal size for the recursive result to merge properly.
        mark :: Int -> [(Int, Bool)] -> (Bool, [(Int, Bool)])
        mark :: Int -> [(Int, Bool)] -> (Bool, [(Int, Bool)])
mark Int
_ []            = (Bool
False, [])
        mark Int
x ((Int
y, Bool
v) : [(Int, Bool)]
ys)
          | Bool
v Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y      = (Bool
True, (Int
y, Bool -> Bool
not Bool
v) (Int, Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. a -> [a] -> [a]
: [(Int, Bool)]
ys)
          | Bool
True             = (Bool
r, (Int
y, Bool
v) (Int, Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. a -> [a] -> [a]
: [(Int, Bool)]
ys')
          where (Bool
r, [(Int, Bool)]
ys') = Int -> [(Int, Bool)] -> (Bool, [(Int, Bool)])
mark Int
x [(Int, Bool)]
ys

-----------------------------------------------------------------------------
-- * The correctness theorem
-----------------------------------------------------------------------------

-- | Asserting correctness of merge-sort for a list of the given size. Note that we can
-- only check correctness for fixed-size lists. Also, the proof will get more and more
-- complicated for the backend SMT solver as @n@ increases. Here we try it with 4.
--
-- We have:
--
-- @
--   [SBV] tests/T48.hs:100:1-16 Proving "mergeSortCorrect", using Z3.
--   [Z3] Q.E.D.
-- @
#ifndef HADDOCK
{-# ANN mergeSortCorrect theorem { options = [ListSize 4] } #-}
#endif
mergeSortCorrect :: [Int] -> Bool
mergeSortCorrect :: [Int] -> Bool
mergeSortCorrect [Int]
xs = [Int] -> Bool
nonDecreasing [Int]
ys Bool -> Bool -> Bool
&& [Int] -> [Int] -> Bool
isPermutationOf [Int]
xs [Int]
ys
   where ys :: [Int]
ys = [Int] -> [Int]
mergeSort [Int]
xs