{-# Language RankNTypes #-}
{-# Language StandaloneDeriving #-}

module RevdepScanner.Types.HashSet.NonEmpty
    ( NonEmptyHashSet
    , fromHashSet
    , toHashSet
    , singleton
    , insert
    , union
    , fromList
    ) where

import Data.Foldable (foldl')
import Data.Hashable (Hashable, Hashed, hashed, unhashed)
import           Data.HashSet (HashSet)
import qualified Data.HashSet as S

-- | A 'HashSet', but it must contain at least one element
data NonEmptyHashSet a
    = Hashable a => NEHashSet (Hashed a) (HashSet a)

deriving instance Show a => Show (NonEmptyHashSet a)
deriving instance Eq (NonEmptyHashSet a)
deriving instance Ord a => Ord (NonEmptyHashSet a)

instance Foldable NonEmptyHashSet where
    foldMap f = foldMap f . toHashSet

instance Semigroup (NonEmptyHashSet a) where
    (<>) = union

-- | @O(n)@
fromHashSet :: Hashable a => HashSet a -> Maybe (NonEmptyHashSet a)
fromHashSet = fromList . S.toList

-- | @O(log n)@
toHashSet :: NonEmptyHashSet a -> HashSet a
toHashSet (NEHashSet x s) = S.insert (unhashed x) s

-- | @O(1)@
singleton :: Hashable a => a -> NonEmptyHashSet a
singleton x = NEHashSet (hashed x) S.empty

-- | @O(log n)@
insert :: Hashable a => a -> NonEmptyHashSet a -> NonEmptyHashSet a
insert x nes@(NEHashSet y s)
    | hashed x == y = nes
    | otherwise = NEHashSet y (S.insert x s)

-- | @O(n+m)@
union :: NonEmptyHashSet a -> NonEmptyHashSet a -> NonEmptyHashSet a
union nes1@(NEHashSet x1 s1) (NEHashSet x2 s2)
    | x1 == x2 = NEHashSet x1 (s1 <> s2)
    | otherwise = NEHashSet x1 (toHashSet nes1 <> s2)

-- | @O(n)@
fromList :: Hashable a => [a] -> Maybe (NonEmptyHashSet a)
fromList [] = Nothing
fromList (x:xs) = Just $ foldl' (flip insert) (singleton x) xs
