{-# LANGUAGE TypeFamilies            #-}
{-# LANGUAGE MultiParamTypeClasses   #-}
{-# LANGUAGE DataKinds               #-}
{-# LANGUAGE KindSignatures          #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE FlexibleInstances       #-}
{-# LANGUAGE FlexibleContexts        #-}
{-# LANGUAGE ScopedTypeVariables     #-}
{-# LANGUAGE TemplateHaskell         #-}
{-# LANGUAGE UndecidableInstances    #-}
{-# LANGUAGE TypeApplications        #-}
{-# LANGUAGE TypeInType              #-}
{-# LANGUAGE AllowAmbiguousTypes     #-}
{-# LANGUAGE NoImplicitPrelude       #-}
{-# LANGUAGE GADTs                   #-}
{-# LANGUAGE ConstraintKinds         #-}

{-# OPTIONS_GHC -fno-solve-constant-dicts #-} -- See https://ghc.haskell.org/trac/ghc/ticket/13943#comment:2

-----------------------------------------------------------------------------
-- |
-- Module      :  Data.List.Unrolled
-- Copyright   :  (C) 2017 Alexey Vagarenko
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Alexey Vagarenko (vagarenko@gmail.com)
-- Stability   :  experimental
-- Portability :  non-portable
--
-- This module provides unrollable versions of functions on lists.
--
-- Classes in this module are assumed to be closed. You __should not__ create
-- new instances for them.
--
----------------------------------------------------------------------------

module Data.List.Unrolled (
      Append(..)
    , Drop(..)
    , Take(..)
    , splitAt
    , ChunksOf(..)
    , ChunksCount
    , Zip(..)
    , Zip3(..)
    , ZipWith(..)
    , Unzip(..)
    , Filter(..)
    , Map(..)
    , All(..)
    , Foldr(..)
    , Foldr1(..)
    , Foldl(..)
    , Foldl1(..)
    , foldMap
    , FoldMap
    , sum
    , Sum
    , Replicate(..)
    , EnumFromN(..)
    , EnumFromStepN(..)
) where

import Data.Type.Bool           (If)
import GHC.TypeLits             (Nat, type (+), type (-), type (<=?))

import Prelude  (Bool(..), otherwise, Num(..), error, Monoid(..), (.))

---------------------------------------------------------------------------------------------------
-- | Append two lists. Type param @l@ is the length of the left list.
class Append (n :: Nat) where
    append :: [a] -> [a] -> [a]

instance {-# OVERLAPPING #-} Append 0 where
    append _ ys = ys
    {-# INLINE append #-}

instance {-# OVERLAPPABLE #-} (Append (n - 1)) => Append n where
    append []       _  = error "append: Not enough elements in the list."
    append (x : xs) ys = x : append @(n - 1) xs ys
    {-# INLINE append #-}

---------------------------------------------------------------------------------------------------
-- | Drop @n@ elements from a list.
class Drop (n :: Nat) where
    drop :: [a] -> [a]

instance {-# OVERLAPPING #-} Drop 0 where
    drop xs = xs
    {-# INLINE drop #-}

instance {-# OVERLAPPABLE #-} (Drop (n - 1)) => Drop n where
    drop [] = error "drop: Not enough elements in the list."
    drop (_ : xs) = drop @(n - 1) xs
    {-# INLINE drop #-}

---------------------------------------------------------------------------------------------------
-- | Take @n@ elements from a list
class Take (n :: Nat) where
    take :: [a] -> [a]

instance {-# OVERLAPPING #-} Take 0 where
    take _ = []
    {-# INLINE take #-}

instance {-# OVERLAPPABLE #-} (Take (n - 1)) => Take n where
    take [] = error "take: Not enough elements in the list."
    take (x : xs) = x : take @(n - 1) xs
    {-# INLINE take #-}

---------------------------------------------------------------------------------------------------
-- | Split list at @n@-th element.
splitAt :: forall (n :: Nat) a. (Take n, Drop n) => [a] -> ([a], [a])
splitAt xs = (take @n xs, drop @n xs)

---------------------------------------------------------------------------------------------------
-- | Split list into chunks of the given length @c@. @n@ is length of the list.
class ChunksOf (n :: Nat) (c :: Nat) where
    chunksOf :: [a] -> [[a]]

instance {-# OVERLAPPING #-} ChunksOf 0 0 where
    chunksOf _ = []
    {-# INLINE chunksOf #-}

instance {-# OVERLAPPABLE #-} ChunksOf 0 c where
    chunksOf _ = []
    {-# INLINE chunksOf #-}

instance {-# OVERLAPPABLE #-} ChunksOf n 0 where
    chunksOf _ = []
    {-# INLINE chunksOf #-}

instance {-# OVERLAPPABLE #-} (Take c, Drop c, ChunksOf (n - 1) c) => ChunksOf n c where
    chunksOf xs =
        let (l, r) = splitAt @c xs
        in l : chunksOf @(n - 1) @c r
    {-# INLINE chunksOf #-}

-- | Number of resulting chunks when list of length @len@ split by chunks of length @clen@.
type family ChunksCount (len :: Nat) (clen :: Nat) where
    ChunksCount 0 _ = 0
    ChunksCount _ 0 = 0
    ChunksCount l c = If (l <=? c) 1 (1 + ChunksCount (l - c) c)

---------------------------------------------------------------------------------------------------
-- | Zip 2 lists together. Type param @n@ is the length of the first list.
class Zip (n :: Nat) where
    zip :: [a] -> [b] -> [(a, b)]

instance {-# OVERLAPPING #-} Zip 0 where
    zip _ _ = []
    {-# INLINE zip #-}

instance {-# OVERLAPPABLE #-} (Zip (n - 1)) => Zip n where
    zip (x : xs) (y : ys) = (x, y) : zip @(n - 1) xs ys
    zip (_ : _ ) []       = []
    zip []        _       = error "zip: Not enough elements in the first list."
    {-# INLINE zip #-}

---------------------------------------------------------------------------------------------------
-- | Zip 3 lists together. Type param @n@ is the length of the first list.
class Zip3 (n :: Nat) where
    zip3 :: [a] -> [b] -> [c] -> [(a, b, c)]

instance {-# OVERLAPPING #-} Zip3 0 where
    zip3 _ _ _ = []
    {-# INLINE zip3 #-}

instance {-# OVERLAPPABLE #-} (Zip3 (n - 1)) => Zip3 n where
    zip3 (x : xs) (y : ys) (z : zs) = (x, y, z) : zip3 @(n - 1) xs ys zs
    zip3 (_ : _ ) []       _        = []
    zip3 (_ : _ ) _        []       = []
    zip3 []       _        _        = error "zip3: Not enough elements in the first list."
    {-# INLINE zip3 #-}

---------------------------------------------------------------------------------------------------
-- | Unzip a list. Type param @n@ is the length of the list.
class Unzip (n :: Nat) where
    unzip :: [(a, b)] -> ([a], [b])

instance {-# OVERLAPPING #-} Unzip 0 where
    unzip _ = ([], [])
    {-# INLINE unzip #-}

instance {-# OVERLAPPABLE #-} (Unzip (n - 1)) => Unzip n where
    unzip []       = error "unzip: Not enough elements in the list."
    unzip (x : xs) = (\(a, b) (as, bs) -> (a : as, b : bs)) x (unzip @(n - 1) xs)
    {-# INLINE unzip #-}

---------------------------------------------------------------------------------------------------
-- | Zip 2 lists together using given function. Type param @n@ is the length of the first list.
class ZipWith (n :: Nat) where
    zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]

instance {-# OVERLAPPING #-} ZipWith 0 where
    zipWith _ _ _ = []
    {-# INLINE zipWith #-}

instance {-# OVERLAPPABLE #-} (ZipWith (n - 1)) => ZipWith n where
    zipWith f (x : xs) (y : ys) = f x y : zipWith @(n - 1) f xs ys
    zipWith _ (_ : _ ) []       = []
    zipWith _ []       _        = error "zipWith: Not enough elements in the first list."
    {-# INLINE zipWith #-}

---------------------------------------------------------------------------------------------------
-- | Filter list with given predicate. Type param @n@ is the length of the list.
class Filter (n :: Nat) where
    filter :: (a -> Bool) -> [a] -> [a]

instance {-# OVERLAPPING #-} Filter 0 where
    filter _ _ = []
    {-# INLINE filter #-}

instance {-# OVERLAPPABLE #-} (Filter (n - 1)) => Filter n where
    filter _ []       = error "filter: Not enough elements in the list."
    filter f (x : xs)
        | f x       = x : filter @(n - 1) f xs
        | otherwise = filter @(n - 1) f xs
    {-# INLINE filter #-}

---------------------------------------------------------------------------------------------------
-- | Apply function to all elements of a list. Type param @n@ is the length of the list.
class Map (n :: Nat) where
    map :: (a -> b) -> [a] -> [b]

instance {-# OVERLAPPING #-} Map 0 where
    map _ _ = []
    {-# INLINE map #-}

instance {-# OVERLAPPABLE #-} (Map (n - 1)) => Map n where
    map _ []       = error "map: Not enough elements in the list."
    map f (x : xs) = f x : map @(n - 1) f xs
    {-# INLINE map #-}

---------------------------------------------------------------------------------------------------
-- | Check if all elements of the list satisfy the predicate. Type param @n@ is the length of the list.
class All (n :: Nat) where
    all :: (a -> Bool) -> [a] -> Bool

instance {-# OVERLAPPING #-} All 0 where
    all _ _ = True
    {-# INLINE all #-}

instance {-# OVERLAPPABLE #-} (All (n - 1)) => All n where
    all _ []        = error "all: Not enough elements in the list."
    all f (x : xs)
        | f x       = all @(n - 1) f xs
        | otherwise = False
    {-# INLINE all #-}

---------------------------------------------------------------------------------------------------
-- | Right fold of a list of length @n@.
class Foldr (n :: Nat) where
    foldr :: (a -> b -> b) -> b -> [a] -> b

instance {-# OVERLAPPING #-} Foldr 0 where
    foldr _ z _ = z
    {-# INLINE foldr #-}

instance {-# OVERLAPPABLE #-} (Foldr (n - 1)) => Foldr n where
    foldr _ _ []       = error "foldr: Not enough elements in the list."
    foldr f z (x : xs) = f x (foldr @(n - 1) f z xs)
    {-# INLINE foldr #-}

---------------------------------------------------------------------------------------------------
-- | Right fold of a list of length @n@ with no base element.
class Foldr1 (n :: Nat) where
    foldr1 :: (a -> a -> a) -> [a] -> a

instance {-# OVERLAPPING #-} Foldr1 1 where
    foldr1 _ []      = error "foldr1: Not enough elements in the list."
    foldr1 _ (x : _) = x
    {-# INLINE foldr1 #-}

instance {-# OVERLAPPABLE #-} (Foldr1 (n - 1)) => Foldr1 n where
    foldr1 _ []       = error "foldr1: Empty list."
    foldr1 f (x : xs) = f x (foldr1 @(n - 1) f xs)
    {-# INLINE foldr1 #-}

---------------------------------------------------------------------------------------------------
-- | Left fold of a list of length @n@.
class Foldl (n :: Nat) where
    foldl :: (b -> a -> b) -> b -> [a] -> b

instance {-# OVERLAPPING #-} Foldl 0 where
    foldl _ z _ = z
    {-# INLINE foldl #-}

instance {-# OVERLAPPABLE #-} (Foldl (n - 1)) => Foldl n where
    foldl _ _ []       = error "foldl: Not enough elements in the list."
    foldl f z (x : xs) = f (foldl @(n - 1) f z xs) x
    {-# INLINE foldl #-}

---------------------------------------------------------------------------------------------------
-- | Right fold of a list of length @n@ with no base element.
class Foldl1 (n :: Nat) where
    foldl1 :: (a -> a -> a) -> [a] -> a

instance {-# OVERLAPPING #-} Foldl1 1 where
    foldl1 _ []      = error "foldl1: Not enough elements in the list."
    foldl1 _ (x : _) = x
    {-# INLINE foldl1 #-}

instance {-# OVERLAPPABLE #-} (Foldl1 (n - 1)) => Foldl1 n where
    foldl1 _ []       = error "foldl1: Empty list."
    foldl1 f (x : xs) = f (foldl1 @(n - 1) f xs) x
    {-# INLINE foldl1 #-}

---------------------------------------------------------------------------------------------------
-- | Map each element of the list of length @n@ to a monoid, and combine the results.
foldMap :: forall (n :: Nat) m a.
           (FoldMap n m) =>
           (a -> m) -> [a] -> m
foldMap f = foldr @n (mappend . f) mempty
{-# INLINE foldMap #-}

-- | Constraint of the 'foldMap' function.
type FoldMap (n :: Nat) m = (Monoid m, Foldr n)

---------------------------------------------------------------------------------------------------
-- | Sum of the elements of the list of length @n@.
sum :: forall (n :: Nat) a.
       (Sum n a) =>
       [a] -> a
sum = foldr @n (+) 0
{-# INLINE sum #-}

-- | Constraint of the 'sum' function.
type Sum (n :: Nat) a = (Foldr n, Num a)

---------------------------------------------------------------------------------------------------
-- | Fill the list of length @n@ with the same values.
class Replicate (n :: Nat) where
    replicate :: a -> [a]

instance {-# OVERLAPPING #-} Replicate 0 where
    replicate _ = []
    {-# INLINE replicate #-}

instance {-# OVERLAPPABLE #-} (Replicate (n - 1)) => Replicate n where
    replicate a = a : replicate @(n - 1) a
    {-# INLINE replicate #-}

---------------------------------------------------------------------------------------------------
-- | Enumeration of length @n@ starting from given value.
class EnumFromN (n :: Nat) where
    enumFromN :: (Num a)
        => a        -- ^ Starting value.
        -> [a]

instance {-# OVERLAPPING #-} EnumFromN 0 where
    enumFromN _ = []
    {-# INLINE enumFromN #-}

instance {-# OVERLAPPABLE #-} (EnumFromN (n - 1)) => EnumFromN n where
    enumFromN a = a : enumFromN @(n - 1) (a + 1)
    {-# INLINE enumFromN #-}

---------------------------------------------------------------------------------------------------
-- | Enumeration of length @n@ starting from given value with given step.
class EnumFromStepN (n :: Nat) where
    enumFromStepN :: (Num a)
        => a        -- ^ Starting value.
        -> a        -- ^ Step.
        -> [a]

instance {-# OVERLAPPING #-} EnumFromStepN 0 where
    enumFromStepN _ _ = []
    {-# INLINE enumFromStepN #-}

instance {-# OVERLAPPABLE #-} (EnumFromStepN (n - 1)) => EnumFromStepN n where
    enumFromStepN a s = a : enumFromStepN @(n - 1) (a + s) s
    {-# INLINE enumFromStepN #-}