Source code for lightplane.lightplane_splatter

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import copy
import math
import time
from typing import List, Optional, Tuple, Union

import torch

from .misc_utils import (
    assert_shape,
    check_grid,
    flatten_grid,
    process_and_flatten_grid,
    unflatten_grid,
)
from .mlp_utils import SplatterParams
from .ray_utils import Rays
from .triton_src import get_lightplane_kernels  # jit-compiles the extension
from .triton_src.shared.const import MIN_BLOCK_SIZE

PROFILE = False
DEBUG = False


[docs] def lightplane_splatter( rays: Rays, output_grid_size: list[tuple[int, int, int, int, int]], # ------ config keys ------ num_samples: int, num_samples_inf: int = 0, mask_out_of_bounds_samples: bool = False, contract_coords: bool = False, disparity_at_inf: float = 1e-5, return_list: bool = True, # whether return list or stacked tensor regenerate_code: bool = False, triton_block_size: int = 16, triton_num_warps: int = 4, ) -> torch.Tensor | list[torch.Tensor]: r""" This is the functional interface for the Lightplane Splatter. For each point along the ray, it splats `rays.encoding` to a zero-initialized output grid-list, `output_grid`, where the shape of the output grid-list is specified by `output_grid_size`. It utilizes the ray-marched splatting, where the `num_samples` 3D points are equispaced between `rays.near` and `rays.far`. It is useful for directly splatting ray encodings It follows:: ray encoding -> splat -> output_grid Args: rays: The rays to splat feature. It is an instance of `Rays`, including `directions`, `origins`, `grid_idx`, `near`, `far`, and `encoding`. `grid_idx` indicates the batch index of the 3D grid to sample features from:: x_i = rays.origins[i] + (rays.near[i] + delta_i) * rays.direction[i] `endcoding` is the feature to splat for each ray. output_grid_size: The sizes of the `output_grid` to be splatted. It is a list of tuples, where each tuple is the shape of the corresponding grid,in the form `(B, D, H, W, C)`, where `C` has to be the same for all output grids. Example:: output_grid_size = [(1, 64, 64, 64, 256), (1, 32, 1, 32, 256)] num_samples: The number of points to be splatted along the ray. The samples are equispaced between `rays.near` and `rays.far`. More specifically, the j-th 3d point `x_ij` along `i-th` ray is defined as follows:: x_ij = rays.origins[i] + (rays.near[i] + j * delta_i) * rays.direction[i], where: delta_i = (rays.far[i] - rays.near[i]) / num_samples num_samples_inf: The number of points in the background to be splatted. The first background sample is placed at `rays.far`, and the samples are spaced in the disparity space until reaching the disparity of `disparity_at_inf`. More specifically, the j-th background 3d point `b_ij` along `i-th` ray is defined as follows:: b_ij = rays.origins[i] + (rays.far[i] + j * bg_delta_ij) * rays.direction[i], where: bg_delta_ij = 1 / disparity_ij disparity_ij = linspace(1, disparity_at_inf, num_samples_inf)[j] These samples are additional to `num_samples`, i.e. the total number of samples along a ray is `num_samples + num_samples_inf`. mask_out_of_bounds_samples: Whether to mask samples that fall outside the [-1, 1] cube (does not apply when contraction with `contract_coords` is enabled). contract_coords: Whether to map the coordinates of the splatted points to always fall into the [-1, 1] cube. The contraction is implemented as in MeRF [1]:: x[k] if |x|_inf <= 1 contract(x)[k] = x[k] / |x|_inf if x_k != |x|_inf > 1 (2 - 1/x[k]) x_k / |x_k| if x_k = |x|_inf > 1 Note: The contraction is useful for representing unbounded scenes. E.g. outdoor captures where the scene extends to infinity. disparity_at_inf: The disparity value at infinity. return_list: Whether to return a list of grids containing the result of the splatting, or a tensor of stacked features. Note: Stacked features can be converted to a list of grids with the `lightplane.misc_utils.unflatten_grid` function. regenerate_code: If `True`, forces the regeneration of the triton code. triton_block_size: The block size for Triton. Has to be higher than 16. triton_num_warps: The number of warps for Triton. Returns: Union[torch.Tensor, List[torch.Tensor]]: The splatted results. References: [1] MERF: Memory-Efficient Radiance Fields for Real-time View Synthesis in Unbounded Scenes, https://arxiv.org/abs/2302.12249 """ output_grid_size = torch.tensor(output_grid_size, device=rays.device) rays, n_rays_padded = rays.pad_to_block_size(triton_block_size) splatting_feature = rays.encoding valid_mask = torch.ones_like(splatting_feature[:, 0]) if n_rays_padded > 0: valid_mask[-n_rays_padded:] = 0 splatted_results = LightplaneSplatterFunction.apply( output_grid_size, None, None, None, rays.directions, rays.origins, rays.grid_idx.to(torch.int32), rays.near, rays.far, splatting_feature, valid_mask, 0, 0, # other settings num_samples, num_samples_inf, mask_out_of_bounds_samples, contract_coords, disparity_at_inf, # BS triton_block_size, triton_num_warps, regenerate_code, ) if return_list: return unflatten_grid(splatted_results, output_grid_size.to(torch.long)) else: return splatted_results
[docs] def lightplane_mlp_splatter( rays: Rays, output_grid_size: list[tuple[int, int, int, int, int]], mlp_params: SplatterParams, input_grid: tuple[torch.Tensor, ...] | torch.Tensor, # ------ config keys ------ num_samples: int, num_samples_inf: int = 0, mask_out_of_bounds_samples: bool = False, contract_coords: bool = False, disparity_at_inf: float = 1e-5, input_grid_sizes: list[list[int]] | None = None, return_list: bool = True, # whether return list or stacked tensor regenerate_code: bool = False, triton_block_size: int = 16, triton_num_warps: int = 4, ) -> torch.Tensor | list[torch.Tensor]: r""" This is the functional interface for the Lightplane Splatter. For each point along the ray, it first samples the point feature from the corresponding prior input grid `input_grid`, adds the sampled feature to the `encoding` of the ray, passes the latter through an MLP, and splats the MLP output to the grid-list `output_grid`. It utilizes the ray-marched splatting, where the `num_samples` 3D points are equispaced between `rays.near` and `rays.far`. It follows:: input_grid -> f_i -> f_i + ray encoding -> mlp -> sf_i -> splat -> output_grid Args: rays: The rays to splat feature. It is an instance of `Rays`, including `directions`, `origins`, `grid_idx`, `near`, `far`, and `encoding`. `grid_idx` indicates the batch index of the 3D grid to sample features from:: x_i = rays.origins[i] + (rays.near[i] + delta_i) * rays.direction[i] `endcoding` is the feature to splat for each ray. output_grid_size: The sizes of the `output_grid` to be splatted. It is a list of tuples, where each tuple is the shape of the corresponding grid,in the form `(B, D, H, W, C)`, where `C` has to be the same for all output grids. Example:: output_grid_size = [(1, 64, 64, 64, 256), (1, 32, 1, 32, 256)] mlp_params: The parameters of the MLP that processes the features sampled from `input_grid`. input_grid: The grid-list (a list of 3D grids) from which to sample features. Sames as `grid` in `lightplane_renderer`, it contains `N` grids, where each grid is a tensor of shape `(B, D_i, H_i, W_i, C)`. `input_grid` should have the same feature dimension (`C`) as `rays.encoding` as they are summed together before passing through the MLP. Similar to `grid` in `lightplane_renderer`, `input_grid` could be a 2D tensor (the stacked version of the grid-list). In this case, `input_grid_sizes` must be provided to specify the shape of `input_grid`. num_samples: The number of points to be splatted along the ray. The samples are equispaced between `rays.near` and `rays.far`. More specifically, the j-th 3d point `x_ij` along `i-th` ray is defined as follows:: x_ij = rays.origins[i] + (rays.near[i] + j * delta_i) * rays.direction[i], where: delta_i = (rays.far[i] - rays.near[i]) / num_samples num_samples_inf: The number of points in the background to be splatted. The first background sample is placed at `rays.far`, and the samples are spaced in the disparity space until reaching the disparity of `disparity_at_inf`. More specifically, the j-th background 3d point `b_ij` along `i-th` ray is defined as follows:: b_ij = rays.origins[i] + (rays.far[i] + j * bg_delta_ij) * rays.direction[i], where: bg_delta_ij = 1 / disparity_ij disparity_ij = linspace(1, disparity_at_inf, num_samples_inf)[j] These samples are additional to `num_samples`, i.e. the total number of samples along a ray is `num_samples + num_samples_inf`. mask_out_of_bounds_samples: Whether to mask samples that fall outside the [-1, 1] cube (does not apply when contraction with `contract_coords` is enabled). contract_coords: Whether to map the coordinates of the splatted points to always fall into the [-1, 1] cube. The contraction is implemented as in MeRF [1]:: x[k] if |x|_inf <= 1 contract(x)[k] = x[k] / |x|_inf if x_k != |x|_inf > 1 (2 - 1/x[k]) x_k / |x_k| if x_k = |x|_inf > 1 Note: The contraction is useful for representing unbounded scenes. E.g. outdoor captures where the scene extends to infinity. disparity_at_inf: The disparity value at infinity. input_grid_sizes: It specifies the size of `input_grid`. It is optional when `input_grid` is a grid-list, but required when `input_grid`is a 2D tensor. Example:: input_grid_sizes = [[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]] return_list: Whether to return a list of grids containing the result of the splatting, or a tensor of stacked features. Note: Stacked features can be converted to a list of grids with the `lightplane.misc_utils.unflatten_grid` function. regenerate_code: If `True`, forces the regeneration of the triton code. triton_block_size: The block size for Triton. Has to be higher than 16. triton_num_warps: The number of warps for Triton. Returns: Union[torch.Tensor, List[torch.Tensor]]: The splatted results. References: [1] MERF: Memory-Efficient Radiance Fields for Real-time View Synthesis in Unbounded Scenes, https://arxiv.org/abs/2302.12249 """ output_grid_size = torch.tensor(output_grid_size, device=rays.device) rays, n_rays_padded = rays.pad_to_block_size(triton_block_size) splatting_feature = rays.encoding valid_mask = torch.ones_like(splatting_feature[:, 0]) if n_rays_padded > 0: valid_mask[-n_rays_padded:] = 0 assert ( len(mlp_params.n_hidden) > 1 ), f"mlp depth has to be bigger than 1 when using input_grid" assert ( input_grid is not None ), f"input_grid cannot be None when mlp_params is not None" n_hidden = mlp_params.n_hidden[1].item() n_layers = mlp_params.n_hidden.numel() - 1 check_grid(input_grid, input_grid_sizes) input_grid, _, input_grid_size, _ = process_and_flatten_grid( input_grid, None, input_grid_sizes, None ) splatted_results = LightplaneSplatterFunction.apply( output_grid_size, mlp_params.mlp_params, input_grid, input_grid_size, rays.directions, rays.origins, rays.grid_idx.to(torch.int32), rays.near, rays.far, splatting_feature, valid_mask, n_hidden, n_layers, # other settings num_samples, num_samples_inf, mask_out_of_bounds_samples, contract_coords, disparity_at_inf, # BS triton_block_size, triton_num_warps, regenerate_code, ) if return_list: return unflatten_grid(splatted_results, output_grid_size.to(torch.long)) else: return splatted_results
class LightplaneSplatterFunction(torch.autograd.Function): @staticmethod def forward( ctx, feature_grid_sizes: torch.Tensor, # NUM_GRIDS x 5: [[B_1, D_1, H_1, W_1, C_1], ...] mlp_params: torch.Tensor | None, # flattened biases and weights of all mlps input_feature_grid: None | ( torch.Tensor ), # NUM_GRIDS * B * D' * H' * W' * C', another gird with same num_grids and batch, for recursive feature updating. input_feature_grid_sizes: None | (torch.Tensor), # NUM_GRIDS x 5: [[B_1, D_1, H_1, W_1, C_1], ...] directions: torch.Tensor, # N x 3 origins: torch.Tensor, # N x 3 grid_idx: torch.Tensor, # N near: torch.Tensor, # N far: torch.Tensor, # N splatting_feature: torch.Tensor, # N x C valid_mask: torch.Tensor, # N # mlp sizes mlp_dim_hidden: int | None, mlp_n_layers: int | None, # other settings num_samples: int, num_samples_inf: int, mask_out_of_bounds_samples: bool = False, contract_coords: bool = False, disparity_at_inf: float = 1e-5, # triton block size BLOCK_SIZE: int = 16, NUM_WARPS: int = 4, # force code regeneration regenerate_code: bool = False, ): device = directions.device use_mlp = mlp_params is not None ( fw_kernel_feat, fw_kernel_weight, bw_kernel_feat, bw_kernel_weight, ) = get_lightplane_kernels( "splatter", mlp_n_layers, regenerate_code=regenerate_code, ) ctx.fw_kernel_feat = fw_kernel_feat ctx.fw_kernel_weight = fw_kernel_weight ctx.bw_kernel_feat = bw_kernel_feat if PROFILE: torch.cuda.synchronize() time_start = time.time() assert BLOCK_SIZE >= MIN_BLOCK_SIZE # initialize feature grid # All grid should have the same feature dimensions, otherwise cannot be stacked. assert all( g[-1] == feature_grid_sizes[0, -1] for g in feature_grid_sizes ), "All output grids should have the same feature dimensions." feature_grid = torch.zeros( ( sum(torch.prod(g[:-1]) for g in feature_grid_sizes), feature_grid_sizes[0, -1], ), device=device, ).contiguous() # important sizes device = feature_grid_sizes.device num_grid_channels = feature_grid.shape[-1] num_splatting_channels = splatting_feature.shape[-1] num_rays = directions.shape[0] num_grids = feature_grid_sizes.shape[0] grid_batch_size = feature_grid_sizes.shape[1] # mlp size params for kernel if use_mlp: mlp_dim_in = num_splatting_channels mlp_dim_out = num_grid_channels else: mlp_dim_in = num_splatting_channels mlp_dim_out = num_splatting_channels assert ( num_grid_channels == num_splatting_channels ), f"num_grid_channels should be the same as num_splatting_channels" # asserts assert_shape(feature_grid_sizes, (num_grids, 5)) assert_shape(directions, (num_rays, 3)) assert_shape(origins, (num_rays, 3)) assert_shape(grid_idx, (num_rays,)) assert_shape(near, (num_rays,)) assert_shape(far, (num_rays,)) assert_shape(splatting_feature, (num_rays, mlp_dim_in)) assert ( math.log2(num_grid_channels) % 1 == 0 ), f"num_grid_channels has to be divided by 2" assert ( num_grid_channels >= MIN_BLOCK_SIZE ), f"num_grid_channels has to be bigger than {MIN_BLOCK_SIZE}" assert ( math.log2(num_splatting_channels) % 1 == 0 ), f"num_splatting_channels has to be divided by 2" assert ( num_splatting_channels >= MIN_BLOCK_SIZE ), f"num_splatting_channels has to be bigger than {MIN_BLOCK_SIZE}" assert ( num_rays % BLOCK_SIZE == 0 ), "We do not support #rays!=multiple of BLOCK_SIZE." if use_mlp: assert mlp_params is not None assert input_feature_grid is not None assert input_feature_grid_sizes is not None assert mlp_dim_hidden is not None assert mlp_n_layers is not None assert mlp_params.ndim == 1 assert splatting_feature.shape[1] == mlp_dim_in assert_shape(input_feature_grid_sizes, (num_grids, 5)) kwargs = { # ---- grid ---- "feature_grid": feature_grid, "feature_grid_sizes": feature_grid_sizes, # ----- non-differentiable tensors "directions": directions, "origins": origins, "grid_idx": grid_idx, "near": near, "far": far, "splatting_feature": splatting_feature, "mask": valid_mask, # ----- config keys ---- "num_samples": num_samples, "num_samples_inf": num_samples_inf, # ----- sizes ---- "num_rays": num_rays, "grid_channel": num_grid_channels, "NUM_GRIDS": num_grids, "feature_channel": num_splatting_channels, "BLOCK_SIZE": BLOCK_SIZE, # ---- switches ---- "mask_out_of_bounds_samples": int(mask_out_of_bounds_samples), "contract_coords": contract_coords, "disparity_at_inf": disparity_at_inf, } if use_mlp: kwargs.update( { "input_feature_grid": input_feature_grid, "input_feature_grid_sizes": input_feature_grid_sizes, "mlp_params": mlp_params.contiguous(), "DIM_HIDDEN": mlp_dim_hidden, "DIM_IN": mlp_dim_in, "DIM_OUT": mlp_dim_out, } ) n_blocks = int(math.ceil(num_rays / BLOCK_SIZE)) grid = (n_blocks,) fw_kernel_feat[grid](num_warps=1 if DEBUG else NUM_WARPS, **kwargs) weight_grid = torch.zeros_like( feature_grid[..., 0:1] ) # single dimension feature grid weight_grid_sizes = copy.deepcopy(feature_grid_sizes) splating_weight = torch.ones_like(splatting_feature[..., 0:1]) weight_grid_sizes[:, 4] = 1 weight_kwargs = { # ---- grid ---- "feature_grid": weight_grid, "feature_grid_sizes": weight_grid_sizes, # ----- non-differentiable tensors "directions": directions, "origins": origins, "grid_idx": grid_idx, "near": near, "far": far, "splatting_feature": splating_weight, "mask": valid_mask, # ----- config keys ---- "num_samples": num_samples, "num_samples_inf": num_samples_inf, # ----- sizes ---- "num_rays": num_rays, "grid_channel": 1, "NUM_GRIDS": num_grids, "feature_channel": 1, "BLOCK_SIZE": BLOCK_SIZE, # ---- switches ---- "mask_out_of_bounds_samples": int(mask_out_of_bounds_samples), "contract_coords": contract_coords, "disparity_at_inf": disparity_at_inf, } fw_kernel_weight[grid](num_warps=1 if DEBUG else NUM_WARPS, **weight_kwargs) weight_grid = torch.clamp(weight_grid, min=1e-5) # save tensors for bw ctx.save_for_backward( feature_grid, feature_grid_sizes, input_feature_grid, input_feature_grid_sizes, weight_grid, mlp_params, directions, origins, grid_idx, near, far, splatting_feature, valid_mask, ) # save config keys ctx.mlp_dim_in = mlp_dim_in ctx.mlp_dim_out = mlp_dim_out ctx.mlp_n_layers = mlp_n_layers ctx.mlp_dim_hidden = mlp_dim_hidden ctx.num_samples = num_samples ctx.num_samples_inf = num_samples_inf ctx.num_splatting_channels = num_splatting_channels ctx.mask_out_of_bounds_samples = mask_out_of_bounds_samples ctx.contract_coords = contract_coords ctx.disparity_at_inf = disparity_at_inf ctx.BLOCK_SIZE = BLOCK_SIZE ctx.NUM_WARPS = NUM_WARPS ctx.num_grids = num_grids ctx.num_rays = num_rays ctx.num_grid_channels = num_grid_channels ctx.device = device ctx.grid_batch_size = grid_batch_size ctx.use_mlp = use_mlp if PROFILE: torch.cuda.synchronize() elapsed = time.time() - time_start print(f"fw time = {elapsed:1.5f}") return feature_grid / weight_grid @staticmethod def backward(ctx, grad_feature_grid): if PROFILE: torch.cuda.synchronize() time_start = time.time() ( feature_grid, feature_grid_sizes, input_feature_grid, input_feature_grid_sizes, weight_grid, mlp_params, directions, origins, grid_idx, near, far, splatting_feature, valid_mask, ) = ctx.saved_tensors grad_feature_grid = grad_feature_grid / weight_grid device = feature_grid.device grad_splatting_feature = torch.zeros_like(splatting_feature) n_blocks = int(math.ceil(ctx.num_rays / ctx.BLOCK_SIZE)) grid = (n_blocks,) kwargs = { "grad_feature_grid": grad_feature_grid, "grad_feature_grid_sizes": feature_grid_sizes, # ---- grid ---- # ----- non-differentiable tensors "directions": directions, "origins": origins, "grid_idx": grid_idx, "near": near, "far": far, "splatting_feature": splatting_feature, "mask": valid_mask, # ----- config keys ---- "num_samples": ctx.num_samples, "num_samples_inf": ctx.num_samples_inf, # ----- sizes ---- "num_rays": ctx.num_rays, "grid_channel": ctx.num_grid_channels, "NUM_GRIDS": ctx.num_grids, "feature_channel": ctx.num_splatting_channels, "BLOCK_SIZE": ctx.BLOCK_SIZE, # ---- switches ---- "mask_out_of_bounds_samples": int(ctx.mask_out_of_bounds_samples), "contract_coords": ctx.contract_coords, "disparity_at_inf": ctx.disparity_at_inf, "grad_splatting_feature": grad_splatting_feature, } if ctx.use_mlp: grad_mlp_params = torch.zeros_like(mlp_params) grad_input_feature_grid = torch.zeros_like(input_feature_grid) kwargs.update( { "feature_grid": feature_grid, "feature_grid_sizes": feature_grid_sizes, "input_feature_grid": input_feature_grid, "input_feature_grid_sizes": input_feature_grid_sizes, "mlp_params": mlp_params.contiguous(), "DIM_HIDDEN": ctx.mlp_dim_hidden, "DIM_IN": ctx.mlp_dim_in, "DIM_OUT": ctx.mlp_dim_out, "grad_mlp_params": grad_mlp_params, "grad_input_feature_grid": grad_input_feature_grid, } ) else: grad_input_feature_grid = None grad_mlp_params = None ctx.bw_kernel_feat[grid](num_warps=1 if DEBUG else ctx.NUM_WARPS, **kwargs) if PROFILE: torch.cuda.synchronize() elapsed = time.time() - time_start print(f"bw time = {elapsed:1.5f}") # TODO: remove for speed assert torch.isfinite(grad_splatting_feature).all() if grad_mlp_params is not None: assert torch.isfinite(grad_mlp_params).all() return ( None, grad_mlp_params, grad_input_feature_grid, None, None, None, None, None, None, grad_splatting_feature, None, None, None, None, None, None, None, None, None, None, None, None, )