-
Notifications
You must be signed in to change notification settings - Fork 0
Home
The plntree package offers a GPU-compatible PyTorch implementation of the PLN-Tree model for multivariate hierarchical count data analysis.
In this documentation, you will learn more about the functions of the package and how to use PLN-Tree for different tasks.
If you are specifically interested in data augmentation for microbiome data, check out the TaxaPLN tutorial as well.
The plntree package is available on PyPI for faster installation.
pip install plntreefrom plntree import PLNTree
PLNTree(counts, covariates=None, offsets='zeros', latent_dynamic=None,
variational_approx=None, covariates_params=None, level_regex='|', clade_regex='__',
smart_init=True, device=None, seed=None)Builder for PLN-Tree models as introduced by Chaussard, A., Bonnet, A., Gassiat, E., & Le Corff, S. (2024). Tree-based variational inference for Poisson log-normal models. arXiv preprint arXiv:2406.17361.
Parameters:
-
counts: Count dataframe with samples as rows and features as columns. Columns name define the hierarchical relationships. For instance:
1__Entity|2__Mammal|3__Humanbased onlevel_regexandclade_regexnomenclature. -
covariates: Covariates dataframe with samples as rows and features as columns. If
None, no covariates are used. -
offsets: Offset type for the counts. Can be
'zeros'(default) or'logsum'. -
latent_dynamic: Dictionary of parameters for the latent dynamic model. If
None, defaults to a Markov linear model with 1 layer. Parameters are described below:
| Parameter | Default | Description |
|---|---|---|
n_layers |
1 | Number of hidden layers to parameterize the latent dynamic |
diagonal |
False | Whether the latent dynamic should have diagonal covariance |
combining_layers |
1 | Number of hidden layers to combine preprocessed covariates and prior latent (if applicable) |
markov_means |
True | Whether the prior should be regular Markov or degenerated in the mean parameters |
markov_covariance |
True | Whether the prior should be regular Markov or degenerated in the covariance parameters |
-
variational_approx: Parameters for the variational approximation method. If
None, defaults to a residual backward Markov approximation. Parameters are described below:
| Parameter | Default | Description |
|---|---|---|
method |
'residual' | Variational approximation family mean_field and residual. |
counts_preprocessing |
'proportion' | Preprocessing strategy for count data in the variational networks. Either None, proportion, log or clr
|
combining_laters |
1 | Number of hidden layers to combine preprocessed covariates and prior latent (if applicable) |
n_layers |
None | Number of hidden layers in the mean-field approximation (if mean_field is selected). |
embedder_type |
'GRU' | Amortizing network for the residual variational approximation. Either GRU or LSTM. |
embedding_size |
32 | Output size of the amortizing network. |
n_embedding_layers |
2 | Number of layers of the amortizing network for the residual variational approximation. |
n_embedding_neurons |
32 | Hidden size of the layers in the amortizing network for the residual variational approximation. |
n_after_layers |
1 | Number of layers after the amortizing network for the residual variational approximation at each level. |
-
covariates_params: Parameters for the covariates processing in both the latent dynamic and the variational approximation. If
None, defaults to a FiLM architecture with 1 layer. Parameters are described below:
| Parameter | Default | Description |
|---|---|---|
type |
'film' | Architecture for covariates processing. Either film or attention. |
n_layers |
1 | Number of hidden layers in the respective architectures. |
n_heads |
4 | Number of attention heads if attention is selected |
-
level_regex: Hierarchy regex to split the features into levels. Default is
'|'. -
clade_regex: Clade regex to split the features into clades. Default is
'__'. - smart_init: Initialize parameters at the first level based on the input counts and covariates.
-
device: Use
'cpu'or'cuda'to specify the device for the model. IfNone, defaults to'cpu'. -
seed: Seed for the random number generator to ensure reproducibility. If
None, no seed is set.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
model.fit(max_epoch=10_000, learning_rate=1e-3, grad_clip=5.,
tolerance=1e-4, tolerance_smoothing=1000, batch_size=512,
shuffle=True, verbose=1000, seed=None)Upon declaration, PLN-Tree models can be learned using Adam optimizer through the fit method.
Parameters:
-
max_epoch: Maximum number of gradient descent epochs to perform. Default
10_000. -
learning_rate: Adam optimizer learning rate. Default
1e-4. -
grad_clip: Gradient clipping value to avoid exploding gradients. Default
5.. -
tolerance: Tolerance threshold on smoothed ELBO convergence. Default
1e-4. -
tolerance_smoothing: Smoothing threshold for the ELBO to perform tolerance computation (not mandatory). Default
1000. -
batch_size: Batch size for the PyTorch optimizer. Default to
512. -
shuffle: Randomization of the dataset prior to training. Default to
True. -
verbose: Iteration to update the progress bar on recursion. Default
1000. -
seed: Seed for reproducibility. Default
None.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
Z = model.encode(model.hierarchical_counts) # Project the counts in the latent space
latent_proportion_counts = model.latent_tree_proportions(Z, clr=True, seed=None) # Perform the LP-CLR transformPLN-Tree enables to preprocess counts in order to enhance the predictive performances of downstream analysis. Here, we propose the LTP-CLR (latent tree proportion centered log-ratio) transform, which leverages identifiable features of the model to yield improved performances. In particular, the LTP-CLR is feature-wise identifiable with the original counts, making downstream analysis easier.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
model.fit()
X, Z = model.sample(n_samples, covariates=None, offsets=None,
hierarchy_level=None, seed=None)Upon fitting, PLN-Tree can generate synthetic count data based on the learned latent dynamic.
Parameters:
- n_samples: Number of synthetic data to generate.
- covariates: Covariates to use for the generation (if applicable).
- offsets: Synthetic data offset (if imposed).
- hierarchy_level: Output a restriction of X to a given level and turns it into a DataFrame.
- seed: Seed for reproducibility.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
model.fit()
X, Z = model.vamp_sample(n_samples=10, X_vamp=None, covariates=None, offsets=None,
mean=False, hierarchy_level=None, seed=None)Sampling in VAE-based models can be performed in several ways. Here, we propose an alternative sampling based on the Variational Mixture of Posteriors (VAMP) that can yield highly faithful data, typically interesting for data augmentation.
Parameters:
- n_samples: Number of synthetic data to generate.
-
X_vamp: Counts to use for VAMP sampling. If
Nonethe training set is used. -
covariates: Covariates to use for the generation (if applicable). If
Nonethe training covariates are used. - offsets: Synthetic data offset (if imposed).
-
mean: Take the mean of the encoder rather than sampling from it. Default to
false. - hierarchy_level: Output a restriction of X to a given level and turns it into a DataFrame.
- seed: Seed for reproducibility.
Upon training, PLN-Tree provides several variables that can be exploited for count data analysis.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
model.tree.plot(figaxis=None, colormap='viridis_r', counts=None, visual='circle')The tree variable is extracted from the count dataframe to define the hierarchy on the features. The Tree.plot method allows to visualize the hierarchy in a convenient way.
Parameters:
-
figaxis: Tuple containing pre-defined Figure and Axis objects from matplotlib. IfNone, it automatically generates them. -
colormap: Colormap for visualization based on matplotlib nomenclature. -
counts: If provided, associates weights to the edges and nodes to display a count data on the hierarchy. -
visual: Visual of the hierarchy. Eithercircleortop-down.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
counts = model.countsFor faster computation, the PLNTree module reorders the columns of the original counts DataFrame. Make sure to override the original object with that of PLN-Tree before doing oversampling or data analysis.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
X = model.tree.hierarchical_counts(counts=counts) # Tensor (n x L x K_L)You can convert the counts DataFrame into hierarchical counts embedded into a Tensor using the model.tree.hierarchical_counts function. The results aligns with the tree ordering defined by PLNTree.
from plntree import PLNTree
# Assume a predefined counts dataframe
model = PLNTree(counts)
# Assume predefined level, Z_l_prev and C covariates in tensor form
model.mu(level, Z_l_prev, C)
model.omega(level, Z_l_prev, C)Upon training, the PLNTree module gives access to the inferred PLN-Tree parameters of the latent dynamic