# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import math
import random
import time
import warnings
from typing import List, Optional, Tuple
import torch
from .misc_utils import (
assert_shape,
check_grid_and_color_grid,
flatten_grid,
process_and_flatten_grid,
)
from .mlp_utils import DecoderParams, get_triton_function_input_dims
from .ray_utils import Rays
from .triton_src import get_lightplane_kernels
from .triton_src.shared.const import MIN_BLOCK_SIZE
PROFILE = False
DEBUG = False
[docs]
def lightplane_renderer(
rays: Rays,
grid: tuple[torch.Tensor, ...] | torch.Tensor,
decoder_params: DecoderParams,
# ------ config keys ------
num_samples: int,
gain: float,
num_samples_inf: int = 0,
mask_out_of_bounds_samples: bool = False,
contract_coords: bool = False,
disparity_at_inf: float = 1e-5,
inject_noise_sigma: float = 0.0,
inject_noise_seed: int | None = None,
scaffold: torch.Tensor | None = None,
color_grid: tuple[torch.Tensor, ...] | torch.Tensor | None = None,
grid_sizes: list[list[int]] | None = None,
color_grid_sizes: list[list[int]] | None = None,
regenerate_code: bool = False,
triton_block_size: int = 16,
triton_num_warps: int = 4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
This is the main functional interface for the Lightplane Renderer.
It outputs the the final color `c`, negative_log_transmittance `T_N` and the expected
ray-termination length `r` (analogous to depth) of each ray's pixel.
For `N=num_samples` equispaced 3D points `pt_3d_i` between the `near` and
`far` ray-lengths, it samples the feature `f_i(x_i)` of 3D point
`pt_3d_i =x_i` from the grid-list `grid` and calculate its corresponding
renderering results.
There are three MLPs: `trunk_mlp`, `color_mlp` and `opacity_mlp`, whose
parameters are specified in `decoder_params`.
- trunk_mlp: Regresses the features from the `f_i(x_i)`::
e_i(x_i) = trunk_mlp(f_i(x_i))
- color_mlp: Regresses target features (e.g. final colors) from the
output of `trunk_mlp` with ray encoding as additional input.
The dimension of the output is `color_chn`::
c_i(x_i) = color_mlp(e_i(x_i) + ray_encoding)
- opacity_mlp: Regresses the opacity scalar from the output of `trunk_mlp`::
o_i(x_i) = opacity_mlp(e_i(x_i))
Args:
rays: The rays to render features.
It is an instance of `Rays`, with fields `directions`, `origins`,
`grid_idx`, `near`, `far`, and `encoding`.
`grid_idx` indicates the batch index of the 3D grid to sample
features from::
x_i = rays.origins[i] + (rays.near[i] + delta_i) * rays.direction[i]
grid: Grid-list (a list of 3D grids) to sample features from.
Features are sampled from each 3D grid in the set and summed up as
the final feature.
`grid` contains `N` tensors, each with the shape::
[[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]]
Each tensor must have 5 dimensions and all tensors should have
the same batch size `B` and feature dimension `C`.
Example:
If `grid` is a single Voxel grid::
grid = [torch.tensor([B, D, H, W, C])]
If `grid` is a triplane::
grid = [
torch.tensor([B, 1, H, W, C]),
torch.tensor([B, D, 1, W, C]),
torch.tensor([B, D, H, 1, C]),
]
`lightplane_renderer` can also work with `grid` as a 2D tensor,
which is a stacked tensor from the grid-list `grid`, with the shape
`[sum_(i=1..N)(B * D_i * H_i * W_i), C]`.
In this case, the `grid_sizes` must be provided to specify the shape
of each grid.
Note:
The 2D tensor can be obtained from `lightplane.flatten_grid(grid)`
to flatten the list of tensors and to also obtain the `grid_sizes`
argument.
Note:
Using 2D tensor inputs improves memory-effciency when grid-list is
large in memory.
decoder_params: The parameters of the decoder MLPs:
`trunk_mlp`, `color_mlp`, and `opacity_mlp`.
num_samples: The number of sampled points along the ray.
The samples are equispaced between `rays.near` and `rays.far`.
More specifically, the `j`-th 3d point `x_ij` along `i-th` ray is
defined as follows::
x_ij = rays.origins[i] + (rays.near[i] + j * delta_i) * rays.direction[i],
where:
delta_i = (rays.far[i] - rays.near[i]) / num_samples
gain: A constant to scale the transmittance `T_i` of `i`-the point along a ray::
T_i = exp(-gain * sum_{j=1}^{i} o(x_ij) * delta_i)
num_samples_inf: The number of background samples along the ray.
The first background sample is placed at `rays.far`, and the samples
are spaced in the disparity space until reaching the disparity of
`disparity_at_inf`.
More specifically, the j-th background 3d point `b_ij` along `i-th`
ray is defined as follows::
b_ij = rays.origins[i] + (rays.far[i] + j * bg_delta_ij) * rays.direction[i],
where:
bg_delta_ij = 1 / disparity_ij
disparity_ij = linspace(1, disparity_at_inf, num_samples_inf)[j]
These samples are additional to `num_samples`, i.e. the total number
of samples along a ray is `num_samples + num_samples_inf`.
mask_out_of_bounds_samples: Whether to mask samples that
fall outside the [-1, 1] cube (does not apply when contraction with
`contract_coords` is enabled).
contract_coords: Whether to map the coordinates of the rendered
points to always fall into the [-1, 1] cube. The contraction is implemented
as in MeRF [1]::
x[k] if |x|_inf <= 1
contract(x)[k] = x[k] / |x|_inf if x_k != |x|_inf > 1
(2 - 1/x[k]) x_k / |x_k| if x_k = |x|_inf > 1
Note:
The contraction is useful for representing unbounded scenes.
E.g. outdoor captures where the scene extends to infinity.
disparity_at_inf: The disparity value at infinity.
inject_noise_sigma: The variance of opacity noise to inject.
inject_noise_seed: The seed of the random noise to inject.
scaffold: A voxel grid with shape `[B, D, H, W]`, indicating the occupancy
of the 3D space. If provided, the renderer will only render the points
that are not empty in the scaffold.
color_grid: Another grid-list (a list of 3D grids) storing color features.
If provided, the renderer will regress the color from features
sampled from `color_grid`, using `color_mlp`.
Similar to `grid`, `color_grid` could also be a 2D tensor with
`color_grid_sizes` provided.
`color_grid` should be the same type as `grid`.
grid_sizes: It specifies the size of `grid`.
It is optional when `grid` is a grid-list, but required when `grid`
is a 2D tensor. Example::
grid_sizes = [[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]].
color_grid_sizes: It specifies the size of `color_grid` when `color_grid`
is a 2D tensor.
It is optional when `color_grid` is a grid-list, but required when
`color_grid` is a 2D tensor. Example::
color_grid_sizes = [[B, D_1, H_1, W_1, C], ... , [B, D_N, H_N, W_N, C]]
regenerate_code: If `True`, forces the regeneration of the triton code.
triton_block_size: The block size for Triton. Has to be higher than 16.
triton_num_warps: The number of warps for Triton.
Returns:
ray_length_render: The rendered ray-termination length `r` (i.e. distance along the ray).
negative_log_transmittances: The negative log transmittances of the ray.
feature_render: The rendered features of the ray.
References:
[1] MERF: Memory-Efficient Radiance Fields for Real-time View Synthesis in
Unbounded Scenes, https://arxiv.org/abs/2302.12249
"""
grid, color_grid, grid_sizes, color_grid_sizes = check_grid_and_color_grid(
grid, color_grid, grid_sizes, color_grid_sizes
)
grid, color_grid, grid_sizes, color_grid_sizes = process_and_flatten_grid(
grid, color_grid, grid_sizes, color_grid_sizes
)
rays, n_rays_padded = rays.pad_to_block_size(triton_block_size)
(
mlp_dim_hidden_trunk,
mlp_dim_hidden_opacity,
mlp_dim_hidden_color,
mlp_n_layers_trunk,
mlp_n_layers_opacity,
mlp_n_layers_color,
color_chn_triton,
) = get_triton_function_input_dims(
decoder_params.n_hidden_trunk,
decoder_params.n_hidden_opacity,
decoder_params.n_hidden_color,
)
color_chn = decoder_params.color_chn
if inject_noise_sigma > 0.0:
if inject_noise_seed is None:
inject_noise_seed = int(random.randint(0, 1000000))
else:
inject_noise_seed = 0
(
ray_length_render,
negative_log_transmittances,
feature_render,
) = LightplaneFunction.apply(
grid,
grid_sizes,
decoder_params.mlp_params,
rays.directions,
rays.origins,
rays.grid_idx.to(torch.int32),
rays.near,
rays.far,
rays.encoding,
scaffold,
color_grid,
color_grid_sizes,
# mlp sizes
mlp_dim_hidden_trunk,
mlp_dim_hidden_opacity,
mlp_dim_hidden_color,
mlp_n_layers_trunk,
mlp_n_layers_opacity,
mlp_n_layers_color,
# other settings
num_samples,
num_samples_inf,
gain,
color_chn_triton,
mask_out_of_bounds_samples,
contract_coords,
disparity_at_inf,
inject_noise_sigma,
inject_noise_seed,
# BS
triton_block_size,
triton_num_warps,
regenerate_code,
)
# crop the features to the requested number of channels
if color_chn_triton > color_chn:
feature_render = feature_render[:, :color_chn]
if n_rays_padded > 0:
ray_length_render, negative_log_transmittances, feature_render = (
t[:-n_rays_padded]
for t in [ray_length_render, negative_log_transmittances, feature_render]
)
return ray_length_render, negative_log_transmittances, feature_render
class LightplaneFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
feature_grid: torch.Tensor, # NUM_GRIDS * B * D * H * W * C
feature_grid_sizes: torch.Tensor, # NUM_GRIDS x 5: [[B_1, D_1, H_1, W_1, C_1], ...]
mlp_params: torch.Tensor, # flattened biases and weights of all mlps
directions: torch.Tensor, # N x 3
origins: torch.Tensor, # N x 3
grid_idx: torch.Tensor, # N
near: torch.Tensor, # N
far: torch.Tensor, # N
ray_encoding: torch.Tensor, # N x C
scaffold: torch.Tensor | None, # B x D x H x W x 1
color_feature_grid: torch.Tensor | None, # NUM_GRIDS * B * D * H * W * C
color_feature_grid_sizes: None
| (torch.Tensor), # NUM_GRIDS x 5: [[B_1, D_1, H_1, W_1, C_1], ...]
# mlp sizes
mlp_dim_hidden_trunk: int,
mlp_dim_hidden_opacity: int,
mlp_dim_hidden_color: int,
mlp_n_layers_trunk: int,
mlp_n_layers_opacity: int,
mlp_n_layers_color: int,
# other settings
num_samples: int,
num_samples_inf: int,
gain: float,
num_render_channels: int,
mask_out_of_bounds_samples: bool = False,
contract_coords: bool = False,
disparity_at_inf: float = 1e-5,
inject_noise_sigma: float = 0.0,
inject_noise_seed: int = 0,
# triton block size
BLOCK_SIZE: int = 16,
NUM_WARPS: int = 4,
# force code regeneration
regenerate_code: bool = False,
):
fw_kernel, bw_kernel = get_lightplane_kernels(
"renderer",
mlp_n_layers_trunk,
mlp_n_layers_opacity,
mlp_n_layers_color,
regenerate_code=regenerate_code,
)
ctx.fw_kernel = fw_kernel
ctx.bw_kernel = bw_kernel
if PROFILE:
torch.cuda.synchronize()
time_start = time.time()
use_separate_color_grid = color_feature_grid is not None
assert (
BLOCK_SIZE >= MIN_BLOCK_SIZE
), f"BLOCK_SIZE has to be bigger than {MIN_BLOCK_SIZE}"
assert (
num_render_channels >= MIN_BLOCK_SIZE
), f"num_render_channels has to be bigger than {MIN_BLOCK_SIZE}"
if mask_out_of_bounds_samples and contract_coords:
warnings.warn(
"The renderer has been configured to contract the coordinates"
" lying outside the [-1,1] cube (contract_coords=True)"
" and to also mask out all such points"
" (mask_out_of_bounds_samples=True)."
)
# important sizes
device = feature_grid.device
num_grid_channels = feature_grid.shape[-1]
num_rays = directions.shape[0]
num_grids = feature_grid_sizes.shape[0]
grid_batch_size = feature_grid_sizes[0, 0].item()
assert (feature_grid_sizes[:, 0] == grid_batch_size).all()
assert all(gs[-1] == num_grid_channels for gs in feature_grid_sizes)
grid_spatial_numel = feature_grid_sizes[:, :-1].prod(dim=1).sum()
assert grid_spatial_numel == feature_grid.numel() // num_grid_channels
feature_grid_sizes.shape[:-1].numel()
# https://github.com/openai/triton/issues/2688#issue-2003537756
# mlp size params for kernel:
if use_separate_color_grid:
assert (
mlp_n_layers_trunk == 0
), f"mlp_n_layers_trunk has to be 0 when use_separate_color_grid"
assert (
mlp_dim_hidden_trunk == 0
), f"mlp_dim_hidden_trunk has to be 0 when use_separate_color_grid"
dim_out_trunk = 0
dim_in_trunk = 0
dim_in_color = num_grid_channels
dim_in_opacity = num_grid_channels
dim_out_color = num_render_channels
else:
dim_out_trunk = mlp_dim_hidden_trunk
dim_in_trunk = num_grid_channels
dim_in_color = dim_out_trunk
dim_in_opacity = dim_out_trunk
dim_out_color = num_render_channels
# asserts
assert_shape(feature_grid_sizes, (num_grids, 5))
assert_shape(directions, (num_rays, 3))
assert_shape(origins, (num_rays, 3))
assert_shape(grid_idx, (num_rays,))
assert_shape(near, (num_rays,))
assert_shape(far, (num_rays,))
assert_shape(ray_encoding, (num_rays, dim_in_color))
assert (
math.log2(num_grid_channels) % 1 == 0
), f"num_grid_channels has to be a power of 2"
assert (
num_grid_channels >= MIN_BLOCK_SIZE
), f"num_grid_channels has to be bigger than {MIN_BLOCK_SIZE}"
assert (
math.log2(num_render_channels) % 1 == 0
), f"num_render_channels has to be a power of 2"
assert (
num_render_channels >= MIN_BLOCK_SIZE
), f"num_render_channels has to be bigger than {MIN_BLOCK_SIZE}"
assert mlp_params.ndim == 1
assert (
ray_encoding.shape[1] == dim_in_color
), f"ray_encoding should have the same dimension as dim_in_color"
assert (
num_rays % BLOCK_SIZE == 0
), "We do not support num_rays!=multiple of BLOCK_SIZE."
if use_separate_color_grid:
num_color_grids = color_feature_grid_sizes.shape[0]
assert_shape(color_feature_grid_sizes, (num_color_grids, 5))
assert color_feature_grid.shape[-1] == num_grid_channels
else:
assert (
color_feature_grid_sizes is None
), "color_feature_grid_sizes has to be None when use_separate_color_grid is False"
color_feature_grid_sizes = torch.empty(
(1,), dtype=torch.int32, device=device
)
color_feature_grid = torch.empty((1,), dtype=torch.float32, device=device)
num_color_grids = 0
# check the number of mlp param elems is correct
numel_params_trunk = _get_mlp_n_params(
dim_in_trunk, mlp_dim_hidden_trunk, dim_out_trunk, mlp_n_layers_trunk
)
numel_params_opacity = _get_mlp_n_params(
dim_in_opacity, mlp_dim_hidden_opacity, 1, mlp_n_layers_opacity
)
numel_params_color = _get_mlp_n_params(
dim_in_color, mlp_dim_hidden_color, dim_out_color, mlp_n_layers_color
)
expected_mlp_params_numel = (
numel_params_trunk + numel_params_opacity + numel_params_color
)
assert expected_mlp_params_numel == mlp_params.numel(), (
f"The number of elements in mlp param should be {expected_mlp_params_numel}."
f" Got {mlp_params.numel()} instead."
)
# make sure grid_idx is in the correct range
assert grid_idx.min() >= 0, f"Negative grid index: {grid_idx.min()}"
assert grid_idx.max() <= (
grid_batch_size - 1
), f"A grid index is out of bounds ({grid_idx.max()} >= {grid_batch_size})"
# init output tensors
negative_log_transmittance = torch.zeros(
num_rays, device=device, dtype=torch.float32
)
ray_length_render = torch.zeros(num_rays, device=device, dtype=torch.float32)
feature_render = torch.zeros(
num_rays, num_render_channels, device=device, dtype=torch.float32
)
# use voxel grid scaffold
use_scaffold = scaffold is not None
if scaffold is not None:
scaffold_t = scaffold.reshape(grid_batch_size, -1, 1).float()
feature_grid_sizes = torch.cat(
[
feature_grid_sizes,
torch.tensor(
[[*scaffold.shape, 1]], dtype=torch.int32, device=device
),
],
dim=0,
).int()
else:
scaffold_t = feature_grid.new_empty(1)
# Random noise seed for each ray, we have to pass this in as a
# tensor otherwise triton would jit-recompile every kernel run
# with a different seed.
inject_noise = inject_noise_sigma > 0.0
inject_noise_seed_t = torch.full(
(num_rays,),
inject_noise_seed,
device=device,
dtype=torch.long,
)
n_blocks = int(math.ceil(num_rays / BLOCK_SIZE))
grid = (n_blocks,)
fw_kernel[grid](
# ---- output -----
_contiguous(negative_log_transmittance),
_contiguous(ray_length_render),
_contiguous(feature_render),
# ---- grid ----
_contiguous(feature_grid),
_contiguous(feature_grid_sizes),
_contiguous(color_feature_grid),
_contiguous(color_feature_grid_sizes),
# ----- non-differentiable tensors
_contiguous(directions),
_contiguous(origins),
_contiguous(grid_idx),
_contiguous(near),
_contiguous(far),
_contiguous(ray_encoding),
_contiguous(inject_noise_seed_t),
_contiguous(scaffold_t),
# ---- mlp params ----
_contiguous(mlp_params), # master ptr for the mlp params
mlp_dim_hidden_trunk,
mlp_dim_hidden_opacity,
mlp_dim_hidden_color,
dim_in_trunk,
dim_in_opacity,
dim_in_color,
dim_out_trunk,
dim_out_color,
# ----- config keys ----
num_samples,
num_samples_inf,
gain,
# ----- sizes ----
num_rays,
num_grid_channels,
num_grids,
num_color_grids,
BLOCK_SIZE,
# ---- switches ----
int(mask_out_of_bounds_samples),
int(inject_noise),
float(inject_noise_sigma),
int(contract_coords),
float(disparity_at_inf),
int(use_scaffold),
int(use_separate_color_grid),
num_warps=1 if DEBUG else NUM_WARPS,
)
# save tensors for bw
ctx.save_for_backward(
negative_log_transmittance,
feature_grid,
feature_grid_sizes,
color_feature_grid,
color_feature_grid_sizes,
mlp_params,
directions,
origins,
grid_idx,
near,
far,
ray_encoding,
inject_noise_seed_t,
scaffold_t,
)
# save config keys
ctx.mlp_dim_hidden_trunk = mlp_dim_hidden_trunk
ctx.mlp_dim_hidden_opacity = mlp_dim_hidden_opacity
ctx.mlp_dim_hidden_color = mlp_dim_hidden_color
ctx.mlp_n_layers_trunk = mlp_n_layers_trunk
ctx.mlp_n_layers_opacity = mlp_n_layers_opacity
ctx.mlp_n_layers_color = mlp_n_layers_color
ctx.dim_in_opacity = dim_in_opacity
ctx.dim_in_color = dim_in_color
ctx.dim_out_trunk = dim_out_trunk
ctx.dim_out_color = dim_out_color
ctx.num_samples = num_samples
ctx.num_samples_inf = num_samples_inf
ctx.gain = gain
ctx.num_render_channels = num_render_channels
ctx.mask_out_of_bounds_samples = mask_out_of_bounds_samples
ctx.contract_coords = contract_coords
ctx.disparity_at_inf = disparity_at_inf
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.NUM_WARPS = NUM_WARPS
ctx.num_grids = num_grids
ctx.num_color_grids = num_color_grids
ctx.num_rays = num_rays
ctx.num_grid_channels = num_grid_channels
ctx.scaffold_t = scaffold_t
ctx.inject_noise_seed_t = inject_noise_seed_t
ctx.use_scaffold = use_scaffold
ctx.use_separate_color_grid = use_separate_color_grid
ctx.inject_noise = inject_noise
ctx.inject_noise_sigma = inject_noise_sigma
if PROFILE:
torch.cuda.synchronize()
elapsed = time.time() - time_start
print(f"fw time = {elapsed:1.5f}")
return ray_length_render, negative_log_transmittance, feature_render
@staticmethod
def backward(
ctx,
grad_ray_length_render,
grad_negative_log_transmittances,
grad_feature_render,
):
if PROFILE:
torch.cuda.synchronize()
time_start = time.time()
(
negative_log_transmittances,
feature_grid,
feature_grid_sizes,
color_feature_grid,
color_feature_grid_sizes,
mlp_params,
directions,
origins,
grid_idx,
near,
far,
ray_encoding,
inject_noise_seed_t,
scaffold_t,
) = ctx.saved_tensors
device = feature_grid.device
grad_feature_grid = torch.zeros_like(feature_grid)
grad_mlp_params = torch.zeros_like(mlp_params)
grad_rays_enc = torch.zeros_like(ray_encoding)
if ctx.use_separate_color_grid:
grad_color_feature_grid = torch.zeros_like(color_feature_grid)
else:
grad_color_feature_grid = torch.empty(
(1,), dtype=torch.float32, device=device
)
n_blocks = int(math.ceil(ctx.num_rays / ctx.BLOCK_SIZE))
grid = (n_blocks,)
debug_tensor = torch.zeros((32, 32)).to(feature_grid.device)
ctx.bw_kernel[grid](
negative_log_transmittances,
# ----- differentiable tensors -----
_contiguous(feature_grid),
_contiguous(feature_grid_sizes),
_contiguous(color_feature_grid),
_contiguous(color_feature_grid_sizes),
# ----- non-differentiable tensors -----
_contiguous(directions),
_contiguous(origins),
_contiguous(grid_idx.to(torch.int32)),
_contiguous(near),
_contiguous(far),
_contiguous(ray_encoding),
_contiguous(inject_noise_seed_t),
_contiguous(scaffold_t),
# ----- mlp params -----
_contiguous(mlp_params),
ctx.mlp_dim_hidden_trunk,
ctx.mlp_dim_hidden_opacity,
ctx.mlp_dim_hidden_color,
ctx.dim_in_opacity,
ctx.dim_in_color,
ctx.dim_out_trunk,
ctx.dim_out_color,
# ----- config keys -----
ctx.num_samples,
ctx.num_samples_inf,
ctx.gain,
# ----- sizes -----
ctx.num_rays,
ctx.num_grid_channels,
ctx.num_grids,
ctx.num_color_grids,
ctx.BLOCK_SIZE,
# ----- switches -----
int(ctx.mask_out_of_bounds_samples),
int(ctx.inject_noise),
ctx.inject_noise_sigma,
int(ctx.contract_coords),
ctx.disparity_at_inf,
int(ctx.use_scaffold),
int(ctx.use_separate_color_grid),
# ----- gradients -----
_contiguous(grad_ray_length_render),
_contiguous(grad_negative_log_transmittances),
_contiguous(grad_feature_render),
# ----- gradients output -----
_contiguous(grad_feature_grid),
_contiguous(grad_color_feature_grid),
_contiguous(grad_mlp_params),
_contiguous(grad_rays_enc),
debug_tensor,
# num_warps=1 if DEBUG else None,
)
if PROFILE:
torch.cuda.synchronize()
elapsed = time.time() - time_start
print(f"bw time = {elapsed:1.5f}")
# TODO: remove for speed
assert torch.isfinite(grad_feature_grid).all()
assert torch.isfinite(grad_color_feature_grid).all()
assert torch.isfinite(grad_mlp_params).all()
assert torch.isfinite(grad_rays_enc).all()
return (
grad_feature_grid,
None,
grad_mlp_params,
None,
None,
None,
None,
None,
grad_rays_enc,
None,
grad_color_feature_grid if ctx.use_separate_color_grid else None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
def _contiguous(t: torch.Tensor | None):
t_out = t.contiguous() if t is not None else t
return t_out
def _get_mlp_n_params_weight(dim_in: int, dim_hidden: int, dim_out: int, n_layers: int):
if n_layers == 0:
return 0
if n_layers == 1:
return dim_in * dim_out
if n_layers == 2:
return dim_in * dim_hidden + dim_hidden * dim_out
# n_layers > 2
return dim_hidden * dim_hidden * (n_layers - 2) + _get_mlp_n_params_weight(
dim_in, dim_hidden, dim_out, 2
)
def _get_mlp_n_params_bias(dim_out: int, dim_hidden: int, n_layers: int):
return dim_hidden * max(n_layers - 1, 0) + dim_out
def _get_mlp_n_params(dim_in: int, dim_hidden: int, dim_out: int, n_layers: int):
return _get_mlp_n_params_weight(
dim_in, dim_hidden, dim_out, n_layers
) + _get_mlp_n_params_bias(dim_out, dim_hidden, n_layers)