diff --git a/monads/__init__.py b/monads/__init__.py
index a1f7c9e..e5c7f92 100644
--- a/monads/__init__.py
+++ b/monads/__init__.py
@@ -2,6 +2,7 @@ from .functor import Functor
 from .applicative import Applicative
 from .monad import Monad
 from .list import List
+from .set import Set
 from .maybe import Maybe, Just, Nothing
 from .result import Result, Ok, Err
 from .future import Future
diff --git a/monads/set.py b/monads/set.py
new file mode 100644
index 0000000..11c3e5a
--- /dev/null
+++ b/monads/set.py
@@ -0,0 +1,99 @@
+from __future__ import annotations
+from functools import reduce
+from itertools import chain
+from monads import functor, List
+from typing import (
+    Callable,
+    Iterable,
+    Iterator,
+    Set as _Set,
+    List as _List,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
+
+from .monad import Monad
+from .monoid import Monoidal
+from .currying import CurriedBinary, uncurry
+
+T = TypeVar("T")
+S = TypeVar("S")
+
+
+class Set(Monad[T], Monoidal[set]):
+    @classmethod
+    def pure(cls, value: T) -> Set[T]:
+        t = set()
+        t.add(value)
+        return Set(t)
+
+    def bind(self, function: Callable[[T], Set[S]]) -> Set[S]:
+        return reduce(Set.mappend, map(function, self.value), Set.mzero())
+
+    def map(self, function: Callable[[T], S]) -> Set[S]:
+        return Set(set(map(function, self.value)))
+
+    def apply(self, functor: Set[Callable[[T], S]]) -> Set[S]:
+
+        return Set(
+            set(chain.from_iterable([map(f, self.value) for f in functor.value]))
+        )
+
+    @classmethod
+    def mzero(cls) -> Set[T]:
+        return cls(set())
+
+    @classmethod
+    def sequence(cls, xs: Iterable[Set[T]]) -> Set[_List[T]]:
+        """Evaluate monadic actions in sequence, collecting results."""
+
+        def mcons(acc: Set[_Set[T]], x: Set[T]) -> Set[_Set[T]]:
+            return acc.bind(lambda acc_: x.map(lambda x_: acc_.union(set([x_]))))
+
+        empty: Set[_Set[T]] = Set.pure(set())
+        return Set(set(reduce(mcons, xs, empty)))  # type: ignore
+
+    def flatten(self) -> Set[T]:
+        def flat(acc: Set[T], element: T) -> Set[T]:
+            if element and isinstance(element, Iterable):
+                for k in element:
+                    acc = acc.mappend(Set(set([k])))
+            elif element:
+                acc = acc.mappend(Set(set([element])))
+            return acc
+
+        return Set(reduce(flat, self, Set.mzero()))  # type: ignore
+
+    def sort(self, key: Optional[str] = None, reverse: bool = False) -> Set[T]:
+        lst_copy = self.value.copy()
+        lst_copy.sort(key=key, reverse=reverse)  # type: ignore
+        return Set(lst_copy)
+
+    def fold(
+        self, func: Union[Callable[[S, T], S], CurriedBinary[S, T, S]], base_val: S
+    ) -> S:
+        if isinstance(func, CurriedBinary):
+            functor = uncurry(cast(CurriedBinary, func))
+        else:
+            functor = func
+        return reduce(functor, self.value, base_val)  # type: ignore
+
+    __and__ = lambda other, self: Set.apply(self, other)  # type: ignore
+
+    def mappend(self, other: Set[T]) -> Set[T]:
+        return Set(self.value.union(other.value))
+
+    __add__ = mappend
+    __mul__ = __rmul__ = map
+    __rshift__ = bind
+
+    def __sizeof__(self) -> int:
+        return self.value.__sizeof__()
+
+    def __len__(self) -> int:
+        return len(set(self.value))
+
+    def __iter__(self) -> Iterator[T]:
+        return iter(self.value)
diff --git a/tests/fixtures.py b/tests/fixtures.py
index 79936cf..4caf50c 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -1,8 +1,8 @@
 import pytest  # type: ignore
 from typing import Type
-from monads import Maybe, List, Result
+from monads import Maybe, List, Result, Set
 
 
-@pytest.fixture(scope="module", params=[Maybe, List, Result])
+@pytest.fixture(scope="module", params=[Maybe, List, Result, Set])
 def monad(request) -> Type:
     return request.param