55# LICENSE file in the root directory of this source tree.
66
77
8+ from typing import Optional
9+
810import torch
911import torch .nn .functional as F
12+
1013from pytorch3d .common .compat import meshgrid_ij
14+
1115from pytorch3d .structures import Meshes
1216
1317
@@ -50,7 +54,14 @@ def ravel_index(idx, dims) -> torch.Tensor:
5054
5155
5256@torch .no_grad ()
53- def cubify (voxels , thresh , device = None , align : str = "topleft" ) -> Meshes :
57+ def cubify (
58+ voxels : torch .Tensor ,
59+ thresh : float ,
60+ * ,
61+ feats : Optional [torch .Tensor ] = None ,
62+ device = None ,
63+ align : str = "topleft"
64+ ) -> Meshes :
5465 r"""
5566 Converts a voxel to a mesh by replacing each occupied voxel with a cube
5667 consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@@ -59,6 +70,9 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
5970 voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
6071 thresh: A scalar threshold. If a voxel occupancy is larger than
6172 thresh, the voxel is considered occupied.
73+ feats: A FloatTensor of shape (N, K, D, H, W) containing the color information
74+ of each voxel. K is the number of channels. This is supported only when
75+ align == "center"
6276 device: The device of the output meshes
6377 align: Defines the alignment of the mesh vertices and the grid locations.
6478 Has to be one of {"topleft", "corner", "center"}. See below for explanation.
@@ -177,6 +191,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
177191 # boolean to linear index
178192 # NF x 2
179193 linind = torch .nonzero (faces_idx , as_tuple = False )
194+
180195 # NF x 4
181196 nyxz = unravel_index (linind [:, 0 ], (N , H , W , D ))
182197
@@ -238,6 +253,21 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
238253 grid_verts .index_select (0 , (idleverts [n ] == 0 ).nonzero (as_tuple = False )[:, 0 ])
239254 for n in range (N )
240255 ]
241- faces_list = [nface - idlenum [n ][nface ] for n , nface in enumerate (faces_list )]
242256
243- return Meshes (verts = verts_list , faces = faces_list )
257+ textures_list = None
258+ if feats is not None and align == "center" :
259+ # We return a TexturesAtlas containing one color for each face
260+ # N x K x D x H x W -> N x H x W x D x K
261+ feats = feats .permute (0 , 3 , 4 , 2 , 1 )
262+
263+ # (NHWD) x K
264+ feats = feats .reshape (- 1 , feats .size (4 ))
265+ feats = torch .index_select (feats , 0 , linind [:, 0 ])
266+ feats = feats .reshape (- 1 , 1 , 1 , feats .size (1 ))
267+ feats_list = list (torch .split (feats , split_size .tolist (), 0 ))
268+ from pytorch3d .renderer .mesh .textures import TexturesAtlas
269+
270+ textures_list = TexturesAtlas (feats_list )
271+
272+ faces_list = [nface - idlenum [n ][nface ] for n , nface in enumerate (faces_list )]
273+ return Meshes (verts = verts_list , faces = faces_list , textures = textures_list )
0 commit comments