BVHについての考察

加速構造 BVH の基本的な構造と代表的な構築手法について
computerscience
accelerationstructure
bvh
公開

2025年7月16日

更新日

2025年8月2日

Open In Colab

Setup for Running in Google Colab

コード
!pip install anytree --quiet

1 テストシーン(共通データセット)

コード
import random
import numpy as np
import numpy.typing as npt
from abc import ABC, abstractmethod
from itertools import combinations
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Protocol, Sequence, TypeVar


@dataclass(slots=True)
class Ray:
    origin: np.ndarray
    direction: np.ndarray

    def at(self, t: float) -> np.ndarray:
        return self.origin + t * self.direction


@dataclass(slots=True)
class Bounds:
    lower: np.ndarray
    upper: np.ndarray

    def __post_init__(self) -> None:
        self.lower = np.array(self.lower)
        self.upper = np.array(self.upper)

    @classmethod
    def empty(cls, dim: int = 2) -> Bounds:
        return cls(np.full(dim, np.inf), np.full(dim, -np.inf))

    @classmethod
    def union(cls, lhs: npt.ArrayLike, rhs: npt.ArrayLike) -> Bounds:
        if isinstance(lhs, cls) and isinstance(rhs, cls):
            return cls(np.minimum(lhs.lower, rhs.lower), np.maximum(lhs.upper, rhs.upper))

        if isinstance(lhs, cls) and isinstance(rhs, (np.ndarray, list, tuple)):
            return cls(np.minimum(lhs.lower, rhs), np.maximum(lhs.upper, rhs))

        if isinstance(lhs, (np.ndarray, list, tuple)) and isinstance(rhs, cls):
            return cls(np.minimum(lhs, rhs.lower), np.maximum(lhs, rhs.upper))

        raise TypeError(f"Unsupported types: {type(lhs)}, {type(rhs)}")

    @property
    def center(self) -> np.ndarray:
        return 0.5 * (self.lower + self.upper)

    @property
    def extent(self) -> np.ndarray:
        return self.upper - self.lower

    @property
    def surface_area(self) -> float:
        extent = self.extent
        if extent.size < 2:
            return 0.0
        return 2.0 * sum(np.prod(extent[k]) for k in combinations(range(extent.size), extent.size - 1))  # type: ignore

    def intersect(self, ray: Ray) -> tuple[bool, float, float]:
        return intersect_ray_aabb(ray.origin, ray.direction, self.lower, self.upper)


class Primitive(ABC):
    @property
    @abstractmethod
    def bounds(self) -> Bounds:
        pass

    @property
    @abstractmethod
    def centroid(self) -> np.ndarray:
        pass

    @property
    @abstractmethod
    def name(self) -> str:
        pass

    @abstractmethod
    def intersect(self, ray: Ray) -> float | None:
        pass


@dataclass(slots=True)
class Triangle(Primitive):
    _vertices: np.ndarray
    _indices: list[int] | None
    _name: str

    def __init__(self, vertices, indices=None, name="") -> None:
        self._vertices = np.asarray(vertices)
        self._indices = indices
        self._name = name

    def vertices(self, copy: bool = False) -> np.ndarray:
        if self._indices is None:
            return np.array(self._vertices, copy=copy)
        return self._vertices[self._indices]

    @property
    def bounds(self) -> Bounds:
        verts = self.vertices()
        return Bounds(np.min(verts, axis=0), np.max(verts, axis=0))

    @property
    def centroid(self) -> np.ndarray:
        verts = self.vertices()
        return np.mean(verts, axis=0)

    @property
    def name(self) -> str:
        return self._name

    def intersect(self, ray: Ray) -> float | None:
        return intersect_ray_triangle(ray.origin, ray.direction, *self.vertices())

交差関数群

コード
def perp_dot(a: np.ndarray, b: np.ndarray) -> float:
    return a[0] * b[1] - a[1] * b[0]


def intersect_ray_segment(ro: np.ndarray, rd: np.ndarray, p0: np.ndarray, p1: np.ndarray) -> float | None:
    u = p1 - p0
    v = ro - p0

    inv_denom = 1.0 / perp_dot(rd, u)  # IEEE 754
    s = perp_dot(u, v) * inv_denom
    t = perp_dot(rd, v) * inv_denom
    if 0.0 <= s < np.inf and 0.0 <= t <= 1.0:
        return s

    return None


def intersect_ray_triangle(ro: np.ndarray, rd: np.ndarray, v0: np.ndarray, v1: np.ndarray, v2: np.ndarray) -> float | None:
    t = intersect_ray_segment(ro, rd, v0, v1)

    t1 = intersect_ray_segment(ro, rd, v1, v2)
    if t1 is not None and (t is None or t1 < t):
        t = t1

    t2 = intersect_ray_segment(ro, rd, v2, v0)
    if t2 is not None and (t is None or t2 < t):
        t = t2

    return t


def intersect_ray_aabb(ro: np.ndarray, rd: np.ndarray, lower: np.ndarray, upper: np.ndarray) -> tuple[bool, float, float]:
    t_min, t_max = 0.0, np.inf

    for i in range(ro.size):
        inv_d = 1.0 / rd[i]  # IEEE 754
        t0 = (lower[i] - ro[i]) * inv_d
        t1 = (upper[i] - ro[i]) * inv_d
        t_min = np.fmax(t_min, np.fmin(t0, t1))
        t_max = np.fmin(t_max, np.fmax(t0, t1))

    return t_min < t_max, t_min, t_max

三角形群の生成

コード
from scipy.spatial import cKDTree


def generate_triangles(n, xmin, ymin, xmax, ymax, *, points_kwargs=None, triangle_kwargs=None):
    merged_points_kwargs = {
        "min_dist": 1.0,
        "domain": (0, 1, 0, 1),
        "steps": 50,
        "lr": 0.01,
    }
    merged_points_kwargs.update({} if points_kwargs is None else points_kwargs)
    merged_triangle_kwargs = {
        "angle_jitter": 30.0,
        "max_radius": 1.0,
        "scale_range": (0.5, 1.0),
    }
    merged_triangle_kwargs.update({} if triangle_kwargs is None else triangle_kwargs)
    max_radius = merged_triangle_kwargs["max_radius"]

    def generate_triangle(center, *, angle_jitter=30.0, max_radius=1.0, scale_range=(0.5, 1.0)):
        assert abs(angle_jitter) < 60.0, "angle_jitter must be less than 60 degrees to preserve CCW order"
        angles = np.linspace(0.0, 2.0 * np.pi, 4)[:-1]
        angles += np.deg2rad(np.random.uniform(-angle_jitter, angle_jitter, 3))
        dirs = np.stack([np.cos(angles), np.sin(angles)], axis=-1)
        radii = np.random.uniform(scale_range[0], scale_range[1], (3, 1)) * max_radius
        return center.reshape(1, 2) + radii * dirs

    def generate_points(n, *, min_dist=1.0, domain=(0, 1, 0, 1), steps=50, lr=0.01):
        xmin, xmax, ymin, ymax = domain
        points = np.random.rand(n, 2)
        r = min_dist / max(xmax - xmin, ymax - ymin)

        for _ in range(steps):
            tree = cKDTree(points)
            pairs = tree.query_pairs(r)

            net = np.zeros_like(points)
            for i, j in pairs:
                delta = points[i] - points[j]
                dist = np.linalg.norm(delta) + 1e-10
                force = delta / dist
                net[i] += force
                net[j] -= force

            points += lr * net
            points[:, 0] = np.clip(points[:, 0], 0, 1)
            points[:, 1] = np.clip(points[:, 1], 0, 1)

        points[:, 0] = points[:, 0] * (xmax - xmin) + xmin
        points[:, 1] = points[:, 1] * (ymax - ymin) + ymin
        return points

    domain = (xmin, xmax, ymin, ymax)
    safe_domain = (
        xmin + max_radius,
        xmax - max_radius,
        ymin + max_radius,
        ymax - max_radius,
    )
    merged_points_kwargs["domain"] = safe_domain
    centers = generate_points(n, **merged_points_kwargs)

    return [generate_triangle(center, **merged_triangle_kwargs) for center in centers]
コード
random.seed(0)
np.random.seed(4)

num_triangles = 50
num_digits = len(str(num_triangles - 1))
points_kwargs = {"min_dist": 4.0}
triangle_kwargs = {"max_radius": 2.0, "scale_range": (0.65, 0.95)}
triangles = generate_triangles(num_triangles, 0.0, 0.0, 50.0, 50.0, points_kwargs=points_kwargs, triangle_kwargs=triangle_kwargs)
primitives: list[Primitive] = [Triangle(tri, name=f"T{i:0{num_digits}}") for i, tri in enumerate(triangles)]

2 構築手法

共通の定義

コード
@dataclass(init=False, slots=True)
class BVHBuildNode:
    bounds: Bounds
    split_axis: int
    left: BVHBuildNode | None
    right: BVHBuildNode | None
    first_prim_offset: int
    num_primitives: int

    def init_leaf(self, bounds: Bounds, first_prim_offset: int, num_primitives: int) -> None:
        self.bounds = bounds
        self.split_axis = 0
        self.left = self.right = None
        self.first_prim_offset = first_prim_offset
        self.num_primitives = num_primitives

    def init_interior(self, axis: int, left: BVHBuildNode, right: BVHBuildNode) -> None:
        self.bounds = Bounds.union(left.bounds, right.bounds)
        self.split_axis = axis
        self.left = left
        self.right = right
        self.first_prim_offset = 0
        self.num_primitives = 0

    def is_leaf(self) -> bool:
        return self.num_primitives > 0

    def primitives(self, ordered_primitives: list[Primitive]) -> list[Primitive]:
        offset, count = self.first_prim_offset, self.num_primitives
        return ordered_primitives[offset : offset + count]


@dataclass(slots=True)
class BVHPrimitiveInfo:
    index: int
    bounds: Bounds
    centroid: np.ndarray = field(init=False)

    def __post_init__(self) -> None:
        self.centroid = self.bounds.center


@dataclass(slots=True)
class BVHVariant:
    name: str = ""
    nodes: list[LinearBVHNode] | None = None
    primitives: list[Primitive] | None = None
    filename: str = ""

2.1 木構造の平坦化

次のような木構造があるとする.

A
├─ C
└─ B
     ├─ E
     └─ D

表にすると,

\[ \begin{array}{c|c|c|c} \text{Node} & \text{Left} & \text{Right} & \text{Type}\\ \hline \text{A} & \text{B} & \text{C} & \text{Interior}\\ \text{B} & \text{D} & \text{E} & \text{Interior}\\ \text{C} & \text{-} & \text{-} & \text{Leaf}\\ \text{D} & \text{-} & \text{-} & \text{Leaf}\\ \text{E} & \text{-} & \text{-} & \text{Leaf} \end{array} \]

深さ優先で平坦化すると,

\[ \begin{array}{c|c|c|c|c} 0 & 1 & 2 & 3 & 4\\ \hline \text{A} & \text{B} & \text{D} & \text{E} & \text{C} \end{array} \]

あるノードの左の子ノードは,存在する場合,常に次の位置に出現する.

コード
@dataclass(slots=True)
class LinearBVHNode:
    bounds: Bounds = field(default_factory=Bounds.empty)
    offset: int = 0
    num_primitives: int = 0
    axis: int = 0

    def is_leaf(self) -> bool:
        return self.num_primitives > 0

    def primitives(self, ordered_primitives: list[Primitive]) -> list[Primitive]:
        offset, count = self.offset, self.num_primitives
        return ordered_primitives[offset : offset + count]

    def child_indices(self, index: int) -> tuple[int, int]:
        return index + 1, self.offset


@dataclass(slots=True)
class FlattenBVHTreeContext:
    nodes: list[LinearBVHNode]
    offset: int = 0


def flatten_bvh_tree(node: BVHBuildNode, context: FlattenBVHTreeContext) -> int:
    linearNode = context.nodes[context.offset]
    linearNode.bounds = node.bounds
    current_offset = context.offset
    context.offset += 1

    if node.is_leaf():
        linearNode.offset = node.first_prim_offset
        linearNode.num_primitives = node.num_primitives
    else:
        assert node.left is not None and node.right is not None

        linearNode.axis = node.split_axis
        flatten_bvh_tree(node.left, context)
        linearNode.offset = flatten_bvh_tree(node.right, context)

    return current_offset

2.2 基本的な BVH

要点

  1. プリミティブの重心からなる点群に対して,軸平行境界ボックス(AABB)を作成する
  2. AABB の中で最も広がりの大きい軸を分割軸として選ぶ
  3. 分割軸の中央値でプリミティブを二分する

C++ の std::partitionstd::nth_element に相当する関数群

コード
T = TypeVar("T")


def partition(lst: list[T], *, start: int, end: int, pred: Callable[[T], bool]) -> int:
    i = start
    for j in range(start, end):
        if pred(lst[j]):
            lst[i], lst[j] = lst[j], lst[i]
            i += 1

    return i


class SupportsRandInt(Protocol):
    def randint(self, a: int, b: int) -> int: ...


def nth_element(lst: list[T], n: int, *, lo: int = 0, hi: int | None = None, key: Callable[[T], object] = lambda x: x, rng: SupportsRandInt = random) -> None:
    def partition(lst, lo, hi, key):
        """
        Lomuto partition scheme
        https://en.wikipedia.org/wiki/Quicksort
        """
        pivot_index = rng.randint(lo, hi)
        lst[pivot_index], lst[hi] = lst[hi], lst[pivot_index]
        pivot = key(lst[hi])

        i = lo
        for j in range(lo, hi):
            if key(lst[j]) <= pivot:
                lst[i], lst[j] = lst[j], lst[i]
                i += 1

        lst[i], lst[hi] = lst[hi], lst[i]
        return i

    if hi is None:
        hi = len(lst) - 1

    if not (lo <= n <= hi):
        raise IndexError("index out of range")

    while lo <= hi:
        pivot_index = partition(lst, lo, hi, key)
        if pivot_index == n:
            return
        elif pivot_index < n:
            lo = pivot_index + 1
        else:
            hi = pivot_index - 1
コード
class HasBounds(Protocol):
    bounds: Bounds


def compute_bounds(bounded_objects: Sequence[HasBounds]) -> Bounds:
    bounds = Bounds.empty()
    for obj in bounded_objects:
        bounds = Bounds.union(bounds, obj.bounds)

    return bounds


def build_bvh(primitives: list[Primitive], split_func) -> tuple[list[LinearBVHNode], list[Primitive]]:
    primitive_info = [BVHPrimitiveInfo(i, p.bounds) for i, p in enumerate(primitives)]
    total_nodes = 0
    ordered_primitives: list[Primitive] = []

    def build(start, end):
        nonlocal total_nodes

        node = BVHBuildNode()
        total_nodes += 1

        # Compute bounds of all primitives in BVH node
        bounds = compute_bounds(primitive_info[start:end])

        num_primitives = end - start
        if num_primitives == 1:
            # Create leaf BVHBuildNode
            offset = len(ordered_primitives)
            for i in range(start, end):
                ind = primitive_info[i].index
                ordered_primitives.append(primitives[ind])

            node.init_leaf(bounds=bounds, first_prim_offset=offset, num_primitives=num_primitives)
            return node

        # Compute bound of primitive centroids, choose split dimension dim
        centroid_bounds = Bounds.empty()
        for i in range(start, end):
            centroid_bounds = Bounds.union(centroid_bounds, primitive_info[i].centroid)

        extent = centroid_bounds.extent
        axis = np.argmax(extent)

        if np.isclose(extent[axis], 0.0):
            # Create leaf BVHBuildNode
            offset = len(ordered_primitives)
            for i in range(start, end):
                ind = primitive_info[i].index
                ordered_primitives.append(primitives[ind])

            node.init_leaf(bounds=bounds, first_prim_offset=offset, num_primitives=num_primitives)
            return node

        # Partition primitives into two sets and build children
        mid = split_func(primitive_info, start, end, axis, bounds=bounds, centroid_bounds=centroid_bounds)
        if mid is None:
            # Create leaf BVHBuildNode
            offset = len(ordered_primitives)
            for i in range(start, end):
                ind = primitive_info[i].index
                ordered_primitives.append(primitives[ind])

            node.init_leaf(bounds=bounds, first_prim_offset=offset, num_primitives=num_primitives)
            return node

        # Partition primitives
        node.init_interior(axis=axis, left=build(start, mid), right=build(mid, end))
        return node

    root = build(0, len(primitives))

    # Compute representation of depth-first traversal of BVH tree
    context = FlattenBVHTreeContext(nodes=[LinearBVHNode() for _ in range(total_nodes)])
    flatten_bvh_tree(root, context)
    return context.nodes, ordered_primitives


def middle_split(primitive_info: list[BVHPrimitiveInfo], start: int, end: int, axis: int, **kwargs) -> int:
    centroid_bounds = kwargs["centroid_bounds"]

    # Partition primitives through node's midpoint
    mid = partition(primitive_info, start=start, end=end, pred=lambda pi: pi.centroid[axis] < centroid_bounds.center[axis])
    if mid == start or mid == end:
        # Partition primitives into equally sized subsets
        mid = (start + end) // 2
        nth_element(primitive_info, mid, lo=start, hi=end - 1, key=lambda pi: pi.centroid[axis])

    return mid
コード
basic_bvh_nodes, basic_bvh_primitives = build_bvh(primitives, split_func=middle_split)
bvh_variant = BVHVariant(name="Basic BVH", nodes=basic_bvh_nodes, primitives=basic_bvh_primitives, filename="basic_bvh.png")
bvh_variants = [bvh_variant]

2.2.1 木構造

コード
import anytree
import anytree.exporter
import IPython.display as display


def convert_to_anytree(nodes, ordered_primitives, parent=None):
    stack = [(0, None)]
    root = None

    while stack:
        ind, parent = stack.pop()
        node = nodes[ind]

        if node.is_leaf():
            prims = node.primitives(ordered_primitives)
            name = "\n".join(f"{p.name}" for p in prims)
            at_node = anytree.Node(f"{name}", parent=parent)
        else:
            lower, uppder = node.bounds.lower, node.bounds.upper
            name = f"LB({lower[0]:.2f}, {lower[1]:.2f})\nUB({uppder[0]:.2f}, {uppder[1]:.2f})"
            at_node = anytree.Node(name, parent=parent)
            l_ind, r_ind = node.child_indices(ind)
            stack.append((r_ind, at_node))
            stack.append((l_ind, at_node))

        if parent is None:
            root = at_node

    return root
コード
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)

# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))

2.2.2 空間構造

コード
import matplotlib.pyplot as plt

# print(plt.style.available)
plt.style.use("seaborn-v0_8")


def plot_bounds(bounds, ax, **kwargs):
    defaults = dict(fill=False, edgecolor="black", linestyle="--", linewidth=1, alpha=0.5)
    defaults.update(kwargs)
    extent = bounds.extent
    rect = plt.Rectangle(bounds.lower, extent[0], extent[1], **defaults)
    ax.add_patch(rect)
    ax.update_datalim(rect.get_path().vertices)
    ax.autoscale_view()


def plot_triangle(tri, ax, **kwargs):
    defaults = dict(edgecolor="black", alpha=0.5)
    defaults.update(kwargs)
    verts = tri.vertices()
    poly = plt.Polygon(verts, closed=True, **defaults)
    ax.add_patch(poly)
    ax.update_datalim(poly.get_path().vertices)
    ax.autoscale_view()


def plot_bvh(nodes, ordered_primitives, ax):
    cmap = plt.colormaps["tab20"]
    c_ind = 0
    stack = []
    stack.append(0)

    while stack:
        ind = stack.pop()
        node = nodes[ind]

        if node.is_leaf():
            color = cmap(c_ind % cmap.N)
            c_ind += 1
            plot_bounds(node.bounds, ax, edgecolor=color)

            prims = node.primitives(ordered_primitives)
            for p in prims:
                plot_triangle(p, ax, facecolor=color)
                centroid = p.centroid
                ax.text(centroid[0], centroid[1], p.name, ha="center", va="center", fontsize=8, color="black")
        else:
            plot_bounds(node.bounds, ax)

            l_ind, r_ind = node.child_indices(ind)
            stack.append(r_ind)
            stack.append(l_ind)
コード
fig, ax = plt.subplots(figsize=(5, 5), constrained_layout=True)

plot_bvh(bvh_variant.nodes, bvh_variant.primitives, ax)

# Styling and layout
ax.set_title(bvh_variant.name)
ax.set_aspect("equal", adjustable="box")
ax.grid(True)

plt.show()

2.3 SAH BVH

要点

  1. 適当な数のバケットを用意する
  2. プリミティブの重心に基づいて,それらをバケットに振り分ける
  3. 各分割位置に対してコストを計算する
  4. 最もコストの低い位置でプリミティブを二分する

評価は次式で行う.

\[ c(A,B)=t_{\text{trav}}+p_{A}\sum_{i=1}^{N_{A}} t_{\text{isect}}(a_{i})+p_{B}\sum_{i=1}^{N_{B}} t_{\text{isect}}(b_{i}) \]

ここで,

  • \(t_{trav}\) は内部ノードを走査し,光線がどの子ノードを通過するかを決定するのにかかる時間
  • \(p_{A}, p_{B}\) は光線が各子ノードを通過する確率
  • \(N_{A}, N_{B}\) はそれぞれの子ノードの領域に重なり合うプリミティブの数
  • \(t_{\text{isect}}(i)\) は光線と \(i\) 番目のプリミティブとの交差を計算する時間
コード
@dataclass(slots=True)
class BucketInfo:
    bounds: Bounds | None = None
    count: int = 0

    def __post_init__(self) -> None:
        if self.bounds is None:
            self.bounds = Bounds.empty()


def evaluate_sah_cost(sa: float, l_sa: float, l_count: int, r_sa: float, r_count: int) -> float:
    return 0.125 + (l_sa * l_count + r_sa * r_count) / sa


def compute_sah_costs(sa: float, buckets: list[BucketInfo]) -> np.ndarray:
    num_buckets = len(buckets)
    costs = np.empty(num_buckets - 1)

    for i in range(num_buckets - 1):
        l_bounds, r_bounds = Bounds.empty(), Bounds.empty()
        l_count, r_count = 0, 0

        for j in range(0, i + 1):
            bucket = buckets[j]
            l_bounds = Bounds.union(l_bounds, bucket.bounds)
            l_count += bucket.count

        for j in range(i + 1, num_buckets):
            bucket = buckets[j]
            r_bounds = Bounds.union(r_bounds, bucket.bounds)
            r_count += bucket.count

        if l_count == 0 or r_count == 0:
            costs[i] = np.inf
        else:
            costs[i] = evaluate_sah_cost(sa, l_bounds.surface_area, l_count, r_bounds.surface_area, r_count)

    return costs
コード
MAX_PRIMS_IN_NODE = 5  # <= 255


def sah_split(primitive_info: list[BVHPrimitiveInfo], start: int, end: int, axis: int, **kwargs) -> int | None:
    bounds = kwargs["bounds"]
    centroid_bounds = kwargs["centroid_bounds"]

    num_primitives = end - start

    if num_primitives <= 4:
        # Partition primitives int equally sized subsets
        mid = (start + end) // 2
        nth_element(primitive_info, mid, lo=start, hi=end - 1, key=lambda pi: pi.centroid[axis])
        return mid

    extent = centroid_bounds.extent

    # Allocate BucketInfo for SAH partition buckets
    NUM_BUCKETS = 12
    buckets = [BucketInfo() for _ in range(NUM_BUCKETS)]

    # Initialize BucketInfo for SAH partition buckets
    for i in range(start, end):
        pi = primitive_info[i]
        pos = (pi.centroid[axis] - centroid_bounds.lower[axis]) / extent[axis]
        ind = min(NUM_BUCKETS - 1, int(pos * NUM_BUCKETS))
        bucket = buckets[ind]
        bucket.bounds = Bounds.union(bucket.bounds, pi.bounds)
        bucket.count += 1

    # Compute costs for splitting after each bucket
    costs = compute_sah_costs(bounds.surface_area, buckets)

    # Find bucket to split at that minimizes SAH metric
    min_cost, min_cost_split = np.inf, 0

    for i in range(NUM_BUCKETS - 1):
        if costs[i] < min_cost:
            min_cost = costs[i]  # type: ignore
            min_cost_split = i

    # Either create leaf or split primitives at selected SAH bucket
    if np.isinf(min_cost) or (num_primitives <= MAX_PRIMS_IN_NODE and min_cost >= num_primitives):
        return None  # Create leaf BVHBuildNode

    split_pos = centroid_bounds.lower[axis] + (min_cost_split + 1) * (extent[axis] / NUM_BUCKETS)
    mid = partition(primitive_info, start=start, end=end, pred=lambda pi: pi.centroid[axis] < split_pos)
    return mid
コード
sah_bvh_nodes, sah_bvh_primitives = build_bvh(primitives, split_func=sah_split)
bvh_variant = BVHVariant(name="SAH BVH", nodes=sah_bvh_nodes, primitives=sah_bvh_primitives, filename="sah_bvh.png")
bvh_variants.append(bvh_variant)

2.3.1 木構造

コード
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)

# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))

2.3.2 空間構造

コード
fig, ax = plt.subplots(figsize=(5, 5), constrained_layout=True)

plot_bvh(bvh_variant.nodes, bvh_variant.primitives, ax)

# Styling and layout
ax.set_title(bvh_variant.name)
ax.set_aspect("equal", adjustable="box")
ax.grid(True)

plt.show()

2.4 LBVH

ここでは,20bits のモートン符号を使用する.

要点

  1. プリミティブの重心からなる点群に対して,軸平行境界ボックス(AABB)を作成する
  2. 各プリミティの重心を AABB の範囲で正規化する
  3. 正規化した重心を,モートン符号のビット数に応じてスケーリングする
  4. 各重心にモートン符号を割り当てる
  5. モートン符号に基づいてプリミティブを基数ソートする
  6. モートン符号に従ってプリミティブを二分する
コード
@dataclass(slots=True)
class MortonPrimitive:
    primitive_index: int
    morton_code: np.uint32


@dataclass(slots=True)
class LBVHTreelet:
    start_index: int
    num_primitives: int
    build_nodes: list[BVHBuildNode]


def interleave_bits(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.uint32)
    x = (x | (x << 8)) & np.uint32(0x00FF00FF)
    x = (x | (x << 4)) & np.uint32(0x0F0F0F0F)
    x = (x | (x << 2)) & np.uint32(0x33333333)
    x = (x | (x << 1)) & np.uint32(0x55555555)
    return x


def encode_morton(v: np.ndarray) -> np.ndarray:
    v = np.asarray(v)
    v = interleave_bits(v)
    x = np.take(v, 0, axis=-1)
    y = np.take(v, 1, axis=-1)
    return (y << 1) | x

基数ソートの実装

  • numba での部分最適化が必須
  • In-place は若干分が悪い
  • Python なら Timsort で妥協するのもあり
コード
from numba import njit


@njit
def radix_sort_core(keys: np.ndarray, indices: np.ndarray, bits_per_pass: int, num_bits: int) -> None:
    n = len(keys)
    num_passes = (num_bits + bits_per_pass - 1) // bits_per_pass
    num_buckets = 1 << bits_per_pass
    bit_mask = num_buckets - 1

    temp_indices = np.empty(n, dtype=np.int32)
    count = np.empty(num_buckets, dtype=np.int32)

    for pass_num in range(num_passes):
        low_bit = pass_num * bits_per_pass
        src, dst = (indices, temp_indices) if pass_num % 2 == 0 else (temp_indices, indices)

        count.fill(0)

        for ind in src:
            key = keys[ind]
            bucket = (key >> low_bit) & bit_mask
            count[bucket] += 1

        for i in range(1, num_buckets):
            count[i] += count[i - 1]

        for ind in src[::-1]:
            key = keys[ind]
            bucket = (key >> low_bit) & bit_mask
            count[bucket] -= 1
            dst[count[bucket]] = ind

    if num_passes % 2 == 1:
        indices[:] = temp_indices


def apply_permutation(lst: list[Any], indices: np.ndarray) -> None:
    n = len(indices)

    for i in range(n):
        if indices[i] < 0:
            continue

        j = i
        while i != (k := indices[j] % n):
            lst[j], lst[k] = lst[k], lst[j]
            indices[j] -= n
            j = k

        indices[j] -= n


def radix_sort(lst: list[T], *, key: Callable[[T], object] = lambda x: x, bits_per_pass: int = 6, num_bits: int = 30, out: list[T] | None = None) -> list[T]:
    n = len(lst)
    if n <= 1:
        if out is None:
            return lst
        elif out is not lst:
            out[:] = lst
        return out

    keys = np.fromiter((key(x) for x in lst), dtype=np.uint32, count=len(lst))
    indices = np.arange(n, dtype=np.int32)

    radix_sort_core(keys, indices, bits_per_pass, num_bits)

    # Reordering of lst using indices
    if out is None:
        return [lst[i] for i in indices]

    if lst is not out:
        out[:] = lst

    apply_permutation(out, indices)
    return out
コード
@dataclass(slots=True)
class EmitLBVHContext:
    primitives: list[Primitive]
    build_node_index: int = 0
    total_nodes: int = 0
    ordered_prims_offset: int = 0


def emit_lbvh(
    build_nodes: list[BVHBuildNode],
    primitive_info: list[BVHPrimitiveInfo],
    morton_primitives: list[MortonPrimitive],
    start: int,
    end: int,
    ordered_primitives: list[Primitive],
    bit_index: int,
    context: EmitLBVHContext,
) -> BVHBuildNode:
    num_primitives = end - start
    if bit_index == -1 or num_primitives < MAX_PRIMS_IN_NODE:
        # Create and return leaf node of LBVH treelet
        node = build_nodes[context.build_node_index]
        context.build_node_index += 1
        context.total_nodes += 1

        first_prim_offset = context.ordered_prims_offset
        context.ordered_prims_offset += num_primitives

        bounds = Bounds.empty()
        for i in range(start, end):
            ind = morton_primitives[i].primitive_index
            ordered_primitives[first_prim_offset + i - start] = context.primitives[ind]
            bounds = Bounds.union(bounds, primitive_info[ind].bounds)

        node.init_leaf(bounds=bounds, first_prim_offset=first_prim_offset, num_primitives=num_primitives)
        return node

    mask = np.uint32(1 << bit_index)
    # Advance to next subtree level if there's no LBVH split for this bit
    if (morton_primitives[start].morton_code & mask) == (morton_primitives[end - 1].morton_code & mask):
        return emit_lbvh(build_nodes, primitive_info, morton_primitives, start, end, ordered_primitives, bit_index - 1, context)

    # Find LBVH split point for this dimension
    search_start, search_end = start, end
    while search_start + 1 != search_end:
        mid = (search_start + search_end) // 2
        if (morton_primitives[search_start].morton_code & mask) == (morton_primitives[mid].morton_code & mask):
            search_start = mid
        else:
            search_end = mid

    split_offset = search_end

    # Create and return interior LBVH node
    node = build_nodes[context.build_node_index]
    context.build_node_index += 1
    context.total_nodes += 1

    axis = bit_index % 2
    left = emit_lbvh(build_nodes, primitive_info, morton_primitives, start, split_offset, ordered_primitives, bit_index - 1, context)
    right = emit_lbvh(build_nodes, primitive_info, morton_primitives, split_offset, end, ordered_primitives, bit_index - 1, context)
    node.init_interior(axis=axis, left=left, right=right)
    return node


def build_lbvh(primitives: list[Primitive]) -> tuple[list[LinearBVHNode], list[Primitive]]:
    MORTON_BITS = 10
    MORTON_SCALE = 1 << MORTON_BITS
    FIRST_BIT_INDEX = 19

    num_primitives = len(primitives)

    primitive_info = [BVHPrimitiveInfo(i, primitives[i].bounds) for i in range(num_primitives)]

    # Compute bounding box of all primitive centroids
    bounds = compute_bounds(primitive_info)
    extent = bounds.extent

    # Compute Morton indices of primitives
    morton_primitives: list[MortonPrimitive] = []
    for pi in primitive_info:
        scaled = (pi.centroid - bounds.lower) / extent * MORTON_SCALE
        scaled = np.minimum(MORTON_SCALE - 1, scaled)
        morton_primitives.append(MortonPrimitive(pi.index, encode_morton(scaled).item()))

    # Radix sort primitive Morton indices
    radix_sort(morton_primitives, key=lambda p: p.morton_code, bits_per_pass=5, num_bits=20, out=morton_primitives)

    # Create LBVH treelet at bottom of BVH
    num_primitives = len(primitive_info)
    max_bvh_nodes = 2 * num_primitives - 1
    nodes = [BVHBuildNode() for _ in range(max_bvh_nodes)]
    treelet = LBVHTreelet(0, num_primitives, nodes)

    # Create LBVH for treelet
    ordered_primitives: list[Primitive] = [None] * num_primitives  # type: ignore
    emit_lbvh_context = EmitLBVHContext(primitives)
    emit_lbvh(treelet.build_nodes, primitive_info, morton_primitives, 0, treelet.num_primitives, ordered_primitives, FIRST_BIT_INDEX, emit_lbvh_context)

    flatten_bvh_tree_context = FlattenBVHTreeContext(nodes=[LinearBVHNode() for _ in range(emit_lbvh_context.total_nodes)])
    flatten_bvh_tree(treelet.build_nodes[0], flatten_bvh_tree_context)
    return flatten_bvh_tree_context.nodes, ordered_primitives
コード
lbvh_nodes, lbvh_primitives = build_lbvh(primitives)
bvh_variant = BVHVariant(name="LBVH", nodes=lbvh_nodes, primitives=lbvh_primitives, filename="lbvh.png")
bvh_variants.append(bvh_variant)

2.4.1 木構造

コード
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)

# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))

2.4.2 空間構造

コード
fig, ax = plt.subplots(figsize=(5, 5), constrained_layout=True)

plot_bvh(bvh_variant.nodes, bvh_variant.primitives, ax)

# Styling and layout
ax.set_title(bvh_variant.name)
ax.set_aspect("equal", adjustable="box")
ax.grid(True)

plt.show()

2.5 HLBVH

要点

  1. プリミティブの重心からなる点群に対して,軸平行境界ボックス(AABB)を作成する
  2. 各プリミティの重心を AABB の範囲で正規化する
  3. 正規化した重心を,モートン符号のビット数に応じてスケーリングする
  4. 各重心にモートン符号を割り当てる
  5. モートン符号に基づいてプリミティブを基数ソートする
  6. 任意の区画毎に Treelet を作成する
  7. 各 Treelet から LBVH を作成する
  8. 各 LBVH を SAH ベースで統合し,最終的な BVH を作成する
コード
@dataclass(slots=True)
class BuildUpperSAHContext:
    total_nodes: int = 0


def build_upper_sah(treelet_roots: list[BVHBuildNode], start: int, end: int, context: BuildUpperSAHContext) -> BVHBuildNode:
    num_nodes = end - start
    if num_nodes == 1:
        return treelet_roots[start]

    node = BVHBuildNode()
    context.total_nodes += 1

    # Compute bounds of all nodes under this HLBVH node
    bounds = Bounds.empty()
    for i in range(start, end):
        bounds = Bounds.union(bounds, treelet_roots[i].bounds)

    # Compute bound of HLBVH node centroids, choose split dimension _dim_
    centroid_bounds = Bounds.empty()
    for i in range(start, end):
        centroid_bounds = Bounds.union(centroid_bounds, treelet_roots[i].bounds.center)

    extent = centroid_bounds.extent
    axis: int = np.argmax(extent)  # type: ignore

    # Initialize _BucketInfo_ for HLBVH SAH partition buckets
    NUM_BUCKETS = 12
    buckets = [BucketInfo() for _ in range(NUM_BUCKETS)]

    for i in range(start, end):
        centroid = treelet_roots[i].bounds.center
        pos = (centroid[axis] - centroid_bounds.lower[axis]) / extent[axis]
        ind = min(NUM_BUCKETS - 1, int(pos * NUM_BUCKETS))
        bucket = buckets[ind]
        bucket.bounds = Bounds.union(bucket.bounds, treelet_roots[i].bounds)
        bucket.count += 1

    # Compute costs for splitting after each bucket
    costs = compute_sah_costs(bounds.surface_area, buckets)

    # Find bucket to split at that minimizes SAH metric
    min_cost, min_cost_split = np.inf, 0

    for i in range(NUM_BUCKETS - 1):
        if costs[i] < min_cost:
            min_cost = costs[i]  # type: ignore
            min_cost_split = i

    # Split nodes and create interior HLBVH SAH node
    split_pos = centroid_bounds.lower[axis] + (min_cost_split + 1) * (extent[axis] / NUM_BUCKETS)
    mid = partition(treelet_roots, start=start, end=end, pred=lambda tr: tr.bounds.center[axis] < split_pos)

    node.init_interior(axis=axis, left=build_upper_sah(treelet_roots, start, mid, context), right=build_upper_sah(treelet_roots, mid, end, context))
    return node


def build_hlbvh(primitives: list[Primitive]) -> tuple[list[LinearBVHNode], list[Primitive]]:
    MORTON_BITS = 10
    MORTON_SCALE = 1 << MORTON_BITS
    FIRST_BIT_INDEX = 15

    primitive_info = [BVHPrimitiveInfo(i, primitives[i].bounds) for i in range(len(primitives))]
    total_nodes = 0

    # Compute bounding box of all primitive centroids
    bounds = compute_bounds(primitive_info)
    extent = bounds.extent

    # Compute Morton indices of primitives
    morton_primitives = []
    for pi in primitive_info:
        scaled = (pi.centroid - bounds.lower) / extent * MORTON_SCALE
        scaled = np.minimum(MORTON_SCALE - 1, scaled)
        morton_primitives.append(MortonPrimitive(pi.index, encode_morton(scaled).item()))

    # Radix sort primitive Morton indices
    radix_sort(morton_primitives, key=lambda p: p.morton_code, bits_per_pass=5, num_bits=20, out=morton_primitives)

    # Create LBVH treelets at bottom of BVH
    treelets_to_build = []
    mask = np.uint32(((1 << (19 - FIRST_BIT_INDEX)) - 1) << (FIRST_BIT_INDEX + 1))

    start, end = 0, 1
    while end <= len(morton_primitives):
        if end == len(morton_primitives) or ((morton_primitives[start].morton_code & mask) != (morton_primitives[end].morton_code & mask)):
            num_primitives = end - start
            max_bvh_nodes = 2 * num_primitives - 1
            nodes = [BVHBuildNode() for _ in range(max_bvh_nodes)]
            treelets_to_build.append(LBVHTreelet(start, num_primitives, nodes))
            start = end

        end += 1

    # Create LBVHs for treelets
    total_nodes = 0
    ordered_primitives: list[Primitive] = [None] * len(primitives)  # type: ignore
    context = EmitLBVHContext(primitives)

    for treelet in treelets_to_build:
        context.build_node_index = 0
        context.total_nodes = 0
        emit_lbvh(treelet.build_nodes, primitive_info, morton_primitives, treelet.start_index, treelet.start_index + treelet.num_primitives, ordered_primitives, FIRST_BIT_INDEX, context)
        total_nodes += context.total_nodes

    # Create and return SAH BVH from LBVH treelets
    finished_treelets = [treelet.build_nodes[0] for treelet in treelets_to_build]
    build_upper_sah_context = BuildUpperSAHContext(total_nodes=total_nodes)
    root = build_upper_sah(finished_treelets, 0, len(finished_treelets), build_upper_sah_context)
    total_nodes = build_upper_sah_context.total_nodes

    flatten_bvh_tree_context = FlattenBVHTreeContext(nodes=[LinearBVHNode() for _ in range(total_nodes)])
    flatten_bvh_tree(root, flatten_bvh_tree_context)
    return flatten_bvh_tree_context.nodes, ordered_primitives
コード
hlbvh_nodes, hlbvh_primitives = build_hlbvh(primitives)
bvh_variant = BVHVariant(name="HLBVH", nodes=hlbvh_nodes, primitives=hlbvh_primitives, filename="hlbvh.png")
bvh_variants.append(bvh_variant)

2.5.1 木構造

コード
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)

# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))

2.5.2 空間構造

コード
fig, ax = plt.subplots(figsize=(5, 5), constrained_layout=True)

plot_bvh(bvh_variant.nodes, bvh_variant.primitives, ax)

# Styling and layout
ax.set_title(bvh_variant.name)
ax.set_aspect("equal", adjustable="box")
ax.grid(True)

plt.show()

2.6 比較

各構築手法を比較する.

2.6.1 木構造

コード
from PIL import Image


try:
    images = [Image.open(variant.filename) for variant in bvh_variants]
    margin = int(0.05 * max(image.width for image in images))
    width = sum(image.width for image in images) + margin * (len(images) - 1)
    height = max(image.height for image in images)

    dst = Image.new("RGB", (width, height), (255, 255, 255))
    x = 0
    for image in images:
        dst.paste(image, (x, 0))
        x += image.width + margin

    dst.save("combined.png")
    display.display(display.Image("combined.png", width=800))
except Exception as e:
    print(e)

2.6.2 空間構造

コード
num_variants = len(bvh_variants)

fig, axes = plt.subplots(nrows=1, ncols=num_variants, figsize=(5 * num_variants, 5), constrained_layout=True)

for i, (ax, variant) in enumerate(zip(axes, bvh_variants)):
    plot_bvh(variant.nodes, variant.primitives, ax)
    ax.set_title(variant.name)
    ax.set_aspect("equal", adjustable="box")
    ax.grid(True)
    if i != 0:
        ax.tick_params(labelleft=False)
        ax.spines["left"].set_visible(False)

plt.show()

3 光線による BVH の走査

構築された BVH に対して光線を飛ばし,ノードを走査しながらプリミティブとの交差判定を行う.

走査について

  • スタックでアクセス
    ノードは深さ優先順に線形化されており,左の子は index + 1,右の子はあらかじめ格納された offset により特定できるため,ポインタを使わずにスタックで辿れる.
    Pythonなので線形化は語弊があるが
  • dir_is_neg を使う理由
    光線の進行方向に応じて,より近い側のノードをスタックの後に追加することで,近い順に交差判定でき,早期終了や枝刈りの可能性が高まる.
コード
@dataclass(slots=True)
class Intersection:
    t: float
    primitive: Primitive


def intersect(nodes: list[LinearBVHNode], ordered_primitives: list[Primitive], ray: Ray) -> Intersection | None:
    isect = None
    dir_is_neg = ray.direction < 0.0
    stack = []
    stack.append(0)

    while stack:
        ind = stack.pop()
        node = nodes[ind]

        is_hit, _, _ = node.bounds.intersect(ray)
        if not is_hit:
            continue

        if node.is_leaf():
            prims = node.primitives(ordered_primitives)
            for p in prims:
                t = p.intersect(ray)
                if t is not None and (isect is None or t < isect.t):
                    isect = Intersection(t, p)
        else:
            l_ind, r_ind = node.child_indices(ind)
            if dir_is_neg[node.axis]:
                stack.append(l_ind)
                stack.append(r_ind)
            else:
                stack.append(r_ind)
                stack.append(l_ind)

    return isect
コード
# BVH variant chosen for use in subsequent traversal
bvh_variant = bvh_variants[3]

assert bvh_variant.nodes is not None
assert bvh_variant.primitives is not None
コード
ro, rd = np.array([0.0, 0.0]), np.array([1.0, 1.0])
rd = rd / np.linalg.norm(rd)

ray = Ray(ro, rd)
isect = intersect(bvh_variant.nodes, bvh_variant.primitives, ray)
if isect is not None:
    print(f"{isect.t=}, {isect.primitive.name=}")
isect.t=np.float64(41.20525616375596), isect.primitive.name='T24'

3.1 可視化

コード
def plot_intersection(ray, isect, ax):
    ro, rd = ray.origin, ray.direction

    if isect is not None:
        t = isect.t
        p = ray.at(t)
        ax.plot(*np.vstack([ro, p]).T, color="crimson", linestyle="--", linewidth=1, alpha=0.25)
        ax.scatter(p[0], p[1], s=50, c="crimson", marker="x", linewidths=1, label="Closest Hit")

    ax.quiver(*ro, *rd, angles="xy", scale_units="xy", scale=1, width=0.003, color="crimson", label="Ray")
コード
fig, ax = plt.subplots(figsize=(6, 6), constrained_layout=True)

plot_bvh(bvh_variant.nodes, bvh_variant.primitives, ax)
plot_intersection(ray, isect, ax)

# Styling and layout
plt.title(f"{bvh_variant.name}: Ray Traversal")
ax.set_aspect("equal", adjustable="box")
ax.grid(True)

ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), borderaxespad=0.0)

plt.show()

4 References

  • Pantaleoni, J., & Luebke, D. (2010, June). HLBVH: Hierarchical LBVH construction for real-time ray tracing of dynamic geometry. In Proceedings of the Conference on High Performance Graphics (HPG 2010) (pp. 87–95). Eurographics Association. Retrieved from https://dl.acm.org/doi/10.5555/1921479.1921493
  • Zellmann, S., Hellmann, M., & Lang, U. (2019, April). A linear time BVH construction algorithm for sparse volumes. In 2019 IEEE Pacific Visualization Symposium (PacificVis) (pp. 222–226). IEEE. https://doi.org/10.1109/PacificVis.2019.00033
  • Pharr, M., Jakob, W., & Humphreys, G. (2021). Physically Based Rendering: From Theory to Implementation (3rd ed.). Bounding Volume Hierarchies. https://pbr-book.org/3ed-2018/Primitives_and_Intersection_Acceleration/Bounding_Volume_Hierarchies
  • Wikipedia contributors. (2025, May 31). Quicksort. Wikipedia. https://en.wikipedia.org/w/index.php?title=Quicksort&oldid=1293231937
トップに戻る