Source code for lightplane.splatter_module

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import copy
import logging
from dataclasses import asdict
from typing import List, Optional, Tuple

import torch

logger = logging.getLogger(__name__)


from .lightplane_splatter import lightplane_mlp_splatter, lightplane_splatter
from .misc_utils import if_not_none_else, unflatten_grid
from .mlp_utils import SplatterParams, init_splatter_params
from .naive_splatter import lightplane_splatter_naive
from .ray_utils import Rays, jitter_near_far


[docs] class LightplaneSplatter(torch.nn.Module):
[docs] def __init__( self, num_samples: int, grid_chn: int, num_samples_inf: int = 0, mask_out_of_bounds_samples: bool = False, contract_coords: bool = False, disparity_at_inf: float = 1e-5, rays_jitter_near_far: bool = False, triton_block_size: int = 16, triton_num_warps: int = 4, use_naive_impl: bool = False, ): r""" This is the Pytorch Module for the Lightplane Splatter. It uses `lightplane_splatter` as the core function and directly splats `rays.encoding` to a zero-initialized output grid-list, `output_grid`. Args: num_samples: Number of samples to splat. grid_chn: Number of channels in the `output_grid`. num_samples_inf: Number of samples beyond the `far` plane. mask_out_of_bounds_samples: Whether to mask out-of-bounds samples. contract_coords: Whether to contract the coordinates as in MeRF disparity_at_inf: The beyond-far samples (their number-per-ray is determined by `num_samples_inf`) are sampled in the disparity space in `range[far, 1 / disparity_at_inf]`. rays_jitter_near_far: Whether to jitter the `near` and `far` planes uniformly in range `[-delta, delta]`. triton_block_size: Block size for triton. triton_num_warps: Number of warps for triton. use_naive_impl: Whether to use the naive pytorch implementation. """ super().__init__() self.num_samples = num_samples self.num_samples_inf = num_samples_inf self.mask_out_of_bounds_samples = mask_out_of_bounds_samples self.contract_coords = contract_coords self.disparity_at_inf = disparity_at_inf self.rays_jitter_near_far = rays_jitter_near_far self.triton_block_size = triton_block_size self.triton_num_warps = triton_num_warps self.use_naive_impl = use_naive_impl self.rays_encoding_dim = grid_chn
[docs] def get_splatter_params(self) -> SplatterParams: r""" Helper function to get the splatter parameters. """ return None
[docs] def forward( self, rays: Rays, grid_size: list[tuple[int, int, int, int, int]], # If set, the following args override the module's default values: num_samples: int | None = None, num_samples_inf: int | None = None, mask_out_of_bounds_samples: bool | None = None, contract_coords: bool | None = None, disparity_at_inf: float | None = None, rays_jitter_near_far: bool | None = None, return_list: bool = True, # return grid list instead of a stacked tensor regenerate_code: bool = False, ): r""" Forward function for splatting rays into a 'output_grid'. Args: rays: `Rays` to splat. `rays.encoding` is splatted to the `output_grid`. grid_size: List of tuples specifying the grid sizes of `output_grid`. num_samples: Number of samples for splatting. num_samples_inf: Number of samples beyond the `far` plane. mask_out_of_bounds_samples: Whether to mask out-of-bounds samples. contract_coords: Whether to contract the coordinates as in MeRF. disparity_at_inf: The beyond-far samples (their number-per-ray is determined by `num_samples_inf`) are sampled in the disparity space in `range[far, 1 / disparity_at_inf]`. rays_jitter_near_far: Whether to jitter the `near` and `far` planes uniformly in range `[-delta, delta]`. return_list: Whether to return a list of grids or a stacked tensor. regenerate_code: Whether to regenerate the code for the splatter. Returns: output_grid: splatted grid. """ num_samples = if_not_none_else(num_samples, self.num_samples) num_samples_inf = if_not_none_else(num_samples_inf, self.num_samples_inf) mask_out_of_bounds_samples = if_not_none_else( mask_out_of_bounds_samples, self.mask_out_of_bounds_samples, ) contract_coords = if_not_none_else(contract_coords, self.contract_coords) disparity_at_inf = if_not_none_else(disparity_at_inf, self.disparity_at_inf) rays_jitter_near_far = if_not_none_else( rays_jitter_near_far, self.rays_jitter_near_far ) # handle ray encoding _check_splatter_ray_encoding_input( rays.encoding, self.rays_encoding_dim, ) rays = copy.copy(rays) # shallow copy for futher processing if rays_jitter_near_far: rays.near, rays.far = jitter_near_far( rays.near, rays.far, num_samples, ) kwargs = { "rays": rays, "output_grid_size": grid_size, "num_samples": num_samples, "num_samples_inf": num_samples_inf, "mask_out_of_bounds_samples": mask_out_of_bounds_samples, "contract_coords": contract_coords, "disparity_at_inf": disparity_at_inf, "return_list": return_list, "triton_block_size": self.triton_block_size, "triton_num_warps": self.triton_num_warps, "regenerate_code": regenerate_code, } # run lightplane if self.use_naive_impl: lightplane_fn = lightplane_splatter_naive else: lightplane_fn = lightplane_splatter out = lightplane_fn(**kwargs) return out
[docs] class LightplaneMLPSplatter(torch.nn.Module):
[docs] def __init__( self, num_samples: int, grid_chn: int, input_grid_chn: int = 32, mlp_hidden_chn: int = 32, mlp_n_layers: int = 2, num_samples_inf: int = 0, mask_out_of_bounds_samples: bool = False, contract_coords: bool = False, disparity_at_inf: float = 1e-5, rays_jitter_near_far: bool = False, triton_block_size: int = 16, triton_num_warps: int = 4, use_naive_impl: bool = False, ): r""" This is the Pytorch Module for the Lightplane Splatter. It uses `lightplane_mlp_splatter` as the core function and samples the point feature from the corresponding prior input grid `input_grid`, adds the sampled feature to the `encoding` of the ray, passes the latter through an MLP, and splats the MLP output to the grid-list `output_grid`. Args: num_samples: Number of samples to splat. grid_chn: Number of channels in the `output_grid`. input_grid_chn: Number of channels in the `input_grid`. It should be the same as the number of channels for `rays.encoding`. mlp_hidden_chn: Number of hidden channels in the MLP. mlp_n_layers: Number of layers in the MLP. num_samples_inf: Number of samples beyond the `far` plane. mask_out_of_bounds_samples: Whether to mask out-of-bounds samples. contract_coords: Whether to contract the coordinates as in MeRF disparity_at_inf: The beyond-far samples (their number-per-ray is determined by `num_samples_inf`) are sampled in the disparity space in `range[far, 1 / disparity_at_inf]`. rays_jitter_near_far: Whether to jitter the `near` and `far` planes uniformly in range `[-delta, delta]`. triton_block_size: Block size for triton. triton_num_warps: Number of warps for triton. use_naive_impl: Whether to use the naive pytorch implementation. """ super().__init__() self.num_samples = num_samples self.num_samples_inf = num_samples_inf self.mask_out_of_bounds_samples = mask_out_of_bounds_samples self.contract_coords = contract_coords self.disparity_at_inf = disparity_at_inf self.rays_jitter_near_far = rays_jitter_near_far self.triton_block_size = triton_block_size self.triton_num_warps = triton_num_warps self.use_naive_impl = use_naive_impl assert input_grid_chn is not None, "input_grid_chn must be provided" splatter_params = init_splatter_params( device="cpu", n_layers=mlp_n_layers, input_chn=input_grid_chn, hidden_chn=mlp_hidden_chn, out_chn=grid_chn, ) self.mlp_params = torch.nn.Parameter(splatter_params.mlp_params) # register the n_hidden field of decoder_params self.register_buffer("n_hidden", splatter_params.n_hidden, persistent=False) # ray encoding dim is the same as the output channels of the mlp self.rays_encoding_dim = input_grid_chn
[docs] def get_splatter_params(self) -> SplatterParams: r""" Helper function to get the splatter parameters. """ return SplatterParams(self.mlp_params, self.n_hidden)
[docs] def forward( self, rays: Rays, grid_size: list[tuple[int, int, int, int, int]], input_grid: tuple[torch.Tensor, ...] | torch.Tensor, num_samples: int | None = None, num_samples_inf: int | None = None, mask_out_of_bounds_samples: bool | None = None, contract_coords: bool | None = None, disparity_at_inf: float | None = None, input_grid_sizes: list[list[int]] | None = None, rays_jitter_near_far: bool | None = None, return_list: bool = True, # return grid list instead of a stacked tensor regenerate_code: bool = False, ): r""" Forward function for splatting rays into a 'output_grid' with an MLP and `input_grid` as the prior grid. Args: rays: `Rays` to splat. `rays.encoding` is splatted to the `output_grid`. grid_size: List of tuples specifying the grid sizes of `output_grid`. input_grid: Grids to sample the point feature from. num_samples: Number of samples for splatting. num_samples_inf: Number of samples beyond the `far` plane. mask_out_of_bounds_samples: Whether to mask out-of-bounds samples. contract_coords: Whether to contract the coordinates as in MeRF. disparity_at_inf: The beyond-far samples (their number-per-ray is determined by `num_samples_inf`) are sampled in the disparity space in `range[far, 1 / disparity_at_inf]`. input_grid_sizes: The size of the `input_grid`. Only required if `input_grid` is a 2D tensor. rays_jitter_near_far: Whether to jitter the `near` and `far` planes uniformly in range `[-delta, delta]`. return_list: Whether to return a list of grids or a stacked tensor. regenerate_code: Whether to regenerate the code for the splatter. Returns: output_grid: splatted grid. """ num_samples = if_not_none_else(num_samples, self.num_samples) num_samples_inf = if_not_none_else(num_samples_inf, self.num_samples_inf) mask_out_of_bounds_samples = if_not_none_else( mask_out_of_bounds_samples, self.mask_out_of_bounds_samples, ) contract_coords = if_not_none_else(contract_coords, self.contract_coords) disparity_at_inf = if_not_none_else(disparity_at_inf, self.disparity_at_inf) rays_jitter_near_far = if_not_none_else( rays_jitter_near_far, self.rays_jitter_near_far ) # handle ray encoding _check_splatter_ray_encoding_input( rays.encoding, self.rays_encoding_dim, ) assert input_grid is not None, "input_grid must be provided" rays = copy.copy(rays) # shallow copy for futher processing if rays_jitter_near_far: rays.near, rays.far = jitter_near_far( rays.near, rays.far, num_samples, ) kwargs = { "rays": rays, "output_grid_size": grid_size, "mlp_params": self.get_splatter_params(), "input_grid": input_grid, "num_samples": num_samples, "num_samples_inf": num_samples_inf, "mask_out_of_bounds_samples": mask_out_of_bounds_samples, "contract_coords": contract_coords, "disparity_at_inf": disparity_at_inf, "return_list": return_list, "triton_block_size": self.triton_block_size, "triton_num_warps": self.triton_num_warps, "regenerate_code": regenerate_code, } # run lightplane if self.use_naive_impl: lightplane_fn = lightplane_splatter_naive else: lightplane_fn = lightplane_mlp_splatter out = lightplane_fn(**kwargs) return out
def _check_splatter_ray_encoding_input( ray_encoding: torch.Tensor | None, ray_encoding_dim: int, ): if ray_encoding is None: raise ValueError( "The encoding field of input rays is None." " However, the Splatter requires an encoding for input rays." ) if ray_encoding is not None and ray_encoding.shape[1] != ray_encoding_dim: raise ValueError( f"Ray encoding has a wrong dimension." f" Expected: {ray_encoding_dim}, got: {ray_encoding.shape[1]}" )