One more fairly standard tech interview problem (for better or for worse, you’re likely to see one of these if you go for a programming job):
Given a linked list and an element
x
. Partition the list so that all elements less thanx
are before elements greater than or equal tox
, but do not otherwise change the order of the elements.
Interviewers really do love their linked lists. And it does make sense, since they can be a very efficient data structure for center kinds of problems, such as moving around and inserting elements in the end where more ’traditional’ straight arrays would have problems (to insert in the middle, you have to move every element).
Python doesn’t directly have a built in linked list class (lists in python are sort of a hybrid), but it’s easy enough to make one! Especially with dataclasses
. Those are wonderful for boiler plate. Decorate the class with @dataclasses.dataclass
and you get a constructor, stringification, and comparison for free! I added a method to turn a traditional python list into a linked list, but that’s it:
@dataclass
class ListNode:
val: int = 0
next: 'ListNode' = None
@staticmethod
def from_iterable(ls):
head = None
previous = None
for el in ls:
current = ListNode(el)
if previous:
previous.next = current
else:
head = current
previous = current
return head
Now for the actual solution, my goal will be essentially to make 5 different pointers into the list and iterate along them at different speeds:
current
will store where we are in our progress through the listlt_head
will store the first node in the eventual less than half of the list, this will become the new head of the listlt_current
will store the current tail of the less than half of the list, add each ’lt’ element here and advance itgte_head
will store the head of the greater than or equal to half, this will get attached to the lastlt_current
gte_current
will be the tail of the second half of the list, add each ‘gte’ element here
And that’s pretty much the entire algorithm:
- For each element:
- If it’s less than the pivot point, add it to the
lt_current
’s next and move that pointer forward - Otherwise do the same for the
gte_current
- If it’s less than the pivot point, add it to the
- Store the first
lt
node inlt_head
and the firstgte
node ingte_head
That helps us deal with the corner cases:
- If there are no elements less than the pivot, the final head node will be
gte_head
(since bothlt_head
andlt_current
will be null). Vice versa if all of the elements are less than, although in that case, you can just combine the two anyways, sincegte_head
will beNone
and aNone
element in.next
means the same thing as no element.
Code!
def partition(self, head: ListNode, x: int) -> ListNode:
current = head
# Create two linked lists: lt (less than) and gte (greater than or equal)
# The head is the beginning of the new linked list, the current node is the tail we're building
lt_head = None
lt_current = None
gte_head = None
gte_current = None
# Iterate through the current list
while current:
# Add to the lt partition
# If we already have a tail, connect it to this node
# If we don't, this is the new head
if current.val < x:
if lt_current:
lt_current.next = current
lt_current = current
if not lt_head:
lt_head = lt_current
# Otherwise, add it to the gte partition with the same conditions
else:
if gte_current:
gte_current.next = current
gte_current = current
if not gte_head:
gte_head = gte_current
# Advance on the main iteration
current = current.next
# If the gte_current wasn't the last node, this points somewhere wrong
# Since this is the current tail, None the next pointer
if gte_current:
gte_current.next = None
# So long as at least one element is in the lt partition, connect the two
if lt_current:
lt_current.next = gte_head
# Return the lt partition's head if it exists, otherwise there are only gte nodes
return lt_head or gte_head
I thought that was a pretty cool algorithm.
Some test cases:
class TestSolution(unittest.TestCase):
def test_1(self):
self.assertEqual(
partition(ListNode.from_iterable([1,4,3,2,5,2]), 3),
ListNode.from_iterable([1,2,2,4,3,5])
)
def test_2(self):
self.assertEqual(
partition(ListNode.from_iterable([2,1]), 2),
ListNode.from_iterable([1, 2])
)
def test_3(self):
self.assertEqual(
partition(ListNode.from_iterable([1,4,3,2,5,2]), 3),
ListNode.from_iterable([1, 2, 2, 4, 3, 5])
)
if __name__ == '__main__':
unittest.main()
Quick:
...
----------------------------------------------------------------------
Ran 3 tests in 0.000s
OK