# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import asdict, dataclass, fields
from enum import Enum
from logging import getLogger
from typing import Optional, Tuple
import torch
from .triton_src.shared.const import MIN_BLOCK_SIZE
logger = getLogger(__name__)
[docs]
@dataclass
class DecoderParams:
    r"""
    Class configuring the learnable parameters of the decoder from Lightplane Renderer.
    The decoder comprises a learnable function that predicts color and an opacity value
    given a grid feature sampled at every point along a rendering ray.
    Specifically, the decoder function consists of three MLPs: `trunk_mlp`, `opacity_mlp`,
    and `color_mlp`.
    The three MLPs predict opacity and color as follows:
        1) `use_separate_color_grid==False`::
            grid -> f_i -> trunk_mlp -> e_i -> e_i + ray_encoding -> color_mlp -> c_i
                                            -> opacity_mlp -> o_i
        If the renderer uses a single grid for both opacity and color, an MLP
        `trunk_mlp` maps the grid-sampled feature `f_i` to a trunk feature `e_i`,
        which is later converted to opacity and color with a pair of additional
        color and opacity MLP heads `color_mlp` and `opacity_mlp`.
        The trunk feature `e_i` is summed with `ray_encoding` before `color_mlp`
        to make the predicted color viewpoint dependent.
        2) `use_separate_color_grid==True`::
            grid       -> f_i  -> opacity_mlp -> o_i
            color_grid -> cf_i -> cf_i + ray_encoding -> color_mlp -> c_i
        If the renderer uses a separate color grid (`use_separate_color_grid==True`),
        the trunk MLP will be omitted
        The `opacity_mlp` and `color_mlp` predict the opacity `o_i` and color
        values `c_i`, respectively,
        given an opacity/color features (`f_i` and `cf_i`) sampled from the
        corresponding grid `grid` and `color_grid`.
    The parameters of the three MLPs are stored in the `mlp_params` attribute.
    Here, `mlp_params` is a 1D tensor which concatenates the flattened weight matrices
    and bias vectors of the three MLPs in the following order::
        mlp_params = torch.cat(
            [
                weights_trunk[0].flatten(),
                ...
                weights_trunk[-1].flatten(),
                biases_trunk[0],
                ...
                biases_trunk[-1],
                weights_opacity[0].flatten(),
                ...
                weights_opacity[-1].flatten(),
                biases_opacity[0],
                ...
                baises_opacity[-1],
                weights_color[0].flatten(),
                ...
                weights_color[n].flatten(),
                biases_color[0],
                ...
                biases_color[-1],
            ]
        )
    Here, `weights_XXX[i]` correspond to a `(M, N)` tensor storing the weight matrix
    of the i-th MLP layer. Similarly, `biases_XXX[i]` is a `(N,)` tensor storing
    the bias vector.
    The MLP multiplies the input features from the right, i.e.::
        output[i+1] = input[i] @ weights_XXX[i] + biases_XXX[i]
    Hence, `M` / `N` is the input / output channel dimension.
    In addition to the `mlp_params`, the `DecoderParams` class stores the number
    of hidden units each MLP. Specifically, `n_hidden_trunk`, `n_hidden_opacity`, and
    `n_hidden_color` are tensors of shape `(n_layers+1,)` that store the number of
    input channels followed by the output channel number of each layer in
    the trunk, opacity, and color MLPs, respectively.
    Note:
        One can convert the 1D `mlp_params` tensor to the more-interpretable
        list of weight matrices and bias tensors using the `flattened_decoder_params_to_list`
        function.
    Note:
        Since the Triton language of Lightplane's GPU kernel constraints the number
        of rendering channels to at least 16, the `color_chn` attribute is used to store
        the effective number of rendered output channels. If the effective number of
        rendered channels is less than 16, the MLP parameters are padded with zeros
        to match the minimum size.
    Attributes:
        mlp_params: The parameters for the Lightplane Rendering decoder.
        n_hidden_trunk: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `trunk_mlp`. Note that this tensor can be empty if the trunk MLP is not used.
        n_hidden_opacity: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `opacity_mlp`.
        n_hidden_color: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `color_mlp`.
        color_chn: The number of rendered channels.
    """
    mlp_params: torch.Tensor
    n_hidden_trunk: torch.Tensor
    n_hidden_opacity: torch.Tensor
    n_hidden_color: torch.Tensor
    color_chn: int 
[docs]
@dataclass
class SplatterParams:
    """
    Class representing learnable parameters of the MLP from Lightplane Splatter.
    The splatter comprises a learnable function that predicts a vector splatted
    to the output 3D feature grid. Specifically, the function is defined as follows::
        MLP(feature_grid[x] + splatting_feature[u]) -> splat_vector[x]
    where `x` corresponds to the 3D point along the the ray of pixel `u`,
    `feature_grid[x]` is the input shape grid sampled at point `x`, and
    `splatting_feature[u]` is the splatted feature at pixel `u`.
    The splatting MLP outputs `splat_vector[x]` which is pushed back into the
    output grid.
    The parameters of the MLP are stored in the `mlp_params` attribute.
    Here, `mlp_params` is a 1D tensor which concatenates the flattened weight matrices
    and bias vectors of the MLP in the following order::
        mlp_params = torch.cat(
            [
                weights[0].flatten(),
                ...
                weights[-1].flatten(),
                biases[0],
                ...
                biases[-1],
            ]
        )
    Here, `weights[i]` correspond to a `(M, N)` tensor storing the weight matrix
    of the i-th MLP layer. Similarly, `biases[i]` is a `(N,)` tensor storing
    the bias vector.
    The MLP multiplies the input features from the right, i.e.::
        output[i+1] = input[i] @ weights[i] + biases[i]
    Hence, `M` / `N` is the input / output channel dimension.
    In addition to the `mlp_params`, the `SplatterParams` class stores the number
    of MLP's hidden units. Specifically, the `n_hidden` field is a tensor of shape
    `(n_layers+1,)` that stores the number of input channels followed by
    the output channel number of each layer in the MLP.
    Attributes:
        mlp_params: The parameters for the Lightplane rendering decoder.
        n_hidden: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            splatting MLP.
    """
    mlp_params: torch.Tensor
    n_hidden: torch.Tensor 
[docs]
def init_decoder_params(
    device: torch.device,
    n_layers_opacity: int,
    n_layers_trunk: int,
    n_layers_color: int,
    input_chn: int = 32,
    hidden_chn: int = 32,
    color_chn: int = 3,
    opacity_init_bias: float = 0.0,
    pad_color_channels_to_min_block_size: bool = True,
    use_separate_color_grid: bool = False,
) -> DecoderParams:
    """
    The function initializes the learnable parameters of the Lightplane Renderer
    decoder given mlp configurations.
    Weights and biases of three MLPs inside decoder (`trunk_mlp`, `opacity_mlp`,
    and `color_mlp`) are initialized using Xavier initialization by function `_xavier_init_mlp_params`,
    and are flattened into a single tensor `mlp_params` by function `flatten_decoder_params`.
    Since the Triton language of Lightplane's GPU kernel constraints the number
    of rendering channels to at least 16, the `color_chn` attribute is used to store
    the effective number of rendered output channels. If the effective number of
    rendered channels is less than 16, the MLP parameters are padded with zeros
    to match the minimum size.
    Args:
        device: The device to store the parameters.
        n_hidden_trunk: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `trunk_mlp`. Note that this tensor can be empty if the trunk MLP is not used.
        n_hidden_opacity: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `opacity_mlp`.
        n_hidden_color: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `color_mlp`.
        input_chn: The number of input channels, which is the number of channel for
            `feature_grid`.
        hidden_chn: The number of hidden units in the MLP layers.
        color_chn: The number of rendered channels.
        opacity_init_bias: The initial bias value for the opacity MLP.
        pad_color_channels_to_min_block_size: If True, the MLP parameters are padded with zeros
            to match the minimum size of the triton minimum block size.
        use_separate_color_grid: If True, the renderer uses a separate color grid.
    """
    if n_layers_trunk > 0:
        assert not use_separate_color_grid, (
            "Cannot use trunk MLP with a separate color grid."
            " Please set n_layers_trunk==0."
        )
        (weights_trunk, biases_trunk,) = _xavier_init_mlp_params(
            n_layers_trunk,
            input_chn,
            hidden_chn,
            hidden_chn,
            device,
        )
    else:
        weights_trunk = []
        biases_trunk = []
    (weights_opacity, biases_opacity,) = _xavier_init_mlp_params(
        n_layers_opacity,
        input_chn if use_separate_color_grid else hidden_chn,
        hidden_chn,
        1,
        device,
        last_bias=opacity_init_bias,
    )
    (weights_color, biases_color,) = _xavier_init_mlp_params(
        n_layers_color,
        input_chn if use_separate_color_grid else hidden_chn,
        hidden_chn,
        color_chn,
        device,
    )
    # for p in [
    #     *weights_trunk, *biases_trunk,
    #     *weights_opacity, *biases_opacity,
    #     *weights_color, *biases_color,
    # ]:
    #     print(p.shape)
    # set the mlp params
    (
        mlp_params,
        n_hidden_trunk,
        n_hidden_opacity,
        n_hidden_color,
    ) = flatten_decoder_params(
        weights_trunk,
        biases_trunk,
        weights_opacity,
        biases_opacity,
        weights_color,
        biases_color,
        pad_color_channels_to_min_block_size,
    )
    return DecoderParams(
        mlp_params,
        n_hidden_trunk,
        n_hidden_opacity,
        n_hidden_color,
        color_chn,
    ) 
[docs]
def init_splatter_params(
    device: torch.device,
    n_layers: int,
    input_chn: int = 32,
    hidden_chn: int = 32,
    out_chn: int = 16,
) -> SplatterParams:
    """
    The function initializes the learnable parameters of the Lightplane Splatter
    given mlp configurations.
    Weights and biases of the MLP inside LightPlane Splatter are initialized using
    Xavier initialization by function `_xavier_init_mlp_params`,
    and are flattened into a single tensor `mlp_params` by function `flatten_splatter_params`.
    Since the outout of the mlp is a vector splatted to the output 3D feature grid,
    whose number of channels is the same as the `output_grid`, which is typically more
    than 16.
    So we do not need to pad the MLP parameters to match the minimum size of the triton
    minimum block size.
    Args:
        device: The device to store the parameters.
        n_layers: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            mlp.
        input_chn: The number of input channels.
        hidden_chn: The number of hidden units in the MLP layers.
        out_chn: The number of output channels.
    """
    (weights, biases) = _xavier_init_mlp_params(
        n_layers, input_chn, hidden_chn, out_chn, device
    )
    mlp_params, n_hidden = flatten_splatter_params(
        weights,
        biases,
    )
    return SplatterParams(
        mlp_params,
        n_hidden,
    ) 
# ------------------------
# --- Helper functions ---
# ------------------------
[docs]
def flatten_decoder_params(
    weights_trunk: Tuple[torch.Tensor, ...],
    biases_trunk: Tuple[torch.Tensor, ...],
    weights_opacity: Tuple[torch.Tensor, ...],
    biases_opacity: Tuple[torch.Tensor, ...],
    weights_color: Tuple[torch.Tensor, ...],
    biases_color: Tuple[torch.Tensor, ...],
    pad_color_channels_to_min_block_size: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    The hepler function to flatten the decoder parameters into a single tensor,
    and get the number of hidden units for each layer in each MLP (`n_hidden_XX`).
    Args:
        weights_trunk: Tuple of weight matrices for `trunk_mlp`.
        biases_trunk: Tuple of bias vectors for `trunk_mlp`.
        weights_opacity: Tuple of weight matrices for `opacity_mlp`.
        biases_opacity: Tuple of bias vectors for `opacity_mlp`.
        weights_color: Tuple of weight matrices for `color_mlp`.
        biases_color: Tuple of bias vectors for `color_mlp`.
        pad_color_channels_to_min_block_size: If True, the MLP parameters are padded with zeros
            to match the minimum size of the triton minimum block size.
    """
    # TODO: return the flattened param vector from DecoderParams directly
    num_pad_channels_color = 0
    if pad_color_channels_to_min_block_size:
        color_chn = biases_color[-1].numel()
        num_pad_channels_color = max(MIN_BLOCK_SIZE - color_chn, 0)
    if num_pad_channels_color > 0:
        weights_color, biases_color = _pad_color_mlp_params(
            weights_color,
            biases_color,
            num_pad_channels_color,
        )
    mlp_params = torch.cat(
        [
            t_elem.reshape(-1).contiguous()
            for t in [
                weights_trunk,
                biases_trunk,
                weights_opacity,
                biases_opacity,
                weights_color,
                biases_color,
            ]
            for t_elem in t
        ],
        dim=0,
    ).contiguous()
    # set the numbers of hidden units in each mlp
    n_hidden_trunk, n_hidden_opacity, n_hidden_color = (
        _get_n_hidden(w, device=mlp_params.device)
        for w in [weights_trunk, weights_opacity, weights_color]
    )
    _validate_flattened_mlp_params(
        mlp_params,
        n_hidden_trunk,
        n_hidden_opacity,
        n_hidden_color,
        pad_color_channels_to_min_block_size=pad_color_channels_to_min_block_size,
    )
    return mlp_params, n_hidden_trunk, n_hidden_opacity, n_hidden_color 
[docs]
def flatten_splatter_params(
    weights: Tuple[torch.Tensor, ...],
    biases: Tuple[torch.Tensor, ...],
):
    """
    The hepler function to flatten the splatter parameters into a single tensor,
    and get the number of hidden units for each layer in the MLP (`n_hidden`).
    Args:
        weights: Tuple of weight matrices for the MLP.
        biases: Tuple of bias vectors for the MLP.
    """
    mlp_params = torch.cat(
        [
            t_elem.reshape(-1).contiguous()
            for t in [
                weights,
                biases,
            ]
            for t_elem in t
        ],
        dim=0,
    ).contiguous()
    # set the numbers of hidden units in each mlp
    n_hidden = _get_n_hidden(weights, device=mlp_params.device)
    return mlp_params, n_hidden 
[docs]
def flattened_decoder_params_to_list(
    mlp_params: torch.Tensor,
    n_hidden_trunk: torch.Tensor,
    n_hidden_opacity: torch.Tensor,
    n_hidden_color: torch.Tensor,
    transpose: bool = False,
) -> Tuple[
    Tuple[torch.Tensor, ...],
    Tuple[torch.Tensor, ...],
    Tuple[torch.Tensor, ...],
    Tuple[torch.Tensor, ...],
    Tuple[torch.Tensor, ...],
    Tuple[torch.Tensor, ...],
]:
    """
    This function converts the flattened MLP parameters into a list of weight matrices,
    and bias vectors for each MLP.
    It is the inverse function of `flatten_decoder_params`.
    Args:
        mlp_params: The flattened MLP parameters, i.e. 1D tensor.
        n_hidden_trunk: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `trunk_mlp`. Note that this tensor can be empty if the trunk MLP is not used.
        n_hidden_opacity: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `opacity_mlp`.
        n_hidden_color: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `color_mlp`.
        transpose: If True, the weight matrices are transposed.
        
    Returns:
        weights_trunk: Weight matrices of the trunk MLP.
        biases_trunk: Bias vectors of the trunk MLP.
        weights_opacity: Weight matrices of the opacity MLP.
        biases_opacity: Bias vectors of the opacity MLP.
        weights_color: Weight matrices of the color MLP.
        biases_color: Bias vectors of the color MLP.
    """
    numel_trunk, numel_opacity, numel_color = (
        (nh[:-1].to(torch.float) @ nh[1:].to(torch.float)).to(torch.int32)
        + nh[1:].sum()
        for nh in [n_hidden_trunk, n_hidden_opacity, n_hidden_color]
    )
    weights_trunk, biases_trunk = _flattened_one_mlp_params_to_list(
        mlp_params[:numel_trunk],
        n_hidden_trunk,
        transpose,
    )
    weights_opacity, biases_opacity = _flattened_one_mlp_params_to_list(
        mlp_params[numel_trunk : (numel_trunk + numel_opacity)],
        n_hidden_opacity,
        transpose,
    )
    weights_color, biases_color = _flattened_one_mlp_params_to_list(
        mlp_params[(numel_trunk + numel_opacity) :],
        n_hidden_color,
        transpose,
    )
    return (
        weights_trunk,
        biases_trunk,
        weights_opacity,
        biases_opacity,
        weights_color,
        biases_color,
    ) 
[docs]
def flattened_triton_decoder_to_list(
    mlp_params: torch.Tensor,
    n_layers_trunk: int,
    n_layers_opacity: int,
    n_layers_color: int,
    input_chn: int,
    hidden_chn: int,
    color_chn: int,
):
    """
    Another helper function to convert the flattened MLP parameters into a list
    of weight matrices, and bias vectors for each MLP.
    Given `mlp_params`, the number of layers for each MLP, input/output number
    of channesl, and hidden units number, this function returns the list of weight
    matrices and bias vectors for each MLP.
    Args:
        mlp_params: The flattened MLP parameters, i.e. 1D tensor.
        n_layers_trunk: The number of layers in the `trunk_mlp`.
        n_layers_opacity: The number of layers in the `opacity_mlp`.
        n_layers_color: The number of layers in the `color_mlp`.
        input_chn: The number of input channels.
        hidden_chn: The number of hidden units in the MLP layers.
        color_chn: The number of rendered channels.
    """
    def _make_n_hidden(dim_in, dim_hidden, dim_out, n_layers):
        n_hidden = [dim_in]
        for _ in range(n_layers - 1):
            n_hidden.append(dim_hidden)
        n_hidden.append(dim_out)
        return torch.tensor(n_hidden, dtype=torch.int32, device=mlp_params.device)
    n_hidden_trunk = _make_n_hidden(input_chn, hidden_chn, hidden_chn, n_layers_trunk)
    n_hidden_opacity = _make_n_hidden(hidden_chn, hidden_chn, 1, n_layers_opacity)
    n_hidden_color = _make_n_hidden(hidden_chn, hidden_chn, color_chn, n_layers_color)
    return flattened_decoder_params_to_list(
        mlp_params,
        n_hidden_trunk,
        n_hidden_opacity,
        n_hidden_color,
        transpose=False,
    ) 
# --------------------------------
# --- Helper private functions ---
# --------------------------------
def _validate_flattened_mlp_params(
    mlp_params: torch.Tensor,
    n_hidden_trunk: torch.Tensor,
    n_hidden_opacity: torch.Tensor,
    n_hidden_color: torch.Tensor,
    pad_color_channels_to_min_block_size: bool = False,
):
    """
    A helper function to validate whether the size of `mlp_params` satisfies the
    configuration specified by `n_hidden_trunk`, `n_hidden_opacity`, and
    `n_hidden_color`.
    Args:
        mlp_params: The flattened MLP parameters, i.e. 1D tensor.
        n_hidden_trunk: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `trunk_mlp`. Note that this tensor can be empty if the trunk MLP is not used.
        n_hidden_opacity: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `opacity_mlp`.
        n_hidden_color: `(n_layers+1,)` Long tensor storing the number of
            input channels followed by the number of hidden units in each layer of the
            `color_mlp`.
        pad_color_channels_to_min_block_size: If True, the MLP parameters are padded with zeros
            to match the minimum size of the triton minimum block size.
    """
    assert n_hidden_trunk.dtype == torch.int32
    assert n_hidden_opacity.dtype == torch.int32
    assert n_hidden_color.dtype == torch.int32
    assert mlp_params.dtype == torch.float
    (
        weights_trunk,
        biases_trunk,
        weights_opacity,
        biases_opacity,
        weights_color,
        biases_color,
    ) = flattened_decoder_params_to_list(
        mlp_params,
        n_hidden_trunk,
        n_hidden_opacity,
        n_hidden_color,
        transpose=False,
    )
    for w, b in (
        (weights_trunk, biases_trunk),
        (weights_opacity, biases_opacity),
        (weights_color, biases_color),
    ):
        _validate_mlp_params_list(w, b)
    if pad_color_channels_to_min_block_size:
        assert biases_color[-1].numel() >= MIN_BLOCK_SIZE
def _validate_mlp_params_list(
    weights_list: Tuple[torch.Tensor, ...],
    biases_list: Tuple[torch.Tensor, ...],
):
    """
    Helper function to validate the weight matrices and bias vectors of an MLP.
    It checks the shape and device of the weights and biases, and the consistency
    of the dimensions between the layers.
    """
    for l, (w, b) in enumerate(zip(weights_list, biases_list)):
        dim_in = w.shape[0]
        dim_out = w.shape[1]
        assert w.device == b.device
        assert b.ndim == 1
        assert w.ndim == 2
        assert dim_out == b.shape[0]
        if l > 0:
            w_prev = weights_list[l - 1]
            assert w_prev.shape[1] == dim_in
def _flattened_one_mlp_params_to_list(
    mlp_params: torch.Tensor,
    n_hidden: torch.Tensor,
    transpose: bool = False,
) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
    """
    Helper function to convert the flattened MLP parameters into a list of weight matrices
    """
    nl = n_hidden.shape[0] - 1
    indims = n_hidden[:nl].tolist()
    outdims = n_hidden[1:].tolist()
    numels = n_hidden[:-1] * n_hidden[1:]
    tot_numel = numels.sum()
    w_mlp_params, b_mlp_params = mlp_params[:tot_numel], mlp_params[tot_numel:]
    assert w_mlp_params.numel() == tot_numel
    assert b_mlp_params.numel() == sum(outdims)
    weights = [
        w.reshape(indim, outdim)
        for w, indim, outdim in zip(
            w_mlp_params.split(numels.tolist()),
            indims,
            outdims,
        )
    ]
    biases = b_mlp_params.split(outdims)
    if transpose:
        weights = [w.t().contiguous() for w in weights]
    return weights, biases
def _get_n_hidden(
    w: Tuple[torch.Tensor, ...],
    device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Helper function to get the number of hidden units for each layer in the MLP.
    """
    if len(w) == 0:
        return torch.tensor([], dtype=torch.int32, device=device)
    n_hidden = [w_.shape[1] for w_ in w]
    n_hidden.insert(0, w[0].shape[0])
    n_hidden = torch.tensor(
        n_hidden,
        dtype=torch.int32,
        device=device,
    ).contiguous()
    return n_hidden
def _pad_color_mlp_params(
    weights: Tuple[torch.Tensor],
    biases: Tuple[torch.Tensor],
    n_pad: int,
):
    """
    Helper function to pad the MLP parameters with zeros to match the minimum output
    size.
    """
    weights[-1] = torch.nn.functional.pad(weights[-1], [0, n_pad])
    biases[-1] = torch.nn.functional.pad(biases[-1], [0, n_pad])
    return weights, biases
def _xavier_init_mlp_params(
    n_layers: int,
    input_chn: int,
    hidden_chn: int,
    output_chn: int,
    device: torch.device,
    last_bias: float = 0.0,
    last_num_pad_channels: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Helper function to initialize the weights and biases of an MLP with Xavier
    initialization and zero padding for output channels.
    """
    weights = [
        torch.empty(
            input_chn if l == 0 else hidden_chn,
            output_chn if l == n_layers - 1 else hidden_chn,
            device=device,
        )
        for l in range(n_layers)
    ]
    for wi, w in enumerate(weights):  # xavier init the weights
        w_init = w
        torch.nn.init.xavier_uniform_(w_init, gain=torch.nn.init.calculate_gain("relu"))
        weights[wi] = w_init.contiguous()
    biases = [
        (
            torch.full((output_chn,), device=device, fill_value=last_bias)
            if l == n_layers - 1
            else torch.zeros(hidden_chn, device=device)
        )
        for l in range(n_layers)
    ]
    if last_num_pad_channels > 0:
        weights[-1] = torch.cat(
            [
                weights[-1],
                torch.zeros(
                    output_chn,
                    last_num_pad_channels,
                    device=device,
                ),
            ],
            dim=1,
        )
        biases[-1] = torch.cat(
            [
                biases[-1],
                torch.zeros(
                    last_num_pad_channels,
                    device=device,
                ),
            ],
            dim=0,
        )
    return weights, biases