# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import logging
import random
from typing import Optional, Tuple
import torch
from torch.utils.checkpoint import checkpoint
from .misc_utils import check_grid_and_color_grid, is_in_bounds, unflatten_grid
from .mlp_utils import DecoderParams, flattened_decoder_params_to_list
from .ray_utils import Rays
from .triton_src.shared.const import MIN_BLOCK_SIZE
from .triton_src.shared.rand_util import int_to_randn_naive
logger = logging.getLogger(__name__)
VERBOSE = False
if VERBOSE:
torch.set_printoptions(
precision=4,
threshold=None,
edgeitems=None,
linewidth=120,
profile=None,
sci_mode=False,
)
[docs]
def lightplane_renderer_naive(
rays: Rays,
grid: tuple[torch.Tensor, ...] | torch.Tensor,
decoder_params: DecoderParams,
# ------ config keys ------
num_samples: int,
gain: float,
mask_out_of_bounds_samples: bool = False,
num_samples_inf: int = 0,
contract_coords: bool = False,
inject_noise_sigma: float = 0.0,
inject_noise_seed: int | None = None,
disparity_at_inf: float = 1e-5,
scaffold: torch.Tensor | None = None,
color_grid: tuple[torch.Tensor, ...] | torch.Tensor | None = None,
grid_sizes: list[list[int]] | None = None,
color_grid_sizes: list[list[int]] | None = None,
triton_num_warps: int = -1, # ignored, but kept for compatibility with triton api
triton_block_size: int = -1, # ignored, but kept for compatibility with triton api
regenerate_code: bool = False, # ignored, but kept for compatibility with triton api
checkpointing: bool = False, # whether or not use pytorch checkpoint for MLP eval
):
r"""
This is the naive implementation of the Lightplane Renderer (`lightplane_renderer`),
which gives the same numeric results as the Triton implementation with less
memory efficiency.
It is useful for debugging and understanding the Triton implementation.
It outputs the the final color `c`, negative_log_transmittance `T_N` and the expected
ray-termination length `r` (analogous to depth) of each ray's pixel.
Its arguments are the same as the Triton implementation in `lightplane_renderer`.
Additionally, it could work using `torch.torch.utils.checkpoint` by setting
`checkpointing=True`
Args:
rays: The rays to render features.
It is an instance of `Rays`, including `directions`, `origins`,
`grid_idx`, `near`, `far`, and `encoding`.
`grid_idx` indicates the batch index of the 3D grid to sample
features from::
x_i = rays.origins[i] + (rays.near[i] + delta_i) * rays.direction[i]
grid: Grid-list (a list of 3D grids) to sample features from.
Features are sampled from each 3D grid in the set and summed up as
the final feature.
`grid` contains `N` tensors, each with the shape::
[[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]]
Each tensor must have 5 dimensions and all tensors should have
the same batch size `B` and feature dimension `C`.
Example:
If `grid` is a single Voxel grid::
grid = [torch.tensor([B, D, H, W, C])]
If `grid` is a triplane::
grid = [
torch.tensor([B, 1, H, W, C]),
torch.tensor([B, D, 1, W, C]),
torch.tensor([B, D, H, 1, C]),
]
`lightplane_renderer` can also work with `grid` as a 2D tensor,
which is a stacked tensor from the grid-list `grid`, with the shape
`[sum_(i=1..N)(B * D_i * H_i * W_i), C]`.
In this case, the `grid_sizes` must be provided to specify the shape
of each grid.
The 2D tensor can be obtained from `lightplane.flatten_grid(grid)`
to flatten the list of tensors and to also obtain the `grid_sizes`
argument.
This is useful for improving memory-effciency when grid-list is
giant since we internally flatten the grid-list to a 2D tensor.
decoder_params: The parameters of the decoder, including the MLP
parameters of `trunk_mlp`, `color_mlp`, and `opacity_mlp`.
num_samples: The number of sampling points along the ray.
The samples are equispaced between `rays.near` and `rays.far`.
More specifically, the j-th 3d point `x_ij` along `i-th` ray is
defined as follows::
x_ij = rays.origins[i] + (rays.near[i] + j * delta_i) * rays.direction[i],
where:
delta_i = (rays.far[i] - rays.near[i]) / num_samples
gain: A constant to scale the transmittance::
T_i = exp(-gain * sum_{j=1}^{i} o(x_ij) * delta_i)
num_samples_inf: The number of background samples along the ray.
The first background sample is placed at `rays.far`, and the samples
are spaced in the disparity space until reaching the disparity of
`disparity_at_inf`.
More specifically, the j-th background 3d point `b_ij` along `i-th`
ray is defined as follows::
b_ij = rays.origins[i] + (rays.far[i] + j * bg_delta_ij) * rays.direction[i],
where:
bg_delta_ij = 1 / disparity_ij
disparity_ij = linspace(1, disparity_at_inf, num_samples_inf)[j]
These samples are additional to `num_samples`, i.e. the total number
of samples along a ray is `num_samples + num_samples_inf`.
mask_out_of_bounds_samples: Whether to mask samples that
fall outside the [-1, 1] cube (does not apply when contraction with
`contract_coords` is enabled).
contract_coords: Whether to map the coordinates of the splatted
points to always fall into the [-1, 1] cube. The contraction is implemented
as in MeRF [1]::
x[k] if |x|_inf <= 1
contract(x)[k] = x[k] / |x|_inf if x_k != |x|_inf > 1
(2 - 1/x[k]) x_k / |x_k| if x_k = |x|_inf > 1
Note: The contraction is useful for representing unbounded scenes.
E.g. outdoor captures where the scene extends to infinity.
disparity_at_inf: The disparity value at infinity.
inject_noise_sigma: The variance of opacity noise to inject.
inject_noise_seed: The seed of the random noise to inject.
scaffold: A voxel grid with shape `[B, D, H, W]`, indicating the occupancy
of the 3D space. If provided, the renderer will only render the points
that are not empty in the scaffold.
color_grid: Another grid-list (a list of 3D grids) storing color features.
If provided, the renderer will regress the color from features
sampled from `color_grid`, using `color_mlp`.
Similar to `grid`, `color_grid` could also be a 2D tensor with
`color_grid_sizes` provided.
`color_grid` should be the same type as `grid`.
grid_sizes: It specifies the size of `grid`.
It is optional when `grid` is a grid-list, but required when `grid` is a
2D tensor. Example::
grid_sizes = [[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]]
color_grid_sizes: It specifies the size of `color_grid` when `color_grid`
is a 2D tensor.
It is optional when `color_grid` is a grid-list, but required
when `color_grid` is a 2D tensor. Example::
color_grid_sizes = [[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]]
regenerate_code: Ignored, but kept for compatibility with triton api.
triton_block_size: Ignored, but kept for compatibility with triton api.
triton_num_warps: Ignored, but kept for compatibility with triton api.
checkpointing: Whether or not use `torch.utils.checkpoint` for checkpointing.
Returns:
ray_length_render: The rendered ray-termination length `r` (i.e. distance along the ray).
negative_log_transmittances: The negative log transmittances of the ray.
feature_render: The expected features of the ray.
References:
[1] MERF: Memory-Efficient Radiance Fields for Real-time View Synthesis in
Unbounded Scenes, https://arxiv.org/abs/2302.12249
"""
grid, color_grid, grid_sizes, color_grid_sizes = check_grid_and_color_grid(
grid, color_grid, grid_sizes, color_grid_sizes
)
# if grid is flatten tensor, we need to unflatten them to 5-dim tensor so
# that pytorch can interpolate on them.
# unflatten use split operations, which should use no addtional memories as
# it creats tensor view instead of allocating memories.
if isinstance(grid, torch.Tensor):
grid_sizes_tensor = torch.tensor(
grid_sizes, device=grid.device, dtype=torch.long
)
grid = unflatten_grid(grid, grid_sizes_tensor)
if color_grid is not None:
color_grid_sizes_tensor = torch.tensor(
color_grid_sizes, device=color_grid.device, dtype=torch.long
)
color_grid = unflatten_grid(color_grid, color_grid_sizes_tensor)
device = rays.device
num_rays = rays.directions.shape[0]
lsp = torch.linspace(0.0, 1.0, num_samples).to(device)
depths = rays.near[:, None] + lsp[None, :] * (rays.far - rays.near)[:, None]
tot_num_samples = num_samples + num_samples_inf
if inject_noise_seed is None:
if inject_noise_sigma > 0.0:
inject_noise_seed = int(random.randint(0, 1000000))
else:
inject_noise_seed = 0
if inject_noise_sigma > 0.0:
inject_opacity_noise = _get_sample_randn(
tot_num_samples,
num_rays,
device,
inject_noise_seed,
)
inject_opacity_noise = inject_opacity_noise * inject_noise_sigma
else:
inject_opacity_noise = None
if num_samples_inf > 0:
sph = torch.stack(
[
_depth_inv_sphere(rays.far, disparity_at_inf, num_samples_inf, step)
for step in range(num_samples_inf)
],
dim=-1,
)
depths = torch.cat([depths, sph], dim=-1)
points = depths[..., None] * rays.directions[:, None]
points = points + rays.origins[..., None, :]
delta_one = (
(rays.far - rays.near) / (num_samples - 1)
if num_samples > 1
else torch.ones_like(rays.near)
)
delta = torch.cat([delta_one[:, None], depths.diff(dim=-1)], dim=-1)
if VERBOSE:
print("near")
print(rays.near)
print("far")
print(rays.far)
print("depths")
print(depths)
print("delta")
print(delta)
print("centers")
print(rays.origins)
# if checkpointing:
# opacity, color = checkpoint(
# lightplane_eval_mlp,
# *(
# samples,
# grid,
# rays.grid_idx,
# decoder_params,
# rays.encoding,
# gain,
# mask_out_of_bounds_samples,
# inject_opacity_noise,
# scaffold,
# color_grid
# ),
# use_reentrant=False
# )
# else:
opacity, color = lightplane_eval_mlp(
points,
grid,
rays.grid_idx,
decoder_params,
rays.encoding, # ..., C
gain,
mask_out_of_bounds_samples=mask_out_of_bounds_samples,
inject_opacity_noise=inject_opacity_noise,
scaffold=scaffold,
color_grid=color_grid,
checkpointing=checkpointing,
contract_coords=contract_coords,
)
delta_opacity = opacity * delta
delta_opacity = torch.nn.functional.pad(delta_opacity, (1, 0))
negative_log_transmittances = torch.cumsum(delta_opacity, dim=-1)
transmittance = torch.exp(-negative_log_transmittances)
rweights = -transmittance.diff(dim=-1)
if VERBOSE:
print("weight")
print(rweights)
ray_length_render = (depths * rweights).sum(dim=-1)
feature_render = (color * rweights[..., None]).sum(-2)
negative_log_transmittance = negative_log_transmittances[..., -1]
if decoder_params.color_chn < feature_render.shape[-1]:
feature_render = feature_render[..., : decoder_params.color_chn]
return (
ray_length_render,
negative_log_transmittance,
feature_render,
)
def lightplane_eval_mlp(
points: torch.Tensor, # R x N x 3; packed with ray_grid_idx
grid: tuple[torch.Tensor, ...],
ray_grid_idx: torch.Tensor,
decoder_params: DecoderParams,
rays_encoding: torch.Tensor,
gain: float,
mask_out_of_bounds_samples: bool = False,
inject_opacity_noise: torch.Tensor | None = None,
scaffold: torch.Tensor | None = None,
color_grid: tuple[torch.Tensor, ...] | None = None,
checkpointing: bool = False, # whether or not use pytorch checkpoint for MLP eval
contract_coords: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert points.ndim >= 3
(
weights_trunk,
biases_trunk,
weights_opacity,
biases_opacity,
weights_color,
biases_color,
) = flattened_decoder_params_to_list(
decoder_params.mlp_params,
decoder_params.n_hidden_trunk,
decoder_params.n_hidden_opacity,
decoder_params.n_hidden_color,
)
if VERBOSE:
print("w_trunk")
for w in weights_trunk:
print(w)
print("b_trunk")
for b in biases_trunk:
print(b)
if contract_coords:
points = _contract_pi(points)
feature_sampled = sample_grid_list_checkpointed(
grid,
points,
ray_grid_idx,
mask_out_of_bounds_samples,
checkpointing=checkpointing,
)
if color_grid is not None:
feature_sampled_color = sample_grid_list_checkpointed(
color_grid,
points,
ray_grid_idx,
mask_out_of_bounds_samples,
checkpointing=checkpointing,
)
else:
feature_sampled_color = None
if VERBOSE:
print("feature_sampled")
print(feature_sampled)
if feature_sampled_color is None:
# we have a single feature grid
feature_trunk = _eval_mlp(
feature_sampled,
weights_trunk,
biases_trunk,
mlp_name="trunk",
checkpointing=checkpointing,
)
feature_trunk = torch.relu(feature_trunk)
if VERBOSE:
print("feature_trunk")
print(feature_trunk)
opacity_raw = _eval_mlp(
feature_trunk,
weights_opacity,
biases_opacity,
mlp_name="opacity",
checkpointing=checkpointing,
)
if VERBOSE:
print("opacity_raw")
print(opacity_raw)
ray_feature_trunk = feature_trunk + rays_encoding[:, None]
if VERBOSE:
print("ray_feature_trunk")
print(ray_feature_trunk)
log_color = _eval_mlp(
ray_feature_trunk,
weights_color,
biases_color,
mlp_name="color",
checkpointing=checkpointing,
)
if VERBOSE:
print("log_color")
print(log_color)
else:
# we use a relu right after sampling (i.e. a relu-field)
feature_sampled = torch.relu(feature_sampled)
feature_sampled_color = torch.relu(feature_sampled_color)
ray_feature_sampled_color = feature_sampled_color + rays_encoding[:, None]
if VERBOSE:
print("feature_sampled")
print(feature_sampled)
print("feature_sampled_color")
print(feature_sampled_color)
assert len(weights_trunk) == 0
assert len(biases_trunk) == 0
opacity_raw = _eval_mlp(
feature_sampled,
weights_opacity,
biases_opacity,
mlp_name="opacity",
checkpointing=checkpointing,
)
if VERBOSE:
print("opacity_raw")
print(opacity_raw)
log_color = _eval_mlp(
ray_feature_sampled_color,
weights_color,
biases_color,
mlp_name="color",
checkpointing=checkpointing,
)
if VERBOSE:
print("log_color")
print(log_color)
assert opacity_raw.shape[-1] == 1
opacity_raw = opacity_raw[..., 0]
if inject_opacity_noise is not None:
if VERBOSE:
print("inject_opacity_noise")
print(inject_opacity_noise)
opacity_raw = opacity_raw + inject_opacity_noise
opacity = gain * torch.nn.functional.softplus(opacity_raw)
# TODO: allow for output without activation
feature_out = torch.sigmoid(log_color)
if scaffold is not None:
scaffold_value = sample_grid_list_checkpointed(
(scaffold[..., None],),
points,
ray_grid_idx,
True,
mode="nearest",
checkpointing=checkpointing,
)
if VERBOSE:
print("scaffold_value")
print(scaffold_value)
prev_opacity = opacity
opacity = opacity * scaffold_value[..., 0]
feature_out = feature_out * scaffold_value
return opacity, feature_out
def lightplane_eval_mlp_opacity_only(
points: torch.Tensor, # R x N x 3; packed with ray_grid_idx
grid: tuple[torch.Tensor, ...],
ray_grid_idx: torch.Tensor,
decoder_params: DecoderParams,
gain: float,
mask_out_of_bounds_samples: bool = False,
inject_opacity_noise: torch.Tensor | None = None,
scaffold: torch.Tensor | None = None,
checkpointing: bool = False, # whether or not use pytorch checkpoint for MLP eval
contract_coords: bool = False,
) -> torch.Tensor:
assert points.ndim >= 3
(
weights_trunk,
biases_trunk,
weights_opacity,
biases_opacity,
weights_color,
biases_color,
) = flattened_decoder_params_to_list(
decoder_params.mlp_params,
decoder_params.n_hidden_trunk,
decoder_params.n_hidden_opacity,
decoder_params.n_hidden_color,
)
if VERBOSE:
print("w_trunk")
for w in weights_trunk:
print(w)
print("b_trunk")
for b in biases_trunk:
print(b)
feature_sampled = sample_grid_list_checkpointed(
grid,
points,
ray_grid_idx,
mask_out_of_bounds_samples,
checkpointing=checkpointing,
)
if VERBOSE:
print("feature_sampled")
print(feature_sampled)
# we have a single feature grid
feature_trunk = _eval_mlp(
feature_sampled,
weights_trunk,
biases_trunk,
mlp_name="trunk",
checkpointing=checkpointing,
)
feature_trunk = torch.relu(feature_trunk)
if VERBOSE:
print("feature_trunk")
print(feature_trunk)
opacity_raw = _eval_mlp(
feature_trunk,
weights_opacity,
biases_opacity,
mlp_name="opacity",
checkpointing=checkpointing,
)
if VERBOSE:
print("opacity_raw")
print(opacity_raw)
assert opacity_raw.shape[-1] == 1
opacity_raw = opacity_raw[..., 0]
if inject_opacity_noise is not None:
if VERBOSE:
print("inject_opacity_noise")
print(inject_opacity_noise)
opacity_raw = opacity_raw + inject_opacity_noise
opacity = gain * torch.nn.functional.softplus(opacity_raw)
if scaffold is not None:
scaffold_value = sample_grid_list_checkpointed(
(scaffold[..., None],),
points,
ray_grid_idx,
True,
mode="nearest",
checkpointing=checkpointing,
)
if VERBOSE:
print("scaffold_value")
print(scaffold_value)
opacity = opacity * scaffold_value[..., 0]
feature_out = feature_out * scaffold_value
return opacity
def sample_grid_list_checkpointed(
grid: tuple[torch.Tensor, ...],
points: torch.Tensor, # B x N x 3
grid_idx: torch.Tensor, # B
mask_out_of_bounds_samples: bool,
mode="bilinear",
checkpointing=False,
) -> torch.Tensor: # B x N x C
if checkpointing:
return checkpoint(
_sample_grid_list,
*(grid, points, grid_idx, mask_out_of_bounds_samples, mode),
use_reentrant=False,
)
else:
return _sample_grid_list(
grid, points, grid_idx, mask_out_of_bounds_samples, mode
)
def _sample_grid_list(
grid: tuple[torch.Tensor, ...],
points: torch.Tensor, # B x N x 3
grid_idx: torch.Tensor, # B
mask_out_of_bounds_points: bool,
mode="bilinear",
) -> torch.Tensor: # B x N x C
used_grids = grid_idx.unique()
batch_to_idx = [torch.where(grid_idx == i)[0] for i in used_grids]
points_list = [points[idx] for idx in batch_to_idx]
points_padded = torch.nn.utils.rnn.pad_sequence(
points_list,
batch_first=True,
)
sampled_padded = sum(
_sample_one_grid(
g[used_grids],
points_padded,
mask_out_of_bounds_points,
mode=mode,
)
for g in grid
)
assert sampled_padded.shape[:-1] == points_padded.shape[:-1]
sampled_list = torch.nn.utils.rnn.unpad_sequence(
sampled_padded,
torch.tensor([len(l) for l in batch_to_idx]),
batch_first=True,
)
if VERBOSE:
print("points")
print(points)
sampled = torch.zeros(
points.shape[0],
points.shape[1],
sampled_list[0].shape[-1],
device=points.device,
dtype=points.dtype,
)
sampled[torch.cat(batch_to_idx)] = torch.cat(sampled_list, dim=0)
return sampled
def _sample_one_grid(
g: torch.Tensor, points: torch.Tensor, mask_out_of_bounds_samples: bool, mode: str
):
assert g.ndim == 5, "We support only B x D x H x W x C grids for now."
n_non_singular_dim = sum(int(s > 1) for s in g.shape[1:-1])
if n_non_singular_dim == 3: # 3d voxel grid
sampled = torch.nn.functional.grid_sample(
g.permute(0, 4, 1, 2, 3),
points[..., None, :],
align_corners=False,
mode=mode,
# mode="nearest",
)[..., 0].permute(0, 2, 3, 1)
elif n_non_singular_dim == 2: # triplane
singular_dim = [i for i, s in enumerate(g.shape[1:-1]) if s == 1][0]
if singular_dim == 0:
plane = "xy"
elif singular_dim == 1:
plane = "xz"
elif singular_dim == 2:
plane = "yz"
else:
raise ValueError()
sample_coords = ["xyz".index(c) for c in plane]
sampled = torch.nn.functional.grid_sample(
g.squeeze(singular_dim + 1).permute(0, 3, 1, 2),
points[..., sample_coords],
align_corners=False,
mode=mode,
).permute(0, 2, 3, 1)
# if True: # debug, TODO: move to separate test
# _, ID, IH, IW, _ = g.shape
# mask_ = (torch.tensor([IW, IH, ID], device=points.device) > 1).float()
# points_ = points * mask_
# sampled_ = torch.nn.functional.grid_sample(
# g.permute(0, 4, 1, 2, 3),
# points_[..., None, :],
# align_corners=False,
# mode=mode,
# )[..., 0].permute(0, 2, 3, 1)
# assert torch.allclose(sampled, sampled_, atol=1e-5, rtol=1e-5)
if VERBOSE:
print(f"sampled {plane}[0]:")
print(sampled[0])
else:
raise ValueError(
f"Unexpected n non-singulare dim of input grid ({n_non_singular_dim})"
)
if mask_out_of_bounds_samples:
in_bounds_mask = is_in_bounds(points)
sampled = sampled * in_bounds_mask.float()
return sampled
def _eval_mlp(
vec: torch.Tensor,
weights: tuple[torch.Tensor, ...],
biases: tuple[torch.Tensor, ...],
mlp_name: str = "",
checkpointing: bool = False,
):
if checkpointing:
return _eval_mlp_checkpointing(vec, weights, biases, mlp_name)
else:
return _eval_mlp_org(vec, weights, biases, mlp_name)
def _eval_mlp_checkpointing(
vec: torch.Tensor,
weights: tuple[torch.Tensor, ...],
biases: tuple[torch.Tensor, ...],
mlp_name: str = "",
):
return checkpoint(
_eval_mlp_org, *(vec, weights, biases, mlp_name), use_reentrant=False
)
def _eval_mlp_org(
vec: torch.Tensor,
weights: tuple[torch.Tensor, ...],
biases: tuple[torch.Tensor, ...],
mlp_name: str = "",
):
n_l = len(weights)
assert n_l == len(biases)
for l in range(n_l):
vec = vec @ weights[l] + biases[l]
if VERBOSE:
if mlp_name == "trunk" and l == 0:
print(weights[l])
print(biases[l])
print(f"x{l}@w+b")
print(vec[0])
if l < n_l - 1:
vec = torch.relu(vec)
return vec
def _get_sample_randn(
num_samples,
num_rays,
device,
inject_noise_seed,
):
num_rays_pad = max(num_rays, MIN_BLOCK_SIZE)
i1 = (
num_samples * torch.arange(num_rays, device=device)[:, None]
+ torch.arange(num_samples, device=device)[None]
+ 1
).long()
i2 = i1 + num_rays_pad * num_samples
r = int_to_randn_naive(i1.reshape(-1), i2.reshape(-1), inject_noise_seed)
return r.reshape(num_rays, num_samples)
def _contract_pi(x):
n = x.abs().max(dim=-1).values[..., None]
x_contract = torch.where(
n <= 1.0,
x,
torch.where(
(x.abs() - n).abs() <= 1e-7,
(2 - 1 / x.abs()) * (x / x.abs()),
x / n,
),
)
return x_contract / 2
def _depth_inv_sphere(far, disparity_at_inf, n, step):
frac_step = (step + 1) / n
n_disp = (disparity_at_inf - 1) * frac_step + 1
return far * (1 / n_disp)