# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def rob(self, root: TreeNode) -> int:
        from functools import lru_cache
        @lru_cache(maxsize=None)
        def dfs(x):
            # 空节点
            if not x: return 0
            # 叶节点
            if not x.left and not x.right: return x.val
            # 不抢当前节点
            r1 = dfs(x.left) + dfs(x.right)
            # 抢当前节点
            r2 = x.val
            if x.left:
                r2 += dfs(x.left.left) + dfs(x.left.right)
            if x.right:
                r2 += dfs(x.right.left) + dfs(x.right.right)
            
            return max(r1, r2)
        
        return dfs(root)