コード
!pip install anytree --quiet
2025年7月16日
2025年8月2日
Setup for Running in Google Colab
import random
import numpy as np
import numpy.typing as npt
from abc import ABC, abstractmethod
from itertools import combinations
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Protocol, Sequence, TypeVar
@dataclass(slots=True)
class Ray:
origin: np.ndarray
direction: np.ndarray
def at(self, t: float) -> np.ndarray:
return self.origin + t * self.direction
@dataclass(slots=True)
class Bounds:
lower: np.ndarray
upper: np.ndarray
def __post_init__(self) -> None:
self.lower = np.array(self.lower)
self.upper = np.array(self.upper)
@classmethod
def empty(cls, dim: int = 2) -> Bounds:
return cls(np.full(dim, np.inf), np.full(dim, -np.inf))
@classmethod
def union(cls, lhs: npt.ArrayLike, rhs: npt.ArrayLike) -> Bounds:
if isinstance(lhs, cls) and isinstance(rhs, cls):
return cls(np.minimum(lhs.lower, rhs.lower), np.maximum(lhs.upper, rhs.upper))
if isinstance(lhs, cls) and isinstance(rhs, (np.ndarray, list, tuple)):
return cls(np.minimum(lhs.lower, rhs), np.maximum(lhs.upper, rhs))
if isinstance(lhs, (np.ndarray, list, tuple)) and isinstance(rhs, cls):
return cls(np.minimum(lhs, rhs.lower), np.maximum(lhs, rhs.upper))
raise TypeError(f"Unsupported types: {type(lhs)}, {type(rhs)}")
@property
def center(self) -> np.ndarray:
return 0.5 * (self.lower + self.upper)
@property
def extent(self) -> np.ndarray:
return self.upper - self.lower
@property
def surface_area(self) -> float:
extent = self.extent
if extent.size < 2:
return 0.0
return 2.0 * sum(np.prod(extent[k]) for k in combinations(range(extent.size), extent.size - 1)) # type: ignore
def intersect(self, ray: Ray) -> tuple[bool, float, float]:
return intersect_ray_aabb(ray.origin, ray.direction, self.lower, self.upper)
class Primitive(ABC):
@property
@abstractmethod
def bounds(self) -> Bounds:
pass
@property
@abstractmethod
def centroid(self) -> np.ndarray:
pass
@property
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def intersect(self, ray: Ray) -> float | None:
pass
@dataclass(slots=True)
class Triangle(Primitive):
_vertices: np.ndarray
_indices: list[int] | None
_name: str
def __init__(self, vertices, indices=None, name="") -> None:
self._vertices = np.asarray(vertices)
self._indices = indices
self._name = name
def vertices(self, copy: bool = False) -> np.ndarray:
if self._indices is None:
return np.array(self._vertices, copy=copy)
return self._vertices[self._indices]
@property
def bounds(self) -> Bounds:
verts = self.vertices()
return Bounds(np.min(verts, axis=0), np.max(verts, axis=0))
@property
def centroid(self) -> np.ndarray:
verts = self.vertices()
return np.mean(verts, axis=0)
@property
def name(self) -> str:
return self._name
def intersect(self, ray: Ray) -> float | None:
return intersect_ray_triangle(ray.origin, ray.direction, *self.vertices())
交差関数群
intersect_ray_aabb
はスラブ法を利用def perp_dot(a: np.ndarray, b: np.ndarray) -> float:
return a[0] * b[1] - a[1] * b[0]
def intersect_ray_segment(ro: np.ndarray, rd: np.ndarray, p0: np.ndarray, p1: np.ndarray) -> float | None:
u = p1 - p0
v = ro - p0
inv_denom = 1.0 / perp_dot(rd, u) # IEEE 754
s = perp_dot(u, v) * inv_denom
t = perp_dot(rd, v) * inv_denom
if 0.0 <= s < np.inf and 0.0 <= t <= 1.0:
return s
return None
def intersect_ray_triangle(ro: np.ndarray, rd: np.ndarray, v0: np.ndarray, v1: np.ndarray, v2: np.ndarray) -> float | None:
t = intersect_ray_segment(ro, rd, v0, v1)
t1 = intersect_ray_segment(ro, rd, v1, v2)
if t1 is not None and (t is None or t1 < t):
t = t1
t2 = intersect_ray_segment(ro, rd, v2, v0)
if t2 is not None and (t is None or t2 < t):
t = t2
return t
def intersect_ray_aabb(ro: np.ndarray, rd: np.ndarray, lower: np.ndarray, upper: np.ndarray) -> tuple[bool, float, float]:
t_min, t_max = 0.0, np.inf
for i in range(ro.size):
inv_d = 1.0 / rd[i] # IEEE 754
t0 = (lower[i] - ro[i]) * inv_d
t1 = (upper[i] - ro[i]) * inv_d
t_min = np.fmax(t_min, np.fmin(t0, t1))
t_max = np.fmin(t_max, np.fmax(t0, t1))
return t_min < t_max, t_min, t_max
三角形群の生成
from scipy.spatial import cKDTree
def generate_triangles(n, xmin, ymin, xmax, ymax, *, points_kwargs=None, triangle_kwargs=None):
merged_points_kwargs = {
"min_dist": 1.0,
"domain": (0, 1, 0, 1),
"steps": 50,
"lr": 0.01,
}
merged_points_kwargs.update({} if points_kwargs is None else points_kwargs)
merged_triangle_kwargs = {
"angle_jitter": 30.0,
"max_radius": 1.0,
"scale_range": (0.5, 1.0),
}
merged_triangle_kwargs.update({} if triangle_kwargs is None else triangle_kwargs)
max_radius = merged_triangle_kwargs["max_radius"]
def generate_triangle(center, *, angle_jitter=30.0, max_radius=1.0, scale_range=(0.5, 1.0)):
assert abs(angle_jitter) < 60.0, "angle_jitter must be less than 60 degrees to preserve CCW order"
angles = np.linspace(0.0, 2.0 * np.pi, 4)[:-1]
angles += np.deg2rad(np.random.uniform(-angle_jitter, angle_jitter, 3))
dirs = np.stack([np.cos(angles), np.sin(angles)], axis=-1)
radii = np.random.uniform(scale_range[0], scale_range[1], (3, 1)) * max_radius
return center.reshape(1, 2) + radii * dirs
def generate_points(n, *, min_dist=1.0, domain=(0, 1, 0, 1), steps=50, lr=0.01):
xmin, xmax, ymin, ymax = domain
points = np.random.rand(n, 2)
r = min_dist / max(xmax - xmin, ymax - ymin)
for _ in range(steps):
tree = cKDTree(points)
pairs = tree.query_pairs(r)
net = np.zeros_like(points)
for i, j in pairs:
delta = points[i] - points[j]
dist = np.linalg.norm(delta) + 1e-10
force = delta / dist
net[i] += force
net[j] -= force
points += lr * net
points[:, 0] = np.clip(points[:, 0], 0, 1)
points[:, 1] = np.clip(points[:, 1], 0, 1)
points[:, 0] = points[:, 0] * (xmax - xmin) + xmin
points[:, 1] = points[:, 1] * (ymax - ymin) + ymin
return points
domain = (xmin, xmax, ymin, ymax)
safe_domain = (
xmin + max_radius,
xmax - max_radius,
ymin + max_radius,
ymax - max_radius,
)
merged_points_kwargs["domain"] = safe_domain
centers = generate_points(n, **merged_points_kwargs)
return [generate_triangle(center, **merged_triangle_kwargs) for center in centers]
random.seed(0)
np.random.seed(4)
num_triangles = 50
num_digits = len(str(num_triangles - 1))
points_kwargs = {"min_dist": 4.0}
triangle_kwargs = {"max_radius": 2.0, "scale_range": (0.65, 0.95)}
triangles = generate_triangles(num_triangles, 0.0, 0.0, 50.0, 50.0, points_kwargs=points_kwargs, triangle_kwargs=triangle_kwargs)
primitives: list[Primitive] = [Triangle(tri, name=f"T{i:0{num_digits}}") for i, tri in enumerate(triangles)]
共通の定義
@dataclass(init=False, slots=True)
class BVHBuildNode:
bounds: Bounds
split_axis: int
left: BVHBuildNode | None
right: BVHBuildNode | None
first_prim_offset: int
num_primitives: int
def init_leaf(self, bounds: Bounds, first_prim_offset: int, num_primitives: int) -> None:
self.bounds = bounds
self.split_axis = 0
self.left = self.right = None
self.first_prim_offset = first_prim_offset
self.num_primitives = num_primitives
def init_interior(self, axis: int, left: BVHBuildNode, right: BVHBuildNode) -> None:
self.bounds = Bounds.union(left.bounds, right.bounds)
self.split_axis = axis
self.left = left
self.right = right
self.first_prim_offset = 0
self.num_primitives = 0
def is_leaf(self) -> bool:
return self.num_primitives > 0
def primitives(self, ordered_primitives: list[Primitive]) -> list[Primitive]:
offset, count = self.first_prim_offset, self.num_primitives
return ordered_primitives[offset : offset + count]
@dataclass(slots=True)
class BVHPrimitiveInfo:
index: int
bounds: Bounds
centroid: np.ndarray = field(init=False)
def __post_init__(self) -> None:
self.centroid = self.bounds.center
@dataclass(slots=True)
class BVHVariant:
name: str = ""
nodes: list[LinearBVHNode] | None = None
primitives: list[Primitive] | None = None
filename: str = ""
次のような木構造があるとする.
A
├─ C
└─ B
├─ E
└─ D
表にすると,
\[ \begin{array}{c|c|c|c} \text{Node} & \text{Left} & \text{Right} & \text{Type}\\ \hline \text{A} & \text{B} & \text{C} & \text{Interior}\\ \text{B} & \text{D} & \text{E} & \text{Interior}\\ \text{C} & \text{-} & \text{-} & \text{Leaf}\\ \text{D} & \text{-} & \text{-} & \text{Leaf}\\ \text{E} & \text{-} & \text{-} & \text{Leaf} \end{array} \]
深さ優先で平坦化すると,
\[ \begin{array}{c|c|c|c|c} 0 & 1 & 2 & 3 & 4\\ \hline \text{A} & \text{B} & \text{D} & \text{E} & \text{C} \end{array} \]
あるノードの左の子ノードは,存在する場合,常に次の位置に出現する.
@dataclass(slots=True)
class LinearBVHNode:
bounds: Bounds = field(default_factory=Bounds.empty)
offset: int = 0
num_primitives: int = 0
axis: int = 0
def is_leaf(self) -> bool:
return self.num_primitives > 0
def primitives(self, ordered_primitives: list[Primitive]) -> list[Primitive]:
offset, count = self.offset, self.num_primitives
return ordered_primitives[offset : offset + count]
def child_indices(self, index: int) -> tuple[int, int]:
return index + 1, self.offset
@dataclass(slots=True)
class FlattenBVHTreeContext:
nodes: list[LinearBVHNode]
offset: int = 0
def flatten_bvh_tree(node: BVHBuildNode, context: FlattenBVHTreeContext) -> int:
linearNode = context.nodes[context.offset]
linearNode.bounds = node.bounds
current_offset = context.offset
context.offset += 1
if node.is_leaf():
linearNode.offset = node.first_prim_offset
linearNode.num_primitives = node.num_primitives
else:
assert node.left is not None and node.right is not None
linearNode.axis = node.split_axis
flatten_bvh_tree(node.left, context)
linearNode.offset = flatten_bvh_tree(node.right, context)
return current_offset
要点
C++ の std::partition や std::nth_element に相当する関数群
T = TypeVar("T")
def partition(lst: list[T], *, start: int, end: int, pred: Callable[[T], bool]) -> int:
i = start
for j in range(start, end):
if pred(lst[j]):
lst[i], lst[j] = lst[j], lst[i]
i += 1
return i
class SupportsRandInt(Protocol):
def randint(self, a: int, b: int) -> int: ...
def nth_element(lst: list[T], n: int, *, lo: int = 0, hi: int | None = None, key: Callable[[T], object] = lambda x: x, rng: SupportsRandInt = random) -> None:
def partition(lst, lo, hi, key):
"""
Lomuto partition scheme
https://en.wikipedia.org/wiki/Quicksort
"""
pivot_index = rng.randint(lo, hi)
lst[pivot_index], lst[hi] = lst[hi], lst[pivot_index]
pivot = key(lst[hi])
i = lo
for j in range(lo, hi):
if key(lst[j]) <= pivot:
lst[i], lst[j] = lst[j], lst[i]
i += 1
lst[i], lst[hi] = lst[hi], lst[i]
return i
if hi is None:
hi = len(lst) - 1
if not (lo <= n <= hi):
raise IndexError("index out of range")
while lo <= hi:
pivot_index = partition(lst, lo, hi, key)
if pivot_index == n:
return
elif pivot_index < n:
lo = pivot_index + 1
else:
hi = pivot_index - 1
class HasBounds(Protocol):
bounds: Bounds
def compute_bounds(bounded_objects: Sequence[HasBounds]) -> Bounds:
bounds = Bounds.empty()
for obj in bounded_objects:
bounds = Bounds.union(bounds, obj.bounds)
return bounds
def build_bvh(primitives: list[Primitive], split_func) -> tuple[list[LinearBVHNode], list[Primitive]]:
primitive_info = [BVHPrimitiveInfo(i, p.bounds) for i, p in enumerate(primitives)]
total_nodes = 0
ordered_primitives: list[Primitive] = []
def build(start, end):
nonlocal total_nodes
node = BVHBuildNode()
total_nodes += 1
# Compute bounds of all primitives in BVH node
bounds = compute_bounds(primitive_info[start:end])
num_primitives = end - start
if num_primitives == 1:
# Create leaf BVHBuildNode
offset = len(ordered_primitives)
for i in range(start, end):
ind = primitive_info[i].index
ordered_primitives.append(primitives[ind])
node.init_leaf(bounds=bounds, first_prim_offset=offset, num_primitives=num_primitives)
return node
# Compute bound of primitive centroids, choose split dimension dim
centroid_bounds = Bounds.empty()
for i in range(start, end):
centroid_bounds = Bounds.union(centroid_bounds, primitive_info[i].centroid)
extent = centroid_bounds.extent
axis = np.argmax(extent)
if np.isclose(extent[axis], 0.0):
# Create leaf BVHBuildNode
offset = len(ordered_primitives)
for i in range(start, end):
ind = primitive_info[i].index
ordered_primitives.append(primitives[ind])
node.init_leaf(bounds=bounds, first_prim_offset=offset, num_primitives=num_primitives)
return node
# Partition primitives into two sets and build children
mid = split_func(primitive_info, start, end, axis, bounds=bounds, centroid_bounds=centroid_bounds)
if mid is None:
# Create leaf BVHBuildNode
offset = len(ordered_primitives)
for i in range(start, end):
ind = primitive_info[i].index
ordered_primitives.append(primitives[ind])
node.init_leaf(bounds=bounds, first_prim_offset=offset, num_primitives=num_primitives)
return node
# Partition primitives
node.init_interior(axis=axis, left=build(start, mid), right=build(mid, end))
return node
root = build(0, len(primitives))
# Compute representation of depth-first traversal of BVH tree
context = FlattenBVHTreeContext(nodes=[LinearBVHNode() for _ in range(total_nodes)])
flatten_bvh_tree(root, context)
return context.nodes, ordered_primitives
def middle_split(primitive_info: list[BVHPrimitiveInfo], start: int, end: int, axis: int, **kwargs) -> int:
centroid_bounds = kwargs["centroid_bounds"]
# Partition primitives through node's midpoint
mid = partition(primitive_info, start=start, end=end, pred=lambda pi: pi.centroid[axis] < centroid_bounds.center[axis])
if mid == start or mid == end:
# Partition primitives into equally sized subsets
mid = (start + end) // 2
nth_element(primitive_info, mid, lo=start, hi=end - 1, key=lambda pi: pi.centroid[axis])
return mid
import anytree
import anytree.exporter
import IPython.display as display
def convert_to_anytree(nodes, ordered_primitives, parent=None):
stack = [(0, None)]
root = None
while stack:
ind, parent = stack.pop()
node = nodes[ind]
if node.is_leaf():
prims = node.primitives(ordered_primitives)
name = "\n".join(f"{p.name}" for p in prims)
at_node = anytree.Node(f"{name}", parent=parent)
else:
lower, uppder = node.bounds.lower, node.bounds.upper
name = f"LB({lower[0]:.2f}, {lower[1]:.2f})\nUB({uppder[0]:.2f}, {uppder[1]:.2f})"
at_node = anytree.Node(name, parent=parent)
l_ind, r_ind = node.child_indices(ind)
stack.append((r_ind, at_node))
stack.append((l_ind, at_node))
if parent is None:
root = at_node
return root
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)
# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))
import matplotlib.pyplot as plt
# print(plt.style.available)
plt.style.use("seaborn-v0_8")
def plot_bounds(bounds, ax, **kwargs):
defaults = dict(fill=False, edgecolor="black", linestyle="--", linewidth=1, alpha=0.5)
defaults.update(kwargs)
extent = bounds.extent
rect = plt.Rectangle(bounds.lower, extent[0], extent[1], **defaults)
ax.add_patch(rect)
ax.update_datalim(rect.get_path().vertices)
ax.autoscale_view()
def plot_triangle(tri, ax, **kwargs):
defaults = dict(edgecolor="black", alpha=0.5)
defaults.update(kwargs)
verts = tri.vertices()
poly = plt.Polygon(verts, closed=True, **defaults)
ax.add_patch(poly)
ax.update_datalim(poly.get_path().vertices)
ax.autoscale_view()
def plot_bvh(nodes, ordered_primitives, ax):
cmap = plt.colormaps["tab20"]
c_ind = 0
stack = []
stack.append(0)
while stack:
ind = stack.pop()
node = nodes[ind]
if node.is_leaf():
color = cmap(c_ind % cmap.N)
c_ind += 1
plot_bounds(node.bounds, ax, edgecolor=color)
prims = node.primitives(ordered_primitives)
for p in prims:
plot_triangle(p, ax, facecolor=color)
centroid = p.centroid
ax.text(centroid[0], centroid[1], p.name, ha="center", va="center", fontsize=8, color="black")
else:
plot_bounds(node.bounds, ax)
l_ind, r_ind = node.child_indices(ind)
stack.append(r_ind)
stack.append(l_ind)
要点
評価は次式で行う.
\[ c(A,B)=t_{\text{trav}}+p_{A}\sum_{i=1}^{N_{A}} t_{\text{isect}}(a_{i})+p_{B}\sum_{i=1}^{N_{B}} t_{\text{isect}}(b_{i}) \]
ここで,
@dataclass(slots=True)
class BucketInfo:
bounds: Bounds | None = None
count: int = 0
def __post_init__(self) -> None:
if self.bounds is None:
self.bounds = Bounds.empty()
def evaluate_sah_cost(sa: float, l_sa: float, l_count: int, r_sa: float, r_count: int) -> float:
return 0.125 + (l_sa * l_count + r_sa * r_count) / sa
def compute_sah_costs(sa: float, buckets: list[BucketInfo]) -> np.ndarray:
num_buckets = len(buckets)
costs = np.empty(num_buckets - 1)
for i in range(num_buckets - 1):
l_bounds, r_bounds = Bounds.empty(), Bounds.empty()
l_count, r_count = 0, 0
for j in range(0, i + 1):
bucket = buckets[j]
l_bounds = Bounds.union(l_bounds, bucket.bounds)
l_count += bucket.count
for j in range(i + 1, num_buckets):
bucket = buckets[j]
r_bounds = Bounds.union(r_bounds, bucket.bounds)
r_count += bucket.count
if l_count == 0 or r_count == 0:
costs[i] = np.inf
else:
costs[i] = evaluate_sah_cost(sa, l_bounds.surface_area, l_count, r_bounds.surface_area, r_count)
return costs
MAX_PRIMS_IN_NODE = 5 # <= 255
def sah_split(primitive_info: list[BVHPrimitiveInfo], start: int, end: int, axis: int, **kwargs) -> int | None:
bounds = kwargs["bounds"]
centroid_bounds = kwargs["centroid_bounds"]
num_primitives = end - start
if num_primitives <= 4:
# Partition primitives int equally sized subsets
mid = (start + end) // 2
nth_element(primitive_info, mid, lo=start, hi=end - 1, key=lambda pi: pi.centroid[axis])
return mid
extent = centroid_bounds.extent
# Allocate BucketInfo for SAH partition buckets
NUM_BUCKETS = 12
buckets = [BucketInfo() for _ in range(NUM_BUCKETS)]
# Initialize BucketInfo for SAH partition buckets
for i in range(start, end):
pi = primitive_info[i]
pos = (pi.centroid[axis] - centroid_bounds.lower[axis]) / extent[axis]
ind = min(NUM_BUCKETS - 1, int(pos * NUM_BUCKETS))
bucket = buckets[ind]
bucket.bounds = Bounds.union(bucket.bounds, pi.bounds)
bucket.count += 1
# Compute costs for splitting after each bucket
costs = compute_sah_costs(bounds.surface_area, buckets)
# Find bucket to split at that minimizes SAH metric
min_cost, min_cost_split = np.inf, 0
for i in range(NUM_BUCKETS - 1):
if costs[i] < min_cost:
min_cost = costs[i] # type: ignore
min_cost_split = i
# Either create leaf or split primitives at selected SAH bucket
if np.isinf(min_cost) or (num_primitives <= MAX_PRIMS_IN_NODE and min_cost >= num_primitives):
return None # Create leaf BVHBuildNode
split_pos = centroid_bounds.lower[axis] + (min_cost_split + 1) * (extent[axis] / NUM_BUCKETS)
mid = partition(primitive_info, start=start, end=end, pred=lambda pi: pi.centroid[axis] < split_pos)
return mid
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)
# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))
ここでは,20bits のモートン符号を使用する.
要点
@dataclass(slots=True)
class MortonPrimitive:
primitive_index: int
morton_code: np.uint32
@dataclass(slots=True)
class LBVHTreelet:
start_index: int
num_primitives: int
build_nodes: list[BVHBuildNode]
def interleave_bits(x: np.ndarray) -> np.ndarray:
x = np.asarray(x, dtype=np.uint32)
x = (x | (x << 8)) & np.uint32(0x00FF00FF)
x = (x | (x << 4)) & np.uint32(0x0F0F0F0F)
x = (x | (x << 2)) & np.uint32(0x33333333)
x = (x | (x << 1)) & np.uint32(0x55555555)
return x
def encode_morton(v: np.ndarray) -> np.ndarray:
v = np.asarray(v)
v = interleave_bits(v)
x = np.take(v, 0, axis=-1)
y = np.take(v, 1, axis=-1)
return (y << 1) | x
基数ソートの実装
from numba import njit
@njit
def radix_sort_core(keys: np.ndarray, indices: np.ndarray, bits_per_pass: int, num_bits: int) -> None:
n = len(keys)
num_passes = (num_bits + bits_per_pass - 1) // bits_per_pass
num_buckets = 1 << bits_per_pass
bit_mask = num_buckets - 1
temp_indices = np.empty(n, dtype=np.int32)
count = np.empty(num_buckets, dtype=np.int32)
for pass_num in range(num_passes):
low_bit = pass_num * bits_per_pass
src, dst = (indices, temp_indices) if pass_num % 2 == 0 else (temp_indices, indices)
count.fill(0)
for ind in src:
key = keys[ind]
bucket = (key >> low_bit) & bit_mask
count[bucket] += 1
for i in range(1, num_buckets):
count[i] += count[i - 1]
for ind in src[::-1]:
key = keys[ind]
bucket = (key >> low_bit) & bit_mask
count[bucket] -= 1
dst[count[bucket]] = ind
if num_passes % 2 == 1:
indices[:] = temp_indices
def apply_permutation(lst: list[Any], indices: np.ndarray) -> None:
n = len(indices)
for i in range(n):
if indices[i] < 0:
continue
j = i
while i != (k := indices[j] % n):
lst[j], lst[k] = lst[k], lst[j]
indices[j] -= n
j = k
indices[j] -= n
def radix_sort(lst: list[T], *, key: Callable[[T], object] = lambda x: x, bits_per_pass: int = 6, num_bits: int = 30, out: list[T] | None = None) -> list[T]:
n = len(lst)
if n <= 1:
if out is None:
return lst
elif out is not lst:
out[:] = lst
return out
keys = np.fromiter((key(x) for x in lst), dtype=np.uint32, count=len(lst))
indices = np.arange(n, dtype=np.int32)
radix_sort_core(keys, indices, bits_per_pass, num_bits)
# Reordering of lst using indices
if out is None:
return [lst[i] for i in indices]
if lst is not out:
out[:] = lst
apply_permutation(out, indices)
return out
@dataclass(slots=True)
class EmitLBVHContext:
primitives: list[Primitive]
build_node_index: int = 0
total_nodes: int = 0
ordered_prims_offset: int = 0
def emit_lbvh(
build_nodes: list[BVHBuildNode],
primitive_info: list[BVHPrimitiveInfo],
morton_primitives: list[MortonPrimitive],
start: int,
end: int,
ordered_primitives: list[Primitive],
bit_index: int,
context: EmitLBVHContext,
) -> BVHBuildNode:
num_primitives = end - start
if bit_index == -1 or num_primitives < MAX_PRIMS_IN_NODE:
# Create and return leaf node of LBVH treelet
node = build_nodes[context.build_node_index]
context.build_node_index += 1
context.total_nodes += 1
first_prim_offset = context.ordered_prims_offset
context.ordered_prims_offset += num_primitives
bounds = Bounds.empty()
for i in range(start, end):
ind = morton_primitives[i].primitive_index
ordered_primitives[first_prim_offset + i - start] = context.primitives[ind]
bounds = Bounds.union(bounds, primitive_info[ind].bounds)
node.init_leaf(bounds=bounds, first_prim_offset=first_prim_offset, num_primitives=num_primitives)
return node
mask = np.uint32(1 << bit_index)
# Advance to next subtree level if there's no LBVH split for this bit
if (morton_primitives[start].morton_code & mask) == (morton_primitives[end - 1].morton_code & mask):
return emit_lbvh(build_nodes, primitive_info, morton_primitives, start, end, ordered_primitives, bit_index - 1, context)
# Find LBVH split point for this dimension
search_start, search_end = start, end
while search_start + 1 != search_end:
mid = (search_start + search_end) // 2
if (morton_primitives[search_start].morton_code & mask) == (morton_primitives[mid].morton_code & mask):
search_start = mid
else:
search_end = mid
split_offset = search_end
# Create and return interior LBVH node
node = build_nodes[context.build_node_index]
context.build_node_index += 1
context.total_nodes += 1
axis = bit_index % 2
left = emit_lbvh(build_nodes, primitive_info, morton_primitives, start, split_offset, ordered_primitives, bit_index - 1, context)
right = emit_lbvh(build_nodes, primitive_info, morton_primitives, split_offset, end, ordered_primitives, bit_index - 1, context)
node.init_interior(axis=axis, left=left, right=right)
return node
def build_lbvh(primitives: list[Primitive]) -> tuple[list[LinearBVHNode], list[Primitive]]:
MORTON_BITS = 10
MORTON_SCALE = 1 << MORTON_BITS
FIRST_BIT_INDEX = 19
num_primitives = len(primitives)
primitive_info = [BVHPrimitiveInfo(i, primitives[i].bounds) for i in range(num_primitives)]
# Compute bounding box of all primitive centroids
bounds = compute_bounds(primitive_info)
extent = bounds.extent
# Compute Morton indices of primitives
morton_primitives: list[MortonPrimitive] = []
for pi in primitive_info:
scaled = (pi.centroid - bounds.lower) / extent * MORTON_SCALE
scaled = np.minimum(MORTON_SCALE - 1, scaled)
morton_primitives.append(MortonPrimitive(pi.index, encode_morton(scaled).item()))
# Radix sort primitive Morton indices
radix_sort(morton_primitives, key=lambda p: p.morton_code, bits_per_pass=5, num_bits=20, out=morton_primitives)
# Create LBVH treelet at bottom of BVH
num_primitives = len(primitive_info)
max_bvh_nodes = 2 * num_primitives - 1
nodes = [BVHBuildNode() for _ in range(max_bvh_nodes)]
treelet = LBVHTreelet(0, num_primitives, nodes)
# Create LBVH for treelet
ordered_primitives: list[Primitive] = [None] * num_primitives # type: ignore
emit_lbvh_context = EmitLBVHContext(primitives)
emit_lbvh(treelet.build_nodes, primitive_info, morton_primitives, 0, treelet.num_primitives, ordered_primitives, FIRST_BIT_INDEX, emit_lbvh_context)
flatten_bvh_tree_context = FlattenBVHTreeContext(nodes=[LinearBVHNode() for _ in range(emit_lbvh_context.total_nodes)])
flatten_bvh_tree(treelet.build_nodes[0], flatten_bvh_tree_context)
return flatten_bvh_tree_context.nodes, ordered_primitives
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)
# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))
要点
@dataclass(slots=True)
class BuildUpperSAHContext:
total_nodes: int = 0
def build_upper_sah(treelet_roots: list[BVHBuildNode], start: int, end: int, context: BuildUpperSAHContext) -> BVHBuildNode:
num_nodes = end - start
if num_nodes == 1:
return treelet_roots[start]
node = BVHBuildNode()
context.total_nodes += 1
# Compute bounds of all nodes under this HLBVH node
bounds = Bounds.empty()
for i in range(start, end):
bounds = Bounds.union(bounds, treelet_roots[i].bounds)
# Compute bound of HLBVH node centroids, choose split dimension _dim_
centroid_bounds = Bounds.empty()
for i in range(start, end):
centroid_bounds = Bounds.union(centroid_bounds, treelet_roots[i].bounds.center)
extent = centroid_bounds.extent
axis: int = np.argmax(extent) # type: ignore
# Initialize _BucketInfo_ for HLBVH SAH partition buckets
NUM_BUCKETS = 12
buckets = [BucketInfo() for _ in range(NUM_BUCKETS)]
for i in range(start, end):
centroid = treelet_roots[i].bounds.center
pos = (centroid[axis] - centroid_bounds.lower[axis]) / extent[axis]
ind = min(NUM_BUCKETS - 1, int(pos * NUM_BUCKETS))
bucket = buckets[ind]
bucket.bounds = Bounds.union(bucket.bounds, treelet_roots[i].bounds)
bucket.count += 1
# Compute costs for splitting after each bucket
costs = compute_sah_costs(bounds.surface_area, buckets)
# Find bucket to split at that minimizes SAH metric
min_cost, min_cost_split = np.inf, 0
for i in range(NUM_BUCKETS - 1):
if costs[i] < min_cost:
min_cost = costs[i] # type: ignore
min_cost_split = i
# Split nodes and create interior HLBVH SAH node
split_pos = centroid_bounds.lower[axis] + (min_cost_split + 1) * (extent[axis] / NUM_BUCKETS)
mid = partition(treelet_roots, start=start, end=end, pred=lambda tr: tr.bounds.center[axis] < split_pos)
node.init_interior(axis=axis, left=build_upper_sah(treelet_roots, start, mid, context), right=build_upper_sah(treelet_roots, mid, end, context))
return node
def build_hlbvh(primitives: list[Primitive]) -> tuple[list[LinearBVHNode], list[Primitive]]:
MORTON_BITS = 10
MORTON_SCALE = 1 << MORTON_BITS
FIRST_BIT_INDEX = 15
primitive_info = [BVHPrimitiveInfo(i, primitives[i].bounds) for i in range(len(primitives))]
total_nodes = 0
# Compute bounding box of all primitive centroids
bounds = compute_bounds(primitive_info)
extent = bounds.extent
# Compute Morton indices of primitives
morton_primitives = []
for pi in primitive_info:
scaled = (pi.centroid - bounds.lower) / extent * MORTON_SCALE
scaled = np.minimum(MORTON_SCALE - 1, scaled)
morton_primitives.append(MortonPrimitive(pi.index, encode_morton(scaled).item()))
# Radix sort primitive Morton indices
radix_sort(morton_primitives, key=lambda p: p.morton_code, bits_per_pass=5, num_bits=20, out=morton_primitives)
# Create LBVH treelets at bottom of BVH
treelets_to_build = []
mask = np.uint32(((1 << (19 - FIRST_BIT_INDEX)) - 1) << (FIRST_BIT_INDEX + 1))
start, end = 0, 1
while end <= len(morton_primitives):
if end == len(morton_primitives) or ((morton_primitives[start].morton_code & mask) != (morton_primitives[end].morton_code & mask)):
num_primitives = end - start
max_bvh_nodes = 2 * num_primitives - 1
nodes = [BVHBuildNode() for _ in range(max_bvh_nodes)]
treelets_to_build.append(LBVHTreelet(start, num_primitives, nodes))
start = end
end += 1
# Create LBVHs for treelets
total_nodes = 0
ordered_primitives: list[Primitive] = [None] * len(primitives) # type: ignore
context = EmitLBVHContext(primitives)
for treelet in treelets_to_build:
context.build_node_index = 0
context.total_nodes = 0
emit_lbvh(treelet.build_nodes, primitive_info, morton_primitives, treelet.start_index, treelet.start_index + treelet.num_primitives, ordered_primitives, FIRST_BIT_INDEX, context)
total_nodes += context.total_nodes
# Create and return SAH BVH from LBVH treelets
finished_treelets = [treelet.build_nodes[0] for treelet in treelets_to_build]
build_upper_sah_context = BuildUpperSAHContext(total_nodes=total_nodes)
root = build_upper_sah(finished_treelets, 0, len(finished_treelets), build_upper_sah_context)
total_nodes = build_upper_sah_context.total_nodes
flatten_bvh_tree_context = FlattenBVHTreeContext(nodes=[LinearBVHNode() for _ in range(total_nodes)])
flatten_bvh_tree(root, flatten_bvh_tree_context)
return flatten_bvh_tree_context.nodes, ordered_primitives
at_root = convert_to_anytree(bvh_variant.nodes, bvh_variant.primitives)
# graphviz needs to be installed for the next line!
anytree.exporter.UniqueDotExporter(at_root, options=[f'label="{bvh_variant.name}"', 'labelloc="t"', "fontsize=24"]).to_picture(bvh_variant.filename)
display.display(display.Image(bvh_variant.filename, width=400))
各構築手法を比較する.
from PIL import Image
try:
images = [Image.open(variant.filename) for variant in bvh_variants]
margin = int(0.05 * max(image.width for image in images))
width = sum(image.width for image in images) + margin * (len(images) - 1)
height = max(image.height for image in images)
dst = Image.new("RGB", (width, height), (255, 255, 255))
x = 0
for image in images:
dst.paste(image, (x, 0))
x += image.width + margin
dst.save("combined.png")
display.display(display.Image("combined.png", width=800))
except Exception as e:
print(e)
num_variants = len(bvh_variants)
fig, axes = plt.subplots(nrows=1, ncols=num_variants, figsize=(5 * num_variants, 5), constrained_layout=True)
for i, (ax, variant) in enumerate(zip(axes, bvh_variants)):
plot_bvh(variant.nodes, variant.primitives, ax)
ax.set_title(variant.name)
ax.set_aspect("equal", adjustable="box")
ax.grid(True)
if i != 0:
ax.tick_params(labelleft=False)
ax.spines["left"].set_visible(False)
plt.show()
構築された BVH に対して光線を飛ばし,ノードを走査しながらプリミティブとの交差判定を行う.
走査について
index + 1
,右の子はあらかじめ格納された offset
により特定できるため,ポインタを使わずにスタックで辿れる.dir_is_neg
を使う理由@dataclass(slots=True)
class Intersection:
t: float
primitive: Primitive
def intersect(nodes: list[LinearBVHNode], ordered_primitives: list[Primitive], ray: Ray) -> Intersection | None:
isect = None
dir_is_neg = ray.direction < 0.0
stack = []
stack.append(0)
while stack:
ind = stack.pop()
node = nodes[ind]
is_hit, _, _ = node.bounds.intersect(ray)
if not is_hit:
continue
if node.is_leaf():
prims = node.primitives(ordered_primitives)
for p in prims:
t = p.intersect(ray)
if t is not None and (isect is None or t < isect.t):
isect = Intersection(t, p)
else:
l_ind, r_ind = node.child_indices(ind)
if dir_is_neg[node.axis]:
stack.append(l_ind)
stack.append(r_ind)
else:
stack.append(r_ind)
stack.append(l_ind)
return isect
isect.t=np.float64(41.20525616375596), isect.primitive.name='T24'
def plot_intersection(ray, isect, ax):
ro, rd = ray.origin, ray.direction
if isect is not None:
t = isect.t
p = ray.at(t)
ax.plot(*np.vstack([ro, p]).T, color="crimson", linestyle="--", linewidth=1, alpha=0.25)
ax.scatter(p[0], p[1], s=50, c="crimson", marker="x", linewidths=1, label="Closest Hit")
ax.quiver(*ro, *rd, angles="xy", scale_units="xy", scale=1, width=0.003, color="crimson", label="Ray")
fig, ax = plt.subplots(figsize=(6, 6), constrained_layout=True)
plot_bvh(bvh_variant.nodes, bvh_variant.primitives, ax)
plot_intersection(ray, isect, ax)
# Styling and layout
plt.title(f"{bvh_variant.name}: Ray Traversal")
ax.set_aspect("equal", adjustable="box")
ax.grid(True)
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), borderaxespad=0.0)
plt.show()