Program to find sum of minimum trees from the list of leaves in python

Suppose we have a list of numbers called nums. This list represents the leaf nodes in inorder traversal of a binary tree. The internal nodes have exactly 2 children and their value equals the product of the largest leaf value of its left subtree and the largest leaf value of its right subtree. We need to find the sum of all values in the tree that produces the minimum total sum.

So, if the input is like nums = [3, 5, 10], then the output will be 83.

50 15 10 3 5 Tree Sum = 3 + 5 + 10 + 15 + 50 = 83 Leaves: Leaf nodes Internal: Internal nodes

Algorithm

To solve this problem, we follow these steps ?

  • res := sum of all elements in nums
  • While size of nums > 1, do:
    • i := index of minimum element of nums
    • left := nums[i - 1] when i > 0, otherwise infinity
    • right := nums[i + 1] when i
    • res := res + (minimum of left and right) * ith element of nums, then delete ith element from nums
  • Return res

Example

Let's implement this algorithm to find the minimum sum tree ?

class Solution:
    def solve(self, nums):
        res = sum(nums)
        while len(nums) > 1:
            i = nums.index(min(nums))
            left = nums[i - 1] if i > 0 else float("inf")
            right = nums[i + 1] if i < len(nums) - 1 else float("inf")
            res += min(left, right) * nums.pop(i)
        
        return res

# Test the solution
ob = Solution()
nums = [3, 5, 10]
print(f"Input: {nums}")
print(f"Minimum sum tree: {ob.solve(nums)}")
Input: [3, 5, 10]
Minimum sum tree: 83

How It Works

The algorithm works by repeatedly removing the smallest leaf and adding the cost of creating its parent node. When we remove a leaf, its parent's value becomes the product of the maximum values in its left and right subtrees. We choose the smaller neighbor to minimize the total cost.

Step-by-Step Execution

def solve_with_steps(nums):
    print(f"Initial nums: {nums}")
    res = sum(nums)
    print(f"Initial sum: {res}")
    
    step = 1
    while len(nums) > 1:
        print(f"\nStep {step}:")
        i = nums.index(min(nums))
        min_val = nums[i]
        left = nums[i - 1] if i > 0 else float("inf")
        right = nums[i + 1] if i < len(nums) - 1 else float("inf")
        
        cost = min(left, right) * min_val
        res += cost
        
        print(f"  Remove minimum: {min_val} at index {i}")
        print(f"  Left neighbor: {left}, Right neighbor: {right}")
        print(f"  Cost added: {min(left, right)} × {min_val} = {cost}")
        print(f"  Running sum: {res}")
        
        nums.pop(i)
        print(f"  Remaining nums: {nums}")
        step += 1
    
    return res

# Demonstrate with example
nums = [3, 5, 10]
result = solve_with_steps(nums.copy())
print(f"\nFinal result: {result}")
Initial nums: [3, 5, 10]
Initial sum: 18

Step 1:
  Remove minimum: 3 at index 0
  Left neighbor: inf, Right neighbor: 5
  Cost added: 5 × 3 = 15
  Running sum: 33
  Remaining nums: [5, 10]

Step 2:
  Remove minimum: 5 at index 0
  Left neighbor: inf, Right neighbor: 10
  Cost added: 10 × 5 = 50
  Running sum: 83
  Remaining nums: [10]

Final result: 83

Conclusion

This greedy algorithm efficiently finds the minimum sum tree by always removing the smallest leaf first and choosing the smaller neighbor to minimize the cost. The time complexity is O(n²) due to finding minimum elements repeatedly.

Updated on: 2026-03-25T12:42:21+05:30

299 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements