Skip to content

Climbing trees 2: implementing decision trees

Published:

This is the sec­ond in a se­ries of posts about de­ci­sion trees in the con­text of ma­chine learn­ing. In this post, we’ll im­ple­ment clas­si­fi­ca­tion and re­gres­sion trees (CART) in Python. If you haven’t al­ready, con­sider read­ing the first part of this se­ries: it dis­cusses fun­da­men­tal con­cepts that are re­quired to un­der­stand the im­ple­men­ta­tion we’ll build.

Climb­ing trees se­ries


If things be­come too com­pli­cated, try to read the pro­vided ref­er­ences. I’ve drawn upon var­i­ous sources in­stru­men­tal to my un­der­stand­ing of de­ci­sion trees, in­clud­ing books, doc­u­men­ta­tion, ar­ti­cles, blog posts and lec­tures. Even if you un­der­stand every­thing, check the ref­er­ences: there is great con­tent there.

The code snip­pets below some­times present par­tial im­ple­men­ta­tion for brevity or to avoid rep­e­ti­tion. Com­plete code is avail­able at the climb­ing trees repos­i­tory.

Defin­ing our build­ing blocks

We will use Python’s stan­dard li­brary and numpy and pan­das as the only ex­ter­nal de­pen­den­cies. This seems like a rea­son­able com­pro­mise be­tween im­ple­ment­ing every­thing from scratch and del­e­gat­ing some parts to ex­ter­nal li­braries (in our case, math­e­mat­i­cal op­er­a­tions and table-​like data struc­tures).

Nodes

First things first: let’s de­fine nodes and leaf nodes.

@dataclass
class LeafNode:
    value: np.ndarray

split_value = float | set

@dataclass
class Node:
    feature_idx: int
    split_value: split_value
    left: Node | LeafNode
    right: Node | LeafNode

Leaf nodes only need to store a value, which is a numpy array. In clas­si­fi­ca­tion set­tings the value is an array of prob­a­bil­i­ties and for re­gres­sion, a single-​element array with a con­stant.

In­ter­nal nodes (re­ferred to only as nodes from here on) need to store extra in­for­ma­tion: the ID of the fea­ture cho­sen to split the node, the split value, the left child node, and the right child node. The split value can be ei­ther a float (for nu­mer­i­cal fea­tures) or a set (for cat­e­gor­i­cal fea­tures). We’ll first im­ple­ment the nu­mer­i­cal fea­ture case and later ex­tend the im­ple­men­ta­tion for the cat­e­gor­i­cal case.

Ob­jec­tive func­tions

Let’s im­ple­ment the three ob­jec­tive func­tions dis­cussed pre­vi­ously: Gini im­pu­rity, en­tropy and squared loss.

def entropy(prob):
    prob = prob[prob > 0]
    return np.sum(-prob * np.log2(prob))

def gini_impurity(prob):
    return 1 - np.sum(prob**2)

Here, both func­tions take an array of prob­a­bil­i­ties as input and re­turn the ob­jec­tive value. No­tice that we ig­nore zero prob­a­bil­ity when cal­cu­lat­ing en­tropy — their con­tri­bu­tion is zero due to the mul­ti­plica­tive term (-prob), but 0 causes nu­mer­i­cal in­sta­bil­ity prob­lems when tak­ing the log.

While eval­u­at­ing a split, we only have the la­bels of sam­ples in each can­di­date node, so it’s con­ve­nient to have func­tions that ac­cept la­bels as input. It’s also con­ve­nient to add sup­port for sam­ple weights (spoiler alert: they’ll come in handy in the fu­ture). Sam­ple weights re­dis­trib­ute the im­por­tance of each sam­ple when com­put­ing the ob­jec­tive func­tion. Du­pli­cat­ing some sam­ples has the same ef­fect as set­ting a sam­ple weight of 2 to them. That is, sam­ples with higher weight will be given more im­por­tance to when find­ing splits.

def _class_probabilities(labels, sample_weights=None):
    if sample_weights is None:
        return np.mean(labels, axis=0)

    sample_weights = sample_weights.reshape((-1, 1))
    return (sample_weights * labels).sum(axis=0) / np.sum(sample_weights)

The de­fault case (no sam­ple weight pro­vided) as­sumes uni­form sam­ple weights, or sim­ply the av­er­age. Now we write two func­tions that com­pute the ob­jec­tive func­tions given la­bels as in­puts:

def entropy_criterion(labels, sample_weights=None):
    return entropy(_class_probabilities(labels, sample_weights))


def gini_criterion(labels, sample_weights=None):
    return gini_impurity(_class_probabilities(labels, sample_weights))

In re­gres­sion set­tings there is no need to com­pute class prob­a­bil­i­ties, there­fore the ob­jec­tive func­tion can be di­rectly im­ple­mented as fol­lows.

def squared_loss_criterion(y, sample_weights=None):
    if sample_weights is None:
        sample_weights = np.ones_like(y) / y.shape[0]

    value = (sample_weights * y) / sample_weights.sum()
    return np.mean(np.power(y - value, 2))

Find­ing the best split

Re­ca­pit­u­lat­ing the de­ci­sion tree al­go­rithm, we greed­ily search for the best pos­si­ble split, re­cur­sively. To find the best pos­si­ble split, we sim­ply test all pos­si­ble splits. For each fea­ture, we sort the val­ues as well as the out­come. Then, for each split (N1N - 1), we cal­cu­late the weighted av­er­age cri­te­rion (the value of the ob­jec­tive func­tion) of child nodes.

class Split(NamedTuple):
    criterion: float
    feature_idx: int
    split_value: split_value
    left_index: np.ndarray
    right_index: np.ndarray
    left_value: np.ndarray
    right_value: np.ndarray


def _find_best_split(
    X, y, criterion_fn, sample_weights, min_samples_leaf
) -> Split | None:
    min_criterion = np.inf
    split = None

    for feat_idx in range(X.shape[1]):
        feature = X[:, feat_idx]
        sort_idx = np.argsort(feature)
        feature_sort = feature[sort_idx]
        y_sort = y[sort_idx]
        weights_sort = sample_weights[sort_idx]

        n_samples = len(sort_idx)
        for idx in range(1, n_samples):
            left = sort_idx[:idx]
            right = sort_idx[idx:]
            criterion_l = criterion_fn(y_sort[:idx], weights_sort[:idx])
            criterion_r = criterion_fn(y_sort[idx:], weights_sort[idx:])
            p_l = (idx) / len(sort_idx)
            p_r = (len(sort_idx) - idx) / len(sort_idx)
            criterion = p_l * criterion_l + p_r * criterion_r
            if criterion < min_criterion:
                min_criterion = criterion
                split = Split(
                    criterion,
                    feat_idx,
                    feature_sort[idx],
                    left,
                    right,
                    np.mean(y_sort[:idx], axis=0),
                    np.mean(y_sort[idx:], axis=0),
                )

    return split

This is the core of the al­go­rithm, so let’s break this func­tion down. First we de­fine some vari­ables: min_criterion is the low­est ob­served cri­te­rion value (ini­tial­ized as in­fin­ity), split is the best split so far (ini­tial­ized as None). The split con­tains the fol­low­ing:

We loop over all fea­tures and, for each one, sort the fea­ture and the out­come based on the se­lected fea­ture val­ues. Then, we di­vide the sam­ples of can­di­date child nodes around each split point (left and right) and com­pute the weighted av­er­aged cri­te­rion. If the cri­te­rion value is the low­est ob­served so far, this is the cur­rent best split point. No­tice that this works for both clas­si­fi­ca­tion and re­gres­sion. The re­gres­sion case is straight­for­ward: each node pre­dicts the av­er­age out­come of its sam­ples. We as­sign the mean out­come value of each child node to left_value and right_value. The clas­si­fi­ca­tion case re­quires an as­sump­tion: the out­come vari­able y must be pre­vi­ously one-​hot en­coded. For in­stance, if we have three classes, in­stead of a vec­tor of class la­bels this func­tions re­ceives a ma­trix with N rows and 3 columns. Each row has only one col­umn set to 1 and the rest to 0 — that is, each col­umn rep­re­sents one class. We can think of this one-​hot en­coded ma­trix as a multi-​output re­gres­sion prob­lem: the model is try­ing to pre­dict three con­tin­u­ous out­comes, it just hap­pens that they en­code class la­bels. Hence, the mean over all N sam­ples of each col­umn is equiv­a­lent to the pro­por­tion of such label in the node.

The greedy al­go­rithm

Now that we have a func­tion to find the best split of a node, we can build the greedy al­go­rithm. We start with a node con­tain­ing all sam­ples and a naive pre­dic­tion (the av­er­age out­come). The child nodes are then split re­cur­sively. The first base case is when there is a sin­gle sam­ple in the node — it can­not be split again.

def split_node(
    node: Node | LeafNode,
    X: pd.DataFrame,
    y: np.ndarray,
    value: np.ndarray,
    depth: int,
    criterion,
    sample_weights: np.ndarray,
) -> LeafNode | Node | None:
    if X.shape[0] <= 1:
        return LeafNode(value)

    split = _find_best_split(X, y, criterion, sample_weights, min_samples_leaf)
    if split is None:
        return None

    X_left = X.iloc[split.left_index, :]
    X_right = X.iloc[split.right_index, :]
    y_left = y[split.left_index]
    y_right = y[split.right_index]

    node = Node(
        split.feature_idx,
        split.split_value,
        LeafNode(split.left_value),
        LeafNode(split.right_value),
    )

    left = split_node(
        node=node,
        X=X_left,
        y=y_left,
        value=split.left_value,
        depth=depth + 1,
        criterion=criterion,
        sample_weights=sample_weights[split.left_index],
    )
    right = split_node(
        node=node,
        X=X_right,
        y=y_right,
        value=split.right_value,
        depth=depth + 1,
        criterion=criterion,
        sample_weights=sample_weights[split.right_index],
    )

    if left is not None:
        node.left = left
    if right is not None:
        node.right = right

    return node

Re­cur­sive calls in­crease the depth counter by 1. We can add a max_depth stop­ping cri­te­rion. Sim­i­larly, we can check the num­ber of sam­ples in child nodes and en­force a min­i­mum num­ber of sam­ples per leaf (min_samples_leaf) con­straint. These con­straints are cru­cial to limit tree size and re­duce vari­ance (over­fit­ting).

def split_node(
    node: Node | LeafNode,
    X: pd.DataFrame,
    y: np.ndarray,
    value: np.ndarray,
    depth: int,
    criterion,
    sample_weights: np.ndarray,
    max_depth: int = 0,
    min_samples_leaf: int = 0,
) -> LeafNode | Node | None:
    if X.shape[0] <= 1 or (max_depth and depth >= max_depth):
        return LeafNode(value)

    if X.shape[0] < 2 * min_samples_leaf:
        return LeafNode(value)

    # [...]

No­tice that if the node has less than twice the min­i­mum num­ber of sam­ples per leaf, all pos­si­ble child nodes will vi­o­late the con­straint. The split search func­tion should also con­sider only splits that do not vi­o­late this con­straint.

We must add an­other base case: a pure node, that is, a node with sam­ples from a sin­gle class (clas­si­fi­ca­tion) or with sam­ples with a sin­gle out­come value (re­gres­sion) should not be split fur­ther. All our ob­jec­tive func­tions yield a value of 0 with pure nodes, so this is easy to im­ple­ment. We can also add the con­straint of a min­i­mum cri­te­rion re­duc­tion — even though I’ve ar­gued in the pre­vi­ous post that this is gen­er­ally a bad idea.

def split_node(
    node: Node | LeafNode,
    X: pd.DataFrame,
    y: np.ndarray,
    value: np.ndarray,
    depth: int,
    criterion,
    sample_weights: np.ndarray,
    max_depth: int = 0,
    min_samples_leaf: int = 0,
    min_criterion_reduction: float = 0,
) -> LeafNode | Node | None:
    if X.shape[0] <= 1 or (max_depth and depth >= max_depth):
        return LeafNode(value)

    if X.shape[0] < 2 * min_samples_leaf:
        return LeafNode(value)

    prior_criterion = criterion.node_impurity(y, sample_weights)
    if np.isclose(prior_criterion, 0):
        return LeafNode(value)

    split = _find_best_split(X, y, criterion, sample_weights, min_samples_leaf)
    if split is None:
        return None

    criterion_reduction = prior_criterion - split.criterion
    if min_criterion_reduction and criterion_reduction < min_criterion_reduction:
        return None

    # [...]

At this point we al­ready have a some­what func­tional de­ci­sion tree con­struc­tor. How­ever, there are still major im­prove­ments to be made to this code: it’s not very op­ti­mized, it can only han­dle nu­mer­i­cal fea­tures, and there’s no pre­dic­tion (in­fer­ence) im­ple­men­ta­tion.

Time com­plex­ity

Im­ple­men­ta­tions of de­ci­sion trees vary quite a lot in their de­tails and the op­ti­miza­tions they em­ploy. Fur­ther­more, the struc­ture of the tree de­pends on the data. Thus, it’s not straight­for­ward to es­ti­mate the av­er­age time com­plex­ity of de­ci­sion trees, but it’s usu­ally ap­prox­i­mated as O(D Nlog2N)O(D\ N \log^2{N}), where DD is the num­ber of di­men­sions (fea­tures) and NN is the num­ber of train­ing sam­ples. For each node, we have to try all pos­si­ble splits, which in­volves sort­ing the train­ing sam­ples based on each fea­ture in O(NlogN)O(N \log{N}) time, then try N1N - 1 splits. We must com­pute the ob­jec­tive func­tion for each split, which in­creases the time com­plex­ity of this step to O(N2)O(N^2). Since this is done for all di­men­sions, we have a time com­plex­ity of O(D N2logN)O(D\ N^2 \log{N}) for each node. This process is re­peated for each new depth level added to the tree, whose max­i­mum depth is ap­prox­i­mately logN\log{N} (bal­anced trees), hence we have a time com­plex­ity of O(D N2log2N)O(D\ N^2 \log^2{N}). The DD and log2Nlog^2 N terms are not bad, but the N2N^2 term makes train­ing trees on large num­ber of sam­ples in­fea­si­ble.

Luck­ily, it’s pos­si­ble to com­pute the ob­jec­tive func­tion in con­stant time, bring­ing us back to the usual O(D N log2N)O(D\ N\ \log^2{N}) time com­plex­ity. With some clever strate­gies we can also reuse the sorted fea­tures, giv­ing us an av­er­age time com­plex­ity of O(D NlogN)O(D\ N \log{N})1.

In our cur­rent im­ple­men­ta­tion, the run time to com­pute the best split per fea­ture is O(N2)O(N^2). For each split (and we have N1N - 1 of them) we have to go over all sam­ples in each child node to com­pute the ob­jec­tive func­tion — this is, ef­fec­tively, a nested loop. We start by sort­ing the sam­ples by fea­ture value, which means that each new split in the loop moves ex­actly one sam­ple from one child node to the other. It’s for this rea­son that we don’t have to go over all sam­ples every time: the ob­jec­tive func­tion can be re­com­puted in con­stant time by mov­ing a sin­gle point for each it­er­a­tion.

We’ll ab­stract our cri­te­rion to ac­com­mo­date dif­fer­ent ways to com­pute it. When search­ing for the best split, we need to keep track of some val­ues de­pend­ing on the cri­te­rion.

@dataclass
class BaseSplitStats:
    left_weight: float
    right_weight: float


S = TypeVar("S", bound=BaseSplitStats)


@dataclass
class ClassificationSplitStats(BaseSplitStats):
    left_class_count: np.ndarray
    right_class_count: np.ndarray

Re­gard­less of the cri­te­rion, we need to keep track of the num­ber of sam­ples in each child node. No­tice they’re named weight and not count be­cause when sam­ple weights are present they are no longer in­te­ger counts. For in­stance, a sam­ple with a weight of 2 acts as if we had an extra copy of it. It can be a bit weird to think about frac­tional counts, but they are an ex­trap­o­la­tion of the in­te­ger ex­am­ple. For in­stance, a sam­ple with weight 2.5 will have 2.5 times more im­pact on the cri­te­rion value.

Our cri­te­rion ab­strac­tion should pro­vide a way to mea­sure node im­pu­rity given y as well as meth­ods to track val­ues dur­ing split search and com­pute the ob­jec­tive func­tion in O(1)O(1) time. It should also pro­vide a method to es­ti­mate the op­ti­mal value of a node given y. De­pend­ing on the cri­te­rion, the op­ti­mal node value — that is, the node value that min­i­mizes the cri­te­rion given y — is es­ti­mated dif­fer­ently. This was over­looked ear­lier since for all the ob­jec­tive func­tions we’ve im­ple­mented the op­ti­mal value is sim­ply the mean out­come.

class Criterion(Protocol, Generic[S]):
    def node_impurity(self, y: np.ndarray, sample_weights: np.ndarray) -> float: ...

    def node_optimal_value(self, y: np.ndarray) -> np.ndarray: ...

    def init_split_stats(self, y: np.ndarray, sample_weights: np.ndarray) -> S: ...

    def update_split_stats(
        self, stats: S, y_value: np.ndarray, weight: float
    ) -> None: ...

    def split_impurity(self, stats: S) -> float: ...

Both Gini im­pu­rity and en­tropy take the exact same val­ues as input, so we can cre­ate a sin­gle clas­si­fi­ca­tion cri­te­rion:

class ClassificationCriterion(Criterion):
    def __init__(self, objective_fn: Callable[[np.ndarray], float]):
        self.objective = objective_fn

    def node_optimal_value(self, y: np.ndarray) -> np.ndarray:
        return np.mean(y, axis=0)

    def node_impurity(self, y: np.ndarray, sample_weights: np.ndarray) -> float:
        # For binary classification with single column
        if y.shape[1] == 1:
            y = np.hstack((y, 1 - y))
        return self.objective(_class_probabilities(y, sample_weights))

    def init_split_stats(
        self, y: np.ndarray, sample_weights: np.ndarray
    ) -> ClassificationSplitStats:
        sample_weights = sample_weights.reshape((-1, 1))

        # For binary classification with single column
        if y.shape[1] == 1:
            y = np.hstack((y, 1 - y))

        return ClassificationSplitStats(
            left_weight=0,
            right_weight=np.sum(sample_weights),
            left_class_count=np.zeros(y.shape[1], dtype=sample_weights.dtype),
            right_class_count=np.sum(
                y * sample_weights, axis=0, dtype=sample_weights.dtype
            ),
        )

    def update_split_stats(
        self,
        stats: ClassificationSplitStats,
        y_value: np.ndarray,
        weight: float,
    ) -> None:
        stats.left_weight += weight
        stats.right_weight -= weight

        # For binary classification with single column
        if len(y_value) == 1:
            y_value = np.hstack((y_value, 1 - y_value))

        stats.left_class_count += y_value * weight
        stats.right_class_count -= y_value * weight

    def split_impurity(self, stats: ClassificationSplitStats) -> float:
        criterion_l = self.objective(stats.left_class_count / stats.left_weight)
        criterion_r = self.objective(stats.right_class_count / stats.right_weight)

        total_weight = stats.left_weight + stats.right_weight
        p_l = stats.left_weight / total_weight
        p_r = stats.right_weight / total_weight
        return float(p_l * criterion_l + p_r * criterion_r)

No­tice that we treat the bi­nary clas­si­fi­ca­tion case dif­fer­ently. When there are two pos­si­ble out­comes, the y ma­trix can be en­coded with a sin­gle bi­nary col­umn (ei­ther 1 or 0). Thus, one-​hot en­cod­ing two classes into two columns is waste­ful. Since the ob­jec­tive func­tions still re­quire the pro­por­tion of both classes, we add the sec­ond col­umn here.

Then, we adapt the split search func­tion to use this cri­te­rion.

def _find_best_split(
    X, y, criterion: Criterion, sample_weights: np.ndarray
) -> Split | None:
    min_score = np.inf
    best_split = None

    for feat_idx in range(X.shape[1]):
        sort_idx = np.argsort(X[:, feat_idx])
        x_sorted = X[sort_idx, feat_idx]
        y_sorted = y[sort_idx]
        weights_sorted = sample_weights[sort_idx]

        stats = criterion.init_split_stats(y_sorted, weights_sorted)

        for i in range(1, len(y_sorted)):
            criterion.update_split_stats(stats, y_sorted[i - 1], weights_sorted[i - 1])
            if x_sorted[i] != x_sorted[i - 1]:
                score = criterion.split_impurity(stats)
                if score < min_score:
                    min_score = score
                    best_split = Split(
                        criterion=min_score,
                        feature_idx=feat_idx,
                        split_value=x_sorted[i - 1],
                        left_index=sort_idx[:i],
                        right_index=sort_idx[i:],
                        left_value=criterion.node_optimal_value(y_sorted[:i]),
                        right_value=criterion.node_optimal_value(y_sorted[i:]),
                    )

    return best_split

There is an edge case when two con­sec­u­tive sam­ples have the exact same sam­ple value. We’re mov­ing sam­ples one by one, but when ap­ply­ing the de­ci­sion rule all tied sam­ples will be­long to the same child node. To avoid this issue, we check whether the split value has changed from the pre­vi­ous it­er­a­tion. If it has not changed, we skip the it­er­a­tion — it’s not a valid split point.

Re­call­ing what we’ve seen pre­vi­ously, the squared loss is de­fined as fol­lows:

L(D)=1Ni=1N(yiyˉ)2L(\mathcal{D}) = \frac{1}{N} \sum_{i=1}^N (y_{i} - \bar{y})^2

This func­tion quan­ti­fies the within-​node vari­ance of the tar­get vari­able. Let’s ex­pand the qua­dratic term:

i=1N(yiyˉ)2=i=1N(yi22yiyˉ+yˉ2)=i=1Nyi22yˉi=1Nyi+Nyˉ2\sum_{i=1}^N (y_{i} - \bar{y})^2 = \sum_{i=1}^N (y_{i}^2 - 2 y_{i} \bar{y} + \bar{y}^2) = \sum_{i=1}^N y_{i}^2 - 2 \bar{y} \sum_{i=1}^N y_{i} + N \bar{y}^2

We know that yˉ\bar{y} is not a pa­ra­me­ter, it’s the av­er­age out­come of the node:

yˉ=1Ni=1Nyi\bar{y} = \frac{1}{N} \sum_{i=1}^N y_{i}

Sub­sti­tut­ing yˉ\bar{y}:

i=1Nyi22yˉi=1Nyi+Nyˉ2=i=1Nyi21N(i=1Nyi)2\sum_{i=1}^N y_{i}^2 - 2 \bar{y} \sum_{i=1}^N y_{i} + N \bar{y}^2 = \sum_{i=1}^N y_{i}^2 - \frac{1}{N} \Big( \sum_{i=1}^N y_{i} \Big)^2

Thus, the loss be­comes:

L(D)=i=1Nyi2N(i=1NyiN)2L(\mathcal{D}) = \frac{\sum_{i=1}^N y_{i}^2}{N} - \Big(\frac{\sum_{i=1}^N y_{i}}{N}\Big)^2

The first term is the sum of squares di­vided by NN and the sec­ond is the squared mean. If we track the squared sum, the sum, and NN, we can com­pute the ob­jec­tive in con­stant time.

@dataclass
class SquaredLossSplitStats(BaseSplitStats):
    left_sum: np.ndarray
    right_sum: np.ndarray
    left_sum_squared: np.ndarray
    right_sum_squared: np.ndarray


class SquaredLossCriterion(Criterion):
    def node_impurity(self, y: np.ndarray, sample_weights: np.ndarray) -> float:
        sample_weights = sample_weights.reshape(-1, 1)
        weighted_mean = np.average(y, weights=sample_weights)
        return float(np.average((y - weighted_mean) ** 2, weights=sample_weights))

    def node_optimal_value(self, y: np.ndarray) -> np.ndarray:
        return np.mean(y, axis=0)

    def init_split_stats(
        self, y: np.ndarray, sample_weights: np.ndarray
    ) -> SquaredLossSplitStats:
        sample_weights = sample_weights.reshape((-1, 1))
        return SquaredLossSplitStats(
            left_weight=0,
            right_weight=np.sum(sample_weights),
            left_sum=np.zeros(y.shape[1], dtype=y.dtype),
            right_sum=np.sum(y * sample_weights, axis=0),
            left_sum_squared=np.zeros(y.shape[1], dtype=y.dtype),
            right_sum_squared=np.sum(y * y * sample_weights, axis=0),
        )

    def update_split_stats(
        self,
        stats: SquaredLossSplitStats,
        y_value: np.ndarray,
        weight: float,
    ) -> None:
        stats.left_sum += weight * y_value
        stats.right_sum -= weight * y_value
        stats.left_weight += weight
        stats.right_weight -= weight
        stats.left_sum_squared += weight * y_value * y_value
        stats.right_sum_squared -= weight * y_value * y_value

    def split_impurity(self, stats: SquaredLossSplitStats) -> float:
        left_mean = stats.left_sum / stats.left_weight if stats.left_weight > 0 else 0
        right_mean = (
            stats.right_sum / stats.right_weight if stats.right_weight > 0 else 0
        )

        criterion_l = (
            np.sum(stats.left_sum_squared / stats.left_weight - left_mean * left_mean)
            if stats.left_weight > 0
            else 0
        )
        criterion_r = (
            np.sum(
                stats.right_sum_squared / stats.right_weight - right_mean * right_mean
            )
            if stats.right_weight > 0
            else 0
        )

        total_weight = stats.left_weight + stats.right_weight
        p_l = stats.left_weight / total_weight
        p_r = stats.right_weight / total_weight
        return float(p_l * criterion_l + p_r * criterion_r)

Cat­e­gor­i­cal fea­tures

In the pre­vi­ous post we’ve seen that de­ci­sion trees ac­cept both nu­mer­i­cal and cat­e­gor­i­cal fea­tures, yet our im­ple­men­ta­tion can only han­dle nu­mer­i­cal ones. One pos­si­bil­ity is to one-​hot en­code cat­e­gor­i­cal fea­tures into nu­mer­i­cal ones. Since the time com­plex­ity of fit­ting a tree scales lin­early with the num­ber of fea­tures, this doesn’t hurt per­for­mance cat­a­stroph­i­cally. In fact, the scikit-​learn im­ple­men­ta­tion still doesn’t sup­port cat­e­gor­i­cal fea­tures di­rectly and re­quires nu­mer­i­cal en­cod­ing. One-​hot en­cod­ing shifts the com­plex­ity from com­bi­na­tions in each split to the tree struc­ture: many splits are re­quired to cap­ture a re­la­tion­ship that is true for a group of cat­e­gory lev­els of a fea­ture. More­over, the en­cod­ing process makes the X fea­ture ma­trix larger in mem­ory.

In the CART-​based fam­ily of de­ci­sion tree al­go­rithms, a cat­e­gor­i­cal split may con­sider one cat­e­gory at a time (one vs rest) or com­bi­na­tions of cat­e­gories. The prob­lem is that there are 2d112^{d-1} - 1 pos­si­ble com­bi­na­tions of cat­e­gories, where dd is the num­ber of lev­els (dis­tinct cat­e­gories) in a fea­ture. If we have 100 lev­els, that’s al­ready 524287 com­bi­na­tions to try for every split. The tree struc­ture pro­duced by a one vs rest strat­egy is ef­fec­tively equiv­a­lent to one pro­duced using one-​hot en­coded fea­tures, but it re­quires less mem­ory for the fea­ture ma­trix.

One vs rest strat­egy

In this strat­egy, the split search con­sid­ers all lev­els of a fea­ture in­di­vid­u­ally and com­pares them with all other lev­els com­bined. In other words, the left child node con­tains sam­ples of a sin­gle level while the right child node con­tains all other lev­els. The split can be quite un­even and re­sult in less bal­anced trees, but the per­for­mance trade-​off can be worth it. We avoid the ex­po­nen­tial num­ber of com­bi­na­tions seen above and don’t ex­pand the mem­ory foot­print of the X ma­trix with many one-​hot en­coded columns.

The cri­te­rion in­ter­face should have new meth­ods to han­dle the cat­e­gor­i­cal split, prefer­ably in con­stant time. Each fea­ture level is con­sid­ered in­di­vid­u­ally, thus we only need to pre-​compute the in­dices of each level. This can be done in O(NlogN)O(N \log{N}) time by lex­i­co­graph­i­cally sort­ing the fea­ture.

cat_indices = {}
sort_idx = np.argsort(x)
x_sorted = x[sort_idx]
y_sorted = y[sort_idx]
weights_sorted = sample_weights[sort_idx]

start_idx = 0
for i in range(1, len(x_sorted) + 1):
    if i == len(x_sorted) or x_sorted[i] != x_sorted[start_idx]:
        cat_indices[x_sorted[start_idx]] = {
            "indices": sort_idx[start_idx:i],
            "y": y_sorted[start_idx:i],
            "weights": weights_sorted[start_idx:i],
        }
        start_idx = i

When­ever the fea­ture value changes, we’ve reached a new level and there­fore can com­pute the in­dices, the sub­set of y, and the sub­set of sample_weights that be­long to this level. The cri­te­rion of the left child node can be di­rectly com­puted using the sub­set of y de­fined above, whereas the cri­te­rion for the right child node can be found re­mov­ing the class counts of the level from the total. You may have no­ticed this is not a O(1)O(1) op­er­a­tion be­cause we have to sum over sub­sets of y. Still, this is done once per level, and we avoid any qua­dratic or ex­po­nen­tial op­er­a­tions.

The split stats are ini­tial­ized the same way as be­fore. We im­ple­ment a new method to make new stats for a cat­e­gor­i­cal group with the logic we’ve just de­scribed:

class ClassificationCriterion(Criterion):
    # [...]

    def make_stats_from_categorical_level(
        self,
        stats: ClassificationSplitStats,
        y: np.ndarray,
        sample_weights: np.ndarray,
        is_left: bool,
    ) -> ClassificationSplitStats:
        level_weights = sample_weights.reshape(-1, 1)
        level_weight = np.sum(sample_weights)

        # For binary classification with single column
        if y.shape[1] == 1:
            y = np.hstack((y, 1 - y))

        level_sum = np.sum(y * level_weights, axis=0)

        if is_left:
            stats = replace(
                stats,
                left_weight=level_weight,
                right_weight=stats.right_weight - level_weight,
                left_class_count=level_sum,
                right_class_count=stats.right_class_count - level_sum,
            )
        else:
            stats = replace(
                stats,
                left_weight=stats.left_weight - level_weight,
                right_weight=level_weight,
                left_class_count=stats.left_class_count - level_sum,
                right_class_count=level_sum,
            )
        return stats

An anal­o­gous method has been im­ple­mented for the squared loss cri­te­rion. Then, we use this method in the new _best_categorical_split func­tion:

def _best_categorical_split(
    x: np.ndarray,
    y: np.ndarray,
    feat_idx: int,
    criterion: Criterion,
    sample_weights: np.ndarray,
    min_samples_leaf: int,
):
    min_score = np.inf
    best_split = None
    unique_values = np.unique(x)

    stats = criterion.init_split_stats(y.astype(np.float64), sample_weights)

    # Pre-compute category indices
    cat_indices = {}
    sort_idx = np.argsort(x)
    x_sorted = x[sort_idx]
    y_sorted = y[sort_idx]
    weights_sorted = sample_weights[sort_idx]

    start_idx = 0
    for i in range(1, len(x_sorted) + 1):
        if i == len(x_sorted) or x_sorted[i] != x_sorted[start_idx]:
            cat_indices[x_sorted[start_idx]] = {
                "indices": sort_idx[start_idx:i],
                "y": y_sorted[start_idx:i],
                "weights": weights_sorted[start_idx:i],
            }
            start_idx = i

    stats = criterion.init_split_stats(y, sample_weights)

    n_samples = len(y_sorted)
    for value in unique_values:
        cat_data = cat_indices[value]
        level_size = len(cat_data["indices"])
        if level_size == 0 or level_size == len(x):
            continue

        if level_size < min_samples_leaf or (n_samples - level_size) < min_samples_leaf:
            continue

        level_stats = criterion.make_stats_from_categorical_level(
            stats, cat_data["y"], cat_data["weights"], is_left=True
        )
        score = criterion.split_impurity(level_stats)

        if score < min_score:
            min_score = score
            left_indices = cat_data["indices"]
            right_indices = np.setdiff1d(np.arange(len(x)), left_indices)
            best_split = Split(
                criterion=score,
                feature_idx=feat_idx,
                split_value=set([value]),
                left_index=left_indices,
                right_index=right_indices,
                left_value=criterion.node_optimal_value(y[left_indices]),
                right_value=criterion.node_optimal_value(y[right_indices]),
            )

    return min_score, best_split

Op­ti­mal par­ti­tion­ing

Luck­ily, there is a very handy op­ti­miza­tion to find the op­ti­mal par­ti­tion­ing of lev­els when the out­put is uni­vari­ate — that is, when we’re deal­ing with bi­nary clas­si­fi­ca­tion or uni­vari­ate re­gres­sion. It was first pro­posed by Fisher in 1958 as a method to group a set of num­bers so that the vari­ance within groups is min­i­mized. Here, again, the num­ber of pos­si­ble com­bi­na­tions grow ex­po­nen­tially. The proof states that we only need to look at the sorted par­ti­tions by av­er­age out­come in­stead of all pos­si­ble per­mu­ta­tions.

Con­sider we want to di­vide TV shows into two groups so that the au­di­ence is as ho­mo­ge­neous as pos­si­ble in each group. This is equiv­a­lent to find­ing the cat­e­gor­i­cal split set that min­i­mizes the weighted vari­ance in each group, i.e. min­i­mizes the squared loss as de­fined above. Any best split must con­tain only con­tigu­ous TV shows (as de­fined by the or­dered au­di­ence), as any non-​contiguous item would al­ways in­crease the group vari­ance. There­fore, it’s pos­si­ble to try N1N - 1 splits over the or­dered TV shows with­out con­sid­er­ing all pos­si­ble com­bi­na­tions. In the re­gres­sion set­ting, we cal­cu­late the av­er­age out­come per cat­e­gory level, which ren­ders the al­go­rithm above suit­able for our pur­poses. Each group level is anal­o­gous to a TV show, whereas the au­di­ence is the av­er­age out­come per group level.

The proof for bi­nary out­comes (bi­nary re­gres­sion) was only de­scribed later and can be found in Breiman et al. (1984) and Rip­ley (1996). In this case, we com­pute the av­er­age pos­i­tive out­come for each cat­e­gory level: one class is (ar­bi­trar­ily) rep­re­sented as 1, hence the av­er­age out­come is the pro­por­tion of this class. This method is used by Light­GBM and XG­Boost, both gradient-​boosted tree li­braries2.

The av­er­age out­come is com­puted by group­ing the out­come vec­tor y by fea­ture x and ap­ply­ing sam­ple weights:

df = pd.DataFrame({"x": x, "y": y.ravel(), "w": sample_weights})

cat_stats = df.groupby("x").agg(
    y_avg=pd.NamedAgg(
        column="y", aggfunc=lambda x: np.average(x, weights=df.loc[x.index, "w"])
    ),
    y_count=pd.NamedAgg(column="y", aggfunc=len),
    w=pd.NamedAgg(column="w", aggfunc="sum"),
)

cat_stats = cat_stats.sort_values("y_avg")

After sort­ing by the av­er­age out­come, we split this pre­dic­tor as if it were an or­dered pre­dic­tor. In other words, we go over the lev­els sorted by av­er­age out­come and move one by one to the left child node, com­put­ing the cri­te­rion at each step. Note that, un­like in pre­vi­ous cases, the fea­ture is not sorted by its val­ues, rather by the av­er­age out­come.

The cri­te­rion can be up­dated with a con­stant time op­er­a­tion by treat­ing the fea­ture level as a sin­gle sam­ple. The idea of using an av­er­age out­come as a sin­gle sam­ple may re­quire some in­tu­ition mas­sag­ing, but it does make sense. Con­sider a level with 8 sam­ples of out­come 1 and 2 sam­ples of out­come 0. The av­er­age out­come is 0.83. Con­sid­er­ing uni­form sam­ple weights, the weight is 10 (equal to the num­ber of sam­ples). We can ex­pand the av­er­age out­come to a two class vec­tor [0.8,0.2][0.8, 0.2] and mul­ti­ply it by the weight, re­sult­ing in the vec­tor [8,2][8, 2]. This vec­tor will be sub­tracted from the left child node and added to the right child node label count. Hence, mov­ing one av­er­age out­come “sam­ple” with a sum of sam­ple weights is equiv­a­lent to mov­ing each sam­ple in­di­vid­u­ally. After com­put­ing and or­der­ing by av­er­age out­comes, the im­ple­men­ta­tion closely re­sem­bles the nu­mer­i­cal fea­ture split search and won’t be shown here for brevity.

Com­bin­ing split search meth­ods

Fi­nally, we com­bine our three split search func­tions into one. This func­tion ap­plies nu­mer­i­cal split search to nu­mer­i­cal fea­tures (of course). For cat­e­gor­i­cal fea­tures, op­ti­mal par­ti­tion­ing is used if the y out­come is uni­vari­ate. Oth­er­wise, a one vs rest ap­proach is used.

def _find_best_split(
    X: pd.DataFrame,
    y: np.ndarray,
    criterion: Criterion,
    sample_weights: np.ndarray,
    min_samples_leaf: int,
) -> Split | None:
    min_score = np.inf
    best_split = None

    categorical_splitter = (
        _best_categorical_optimal_partitioning
        if y.shape[1] == 1
        else _best_categorical_split
    )

    feature_types = np.array(
        [np.issubdtype(X.iloc[:, i].dtype, np.number) for i in range(X.shape[1])]
    )
    feature_values = [X.iloc[:, i].values for i in range(X.shape[1])]

    for feat_idx in range(X.shape[1]):
        splitter = (
            _best_numerical_split if feature_types[feat_idx] else categorical_splitter
        )

        score, split = splitter(
            feature_values[feat_idx],
            y,
            feat_idx,
            criterion,
            sample_weights,
            min_samples_leaf,
        )

        if split is not None and score < min_score:
            min_score = score
            best_split = split

    return best_split

Train­ing and in­fer­ence

The scikit-​learn in­ter­face has be­come so ubiq­ui­tous in the Python world that it seems only rea­son­able to use it here. Let’s de­fine two in­ter­faces (in scikit-​learn style), one for clas­si­fi­ca­tion and the other for re­gres­sion:

class Regressor(Protocol):
    def fit(
        self, X: pd.DataFrame, y: np.ndarray, sample_weights: np.ndarray | None = None
    ) -> None: ...

    def predict(self, X: pd.DataFrame) -> np.ndarray: ...


class Classifier(Protocol):
    def fit(
        self, X: pd.DataFrame, y: np.ndarray, sample_weights: np.ndarray | None = None
    ) -> None: ...

    def predict(self, X: pd.DataFrame) -> np.ndarray: ...

    def predict_proba(self, X: pd.DataFrame) -> np.ndarray: ...

The classes fol­low­ing these in­ter­faces are mostly boil­er­plate code, but the in­fer­ence code is new. To make pre­dic­tions for a sin­gle fea­ture, we check if the fea­ture lies on the left or on the right side of the root node split. We re­peat this process until we reach a leaf node, and then re­turn the leaf node value as the pre­dic­tion. It’s quite a sim­ple in­fer­ence process with a O(logN)O(\log{N}) av­er­age time com­plex­ity, that is, pro­por­tional to the depth of the tree4.

For the re­gres­sion case:

class DecisionTreeRegressor:
    # [...]

    def predict(self, X: pd.DataFrame) -> np.ndarray:
        if self._root_node is None:
            raise ValueError("model must be trained before prediction")

        def traverse_tree(x, node):
            while isinstance(node, Node):
                feature_val = x.iloc[node.feature_idx]
                if isinstance(node.split_value, set):
                    node = node.left if feature_val in node.split_value else node.right
                else:
                    node = node.left if feature_val <= node.split_value else node.right
            return node.value

        y_pred = np.array([traverse_tree(x, self._root_node) for _, x in X.iterrows()])
        return y_pred

For clas­si­fi­ca­tion, this method be­comes predict_proba, which re­turns vec­tors of prob­a­bil­i­ties5. The predict method then chooses the most likely class as the pre­dic­tion for each in­stance using the fol­low­ing func­tion:

def _prob_to_class(prob: np.ndarray) -> np.ndarray:
    if prob.shape[1] > 1:
        return np.argmax(prob, axis=1)

    return (prob.squeeze(1) >= 0.5).astype(int)

Con­clu­sion

We have im­ple­mented clas­si­fi­ca­tion and re­gres­sion trees (CART) with a de­cent time com­plex­ity (no qua­dratic time op­er­a­tions), sup­port for sam­ple weights, and cat­e­gor­i­cal fea­tures. Let’s test them!

import random

import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer, load_diabetes
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error

from src.cart import DecisionTreeClassifier, DecisionTreeRegressor, print_tree


random.seed(42)
np.random.seed(42)
X, y = load_breast_cancer(return_X_y=True)
X = pd.DataFrame(np.random.random(size=(20, 4)))
X["cat"] = np.array([["B"] * 9 + ["A"] * 6 + ["C"] * 5]).T
y = np.array([0] * 10 + [1] * 10)

max_depth = 2
min_samples_leaf = 3

tree = DecisionTreeClassifier(max_depth, min_samples_leaf)
tree.fit(X, y, sample_weights=np.ones(len(y)) / 10)
pred = tree.predict(X)
score = f1_score(y, pred, average="macro")
acc = accuracy_score(y, pred)
print_tree(tree._root_node)
print(f"classification tree -> F1: {score:.2f} accuracy: {acc:.2%}")
print()

X, y = load_diabetes(return_X_y=True)
X = pd.DataFrame(X)

tree = DecisionTreeRegressor(max_depth, min_samples_leaf)
tree.fit(X, y)
pred = tree.predict(X)
mse = mean_squared_error(y, pred)
print_tree(tree._root_node)
print(f"regression tree -> MSE: {mse:.2f}")

You may have no­ticed the scikit-learn de­pen­dency. Well, it’s used only as a con­ve­nient way to load a toy dataset, so I think this is fair enough.

Node(feature_idx=4, split_value={'B'})
Left:
  LeafNode(value=[0.])
Right:
  Node(feature_idx=3, split_value=0.44)
  Left:
    LeafNode(value=[0.8])
  Right:
    LeafNode(value=[1.])
classification tree -> F1: 0.95 accuracy: 95.00%

Node(feature_idx=8, split_value=-0.00)
Left:
  Node(feature_idx=2, split_value=0.01)
  Left:
    LeafNode(value=[96.31])
  Right:
    LeafNode(value=[159.74])
Right:
  Node(feature_idx=2, split_value=0.01)
  Left:
    LeafNode(value=[162.68])
  Right:
    LeafNode(value=[225.88])
regression tree -> MSE: 3360.05

It works! In the next part of this se­ries we’ll talk about bag­ging, a strat­egy to re­duce model vari­ance, and the most fa­mous bagged tree al­go­rithm: ran­dom forests.

Ref­er­ences

Footnotes

  1. We will not im­ple­ment this kind of op­ti­miza­tion.

  2. We’ll cover gradient-​boosted de­ci­sion trees (GBDT) in the fu­ture.

  3. The av­er­age out­come can be seen as the pro­por­tion of pos­i­tive (1) out­comes.

  4. We as­sume roughly bal­anced trees and there­fore the depth of the tree is pro­por­tional to logN\log {N}.

  5. These prob­a­bil­i­ties are not well cal­i­brated and are not a good mea­sure of the un­cer­tainty of the model.



Next Post
Climbing trees 1: what are decision trees?