From 7a2b81b09956776101e35d21e1524938aa1bb421 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 23 May 2023 14:04:56 +0200
Subject: [PATCH 1/9] change some properties to private properties (i.e. add a
 blank _)

---
 pyPLNmodels/models.py | 26 ++++++++++++--------------
 1 file changed, 12 insertions(+), 14 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 739bf308..960cdd0d 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -327,7 +327,7 @@ class _Pln(ABC):
             if abs(criterion) < tol:
                 stop_condition = True
             if verbose and self.nb_iteration_done % 50 == 0:
-                self.print_stats()
+                self._print_stats()
         self._print_end_of_fitting_message(stop_condition, tol)
         self._fitted = True
 
@@ -406,7 +406,7 @@ class _Pln(ABC):
                 np.round(self._plotargs.criterions[-1], 8),
             )
 
-    def print_stats(self):
+    def _print_stats(self):
         """
         Print the training statistics.
         """
@@ -491,7 +491,7 @@ class _Pln(ABC):
         delimiter = "=" * NB_CHARACTERS_FOR_NICE_PLOT
         string = f"A multivariate Poisson Lognormal with {self._description} \n"
         string += f"{delimiter}\n"
-        string += _nice_string_of_dict(self.dict_for_printing)
+        string += _nice_string_of_dict(self._dict_for_printing)
         string += f"{delimiter}\n"
         string += "* Useful properties\n"
         string += f"    {self._useful_properties_string}\n"
@@ -762,7 +762,7 @@ class _Pln(ABC):
         path_of_directory : str, optional
             The path of the directory to save the parameters, by default "./".
         """
-        path = f"{path_of_directory}/{self.path_to_directory}{self.directory_name}"
+        path = f"{path_of_directory}/{self._path_to_directory}{self._directory_name}"
         os.makedirs(path, exist_ok=True)
         for key, value in self._dict_parameters.items():
             filename = f"{path}/{key}.csv"
@@ -901,7 +901,7 @@ class _Pln(ABC):
         self._coef = coef
 
     @property
-    def dict_for_printing(self):
+    def _dict_for_printing(self):
         """
         Property representing the dictionary for printing.
 
@@ -1007,7 +1007,7 @@ class _Pln(ABC):
         return covariates @ self.coef
 
     @property
-    def directory_name(self):
+    def _directory_name(self):
         """
         Property representing the directory name.
 
@@ -1019,7 +1019,7 @@ class _Pln(ABC):
         return f"{self._NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
     @property
-    def path_to_directory(self):
+    def _path_to_directory(self):
         """
         Property representing the path to the directory.
 
@@ -1641,12 +1641,10 @@ class PlnPCAcollection:
             )
             if i < len(self.values()) - 1:
                 next_model = self[self.ranks[i + 1]]
-                self.init_next_model_with_previous_parameters(next_model, model)
+                self._init_next_model_with_current_model(next_model, model)
         self._print_ending_message()
 
-    def init_next_model_with_previous_parameters(
-        self, next_model: Any, current_model: Any
-    ):
+    def _init_next_model_with_current_model(self, next_model: Any, current_model: Any):
         """
         Initialize the next model with the parameters of the current model.
 
@@ -1922,7 +1920,7 @@ class PlnPCAcollection:
                 model.save(path_of_directory)
 
     @property
-    def directory_name(self) -> str:
+    def _directory_name(self) -> str:
         """
         Property representing the directory name.
 
@@ -2152,7 +2150,7 @@ class PlnPCA(_Pln):
         self._latent_var = latent_var
 
     @property
-    def directory_name(self) -> str:
+    def _directory_name(self) -> str:
         """
         Property representing the directory name.
 
@@ -2192,7 +2190,7 @@ class PlnPCA(_Pln):
         self._smart_init_coef()
 
     @property
-    def path_to_directory(self) -> str:
+    def _path_to_directory(self) -> str:
         """
         Property representing the path to the directory.
 
-- 
GitLab


From a30e4f00a61a31cb6c977058951777848b6e5104 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 23 May 2023 15:31:47 +0200
Subject: [PATCH 2/9] error in the tests, did not update the blank properties,
 and add to one test other fixtures.

---
 tests/conftest.py    | 4 ++--
 tests/test_common.py | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index f70e6aca..5dd35b20 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -83,10 +83,10 @@ def convenientpln(*args, **kwargs):
 
 
 def generate_new_model(model, *args, **kwargs):
-    name_dir = model.directory_name
+    name_dir = model._directory_name
     name = model._NAME
     if name in ("Pln", "PlnPCA"):
-        path = model.path_to_directory + name_dir
+        path = model._path_to_directory + name_dir
         init = load_model(path)
         if name == "Pln":
             new = convenientpln(*args, **kwargs, dict_initialization=init)
diff --git a/tests/test_common.py b/tests/test_common.py
index 64cb6cae..19d8673e 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -9,7 +9,7 @@ from tests.utils import MSE, filter_models
 from tests.import_data import true_sim_0cov, true_sim_2cov
 
 
-@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
+@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
 @filter_models(["Pln", "PlnPCA"])
 def test_properties(any_pln):
     assert hasattr(any_pln, "latent_parameters")
-- 
GitLab


From c2188375d54c764f1d93aa11ec0e2a7b2175f096 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 23 May 2023 19:33:49 +0200
Subject: [PATCH 3/9] useless since in the pyproject.toml

---
 tests/requirements.txt | 8 --------
 1 file changed, 8 deletions(-)
 delete mode 100644 tests/requirements.txt

diff --git a/tests/requirements.txt b/tests/requirements.txt
deleted file mode 100644
index 19e8bedd..00000000
--- a/tests/requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-numpy
-pandas
-pyPLNmodels
-pytest
-pytest_lazy_fixture
-scanpy
-scikit_learn
-torch
-- 
GitLab


From 495fa96d8a278ea8e37a1fdcd174d886b7a2e157 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 23 May 2023 19:57:42 +0200
Subject: [PATCH 4/9] re arranged docstrings for chatgpt to _utils and elbos. I
 went to _init_S in _utils.

---
 pyPLNmodels/_closed_forms.py |  4 +--
 pyPLNmodels/_utils.py        | 70 +++++++++++-------------------------
 pyPLNmodels/elbos.py         | 19 +++++-----
 3 files changed, 31 insertions(+), 62 deletions(-)

diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 1405b6df..a08196e1 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -11,7 +11,7 @@ def _closed_formula_covariance(
     n_samples: int,
 ) -> torch.Tensor:
     """
-    Compute the closed-form covariance for the M step of the noPCA model.
+    Compute the closed-form covariance for the M step of the Pln model.
 
     Parameters:
     ----------
@@ -46,7 +46,7 @@ def _closed_formula_coef(
     covariates: torch.Tensor, latent_mean: torch.Tensor
 ) -> Optional[torch.Tensor]:
     """
-    Compute the closed-form coef for the M step of the noPCA model.
+    Compute the closed-form coef for the M step of the Pln model.
 
     Parameters:
     ----------
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 0206b3b4..8fcff06c 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -29,17 +29,17 @@ class _PlotArgs:
         Parameters
         ----------
         window : int
-            The size of the window for running statistics.
+            The size of the window for computing the criterion.
         """
         self.window = window
         self.running_times = []
-        self.criterions = [1] * window
+        self.criterions = [1] * window  # the first window criterion won't be computed.
         self._elbos_list = []
 
     @property
     def iteration_number(self) -> int:
         """
-        Get the number of iterations.
+        Numer of iterations done when fitting the model.
 
         Returns
         -------
@@ -48,16 +48,14 @@ class _PlotArgs:
         """
         return len(self._elbos_list)
 
-    def _show_loss(self, ax=None, name_doss=""):
+    def _show_loss(self, ax=None):
         """
-        Show the loss plot.
+        Show the loss of the model (i.e. the negative ELBO).
 
         Parameters
         ----------
         ax : matplotlib.axes.Axes, optional
-            The axes object to plot on. If not provided, the current axes will be used.
-        name_doss : str, optional
-            The name of the loss. Default is an empty string.
+            The axes object to plot on. If not provided, will be created.
         """
         ax = plt.gca() if ax is None else ax
         ax.plot(self.running_times, -np.array(self._elbos_list), label="Negative ELBO")
@@ -70,12 +68,13 @@ class _PlotArgs:
 
     def _show_stopping_criterion(self, ax=None):
         """
-        Show the stopping criterion plot.
+        Show the stopping criterion plot. The gradient ascent
+        stops according to this critertion.
 
         Parameters
         ----------
         ax : matplotlib.axes.Axes, optional
-            The axes object to plot on. If not provided, the current axes will be used.
+            The axes object to plot on. If not provided, will be created.
         """
         ax = plt.gca() if ax is None else ax
         ax.plot(
@@ -94,9 +93,9 @@ def _init_covariance(
     counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor
 ) -> torch.Tensor:
     """
-    Initialization for covariance for the Pln model. Take the log of counts
-    (careful when counts=0), remove the covariates effects X@coef and
-    then do as a MLE for Gaussians samples.
+    Initialization for the covariance for the Pln model. Take the log of counts
+    (careful when counts=0), and computes the Maximum Likelihood
+    Estimator in the gaussian case.
 
     Parameters
     ----------
@@ -162,7 +161,8 @@ def _init_latent_mean(
     eps=7e-3,
 ) -> torch.Tensor:
     """
-    Initialization for the variational parameter M. Basically, the mode of the log_posterior is computed.
+    Initialization for the variational parameter latent_mean.
+    Basically, the mode of the log_posterior is computed.
 
     Parameters
     ----------
@@ -181,8 +181,8 @@ def _init_latent_mean(
     lr : float, optional
         The learning rate of the optimizer. Default is 0.01.
     eps : float, optional
-        The tolerance. The algorithm will stop if the maximum of |W_t-W_{t-1}| is lower than eps,
-        where W_t is the t-th iteration of the algorithm. Default is 7e-3.
+        The tolerance. The algorithm will stop as soon as the criterion is lower than the tolerance.
+        Default is 7e-3.
 
     Returns
     -------
@@ -211,19 +211,6 @@ def _init_latent_mean(
 
 
 def _sigmoid(tens: torch.Tensor) -> torch.Tensor:
-    """
-    Compute the sigmoid function of x element-wise.
-
-    Parameters
-    ----------
-    tens : torch.Tensor
-        Input tensor
-
-    Returns
-    -------
-    torch.Tensor
-        Output tensor with sigmoid applied element-wise
-    """
     return 1 / (1 + torch.exp(-tens))
 
 
@@ -288,8 +275,7 @@ def sample_pln(
 
 def _components_from_covariance(covariance: torch.Tensor, rank: int) -> torch.Tensor:
     """
-    Get the best matrix of size (p, rank) when covariance is of size (p, p),
-    i.e., reduce norm(covariance - components @ components.T).
+    Get the PCA with rank components of covariance.
 
     Parameters
     ----------
@@ -315,7 +301,7 @@ def _init_coef(
     counts: torch.Tensor, covariates: torch.Tensor, offsets: torch.Tensor
 ) -> torch.Tensor:
     """
-    Initialize the coefficient for the Poisson regression model.
+    Initialize the coefficient for the Pln model using Poisson regression model.
 
     Parameters
     ----------
@@ -341,8 +327,7 @@ def _init_coef(
 
 def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
     """
-    Compute log(n!) even for large n using the Stirling formula to avoid numerical
-    infinite values of n!.
+    Compute log(n!) using the Stirling formula.
 
     Parameters
     ----------
@@ -389,7 +374,7 @@ def log_posterior(
     Returns
     -------
     torch.Tensor
-        Log posterior of size (N_samples, batch_size) or (batch_size).
+        Log posterior of size n_samples.
     """
     length = len(posterior_mean.shape)
     rank = posterior_mean.shape[-1]
@@ -414,21 +399,6 @@ def log_posterior(
 
 
 def _trunc_log(tens: torch.Tensor, eps: float = 1e-16) -> torch.Tensor:
-    """
-    Compute the truncated logarithm of the input tensor.
-
-    Parameters
-    ----------
-    tens : torch.Tensor
-        Input tensor
-    eps : float, optional
-        Truncation value, default is 1e-16
-
-    Returns
-    -------
-    torch.Tensor
-        Truncated logarithm of the input tensor.
-    """
     integer = torch.min(torch.max(tens, torch.tensor([eps])), torch.tensor([1 - eps]))
     return torch.log(integer)
 
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 68b73247..158fe1fe 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -2,7 +2,6 @@ import torch  # pylint:disable=[C0114]
 from ._utils import _log_stirling, _trunc_log
 from ._closed_forms import _closed_formula_covariance, _closed_formula_coef
 
-
 from typing import Optional
 
 
@@ -38,7 +37,7 @@ def elbo_pln(
     Returns:
     -------
     torch.Tensor
-        The ELBO (Evidence Lower Bound) with size 1, with a gradient.
+        The ELBO (Evidence Lower Bound), of size one.
     """
     n_samples, dim = counts.shape
     s_rond_s = torch.square(latent_var)
@@ -63,10 +62,6 @@ def elbo_pln(
     return elbo / n_samples
 
 
-import torch
-from typing import Optional
-
-
 def profiled_elbo_pln(
     counts: torch.Tensor,
     covariates: torch.Tensor,
@@ -75,7 +70,9 @@ def profiled_elbo_pln(
     latent_var: torch.Tensor,
 ) -> torch.Tensor:
     """
-    Compute the ELBO (Evidence Lower Bound) for the Pln model with profiled parameters.
+    Compute the ELBO (Evidence Lower Bound) for the Pln model with profiled
+    model parameters (i.e the model parameters are derived directly from the
+    latent parameters).
 
     Parameters:
     ----------
@@ -93,7 +90,7 @@ def profiled_elbo_pln(
     Returns:
     -------
     torch.Tensor
-        The ELBO (Evidence Lower Bound) with size 1, with a gradient.
+        The ELBO (Evidence Lower Bound) with size 1.
     """
     n_samples, _ = counts.shape
     s_squared = torch.square(latent_var)
@@ -122,7 +119,8 @@ def elbo_plnpca(
     coef: torch.Tensor,
 ) -> torch.Tensor:
     """
-    Compute the ELBO (Evidence Lower Bound) for the Pln model with PCA parametrization.
+    Compute the ELBO (Evidence Lower Bound) for the Pln model
+    with PCA parametrization.
 
     Parameters:
     ----------
@@ -135,7 +133,8 @@ def elbo_plnpca(
     latent_mean : torch.Tensor
         Variational parameter with size (n, p).
     latent_var : torch.Tensor
-        Variational parameter with size (n, p).
+        Variational parameter with size (n, p). More precisely it is the unsigned
+        square root of the variational variance.
     components : torch.Tensor
         Model parameter with size (p, q).
     coef : torch.Tensor
-- 
GitLab


From 03608424d560dd1a6801c9694a5bcd6de8bd0bfe Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 24 May 2023 10:10:59 +0200
Subject: [PATCH 5/9] add GPU support and python < 3.9 support. test pass for
 all the models.

---
 pyPLNmodels/_utils.py  | 50 +++++++++++++++++++++---------------------
 pyPLNmodels/models.py  | 13 ++++++-----
 tests/requirements.txt |  8 -------
 tests/test_pln_full.py |  2 +-
 4 files changed, 33 insertions(+), 40 deletions(-)
 delete mode 100644 tests/requirements.txt

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 0206b3b4..9a3db0b2 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -9,7 +9,7 @@ from matplotlib import transforms
 from matplotlib.patches import Ellipse
 import matplotlib.pyplot as plt
 from patsy import dmatrices
-from typing import Optional, Dict, Any, Union
+from typing import Optional, Dict, Any, Union, Tuple, List
 import pkg_resources
 
 
@@ -234,7 +234,7 @@ def sample_pln(
     offsets: torch.Tensor,
     _coef_inflation: torch.Tensor = None,
     seed: int = None,
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Sample from the Poisson Log-Normal (Pln) model.
 
@@ -255,7 +255,7 @@ def sample_pln(
 
     Returns
     -------
-    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
         Tuple containing counts (torch.Tensor), gaussian (torch.Tensor), and ksi (torch.Tensor)
     """
     prev_state = torch.random.get_rng_state()
@@ -270,7 +270,7 @@ def sample_pln(
     else:
         XB = torch.matmul(covariates, coef)
 
-    gaussian = torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) + XB
+    gaussian = torch.mm(torch.randn(n_samples, rank, device="cpu"), components.T) + XB
     parameter = torch.exp(offsets + gaussian)
 
     if _coef_inflation is not None:
@@ -588,7 +588,7 @@ def _format_data(data: pd.DataFrame) -> torch.Tensor or None:
     if isinstance(data, np.ndarray):
         return torch.from_numpy(data).double().to(DEVICE)
     if isinstance(data, torch.Tensor):
-        return data
+        return data.to(DEVICE)
     raise AttributeError(
         "Please insert either a numpy.ndarray, pandas.DataFrame or torch.Tensor"
     )
@@ -600,7 +600,7 @@ def _format_model_param(
     offsets: torch.Tensor,
     offsets_formula: str,
     take_log_offsets: bool,
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Format the model parameters.
 
@@ -768,12 +768,12 @@ def _get_components_simulation(dim: int, rank: int) -> torch.Tensor:
         ] = 1
     components += torch.randn(dim, rank) / 8
     torch.random.set_rng_state(prev_state)
-    return components.to(DEVICE)
+    return components.to("cpu")
 
 
 def get_simulation_offsets_cov_coef(
     n_samples: int, nb_cov: int, dim: int
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Get simulation offsets, covariance coefficients.
 
@@ -788,7 +788,7 @@ def get_simulation_offsets_cov_coef(
 
     Returns:
     --------
-    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
         Tuple containing offsets, covariates, and coefficients.
     """
     prev_state = torch.random.get_rng_state()
@@ -801,11 +801,11 @@ def get_simulation_offsets_cov_coef(
             high=2,
             size=(n_samples, nb_cov),
             dtype=torch.float64,
-            device=DEVICE,
+            device="cpu",
         )
-    coef = torch.randn(nb_cov, dim, device=DEVICE)
+    coef = torch.randn(nb_cov, dim, device="cpu")
     offsets = torch.randint(
-        low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device=DEVICE
+        low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device="cpu"
     )
     torch.random.set_rng_state(prev_state)
     return offsets, covariates, coef
@@ -818,7 +818,7 @@ def get_simulated_count_data(
     nb_cov: int = 1,
     return_true_param: bool = False,
     seed: int = 0,
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Get simulated count data.
 
@@ -839,7 +839,7 @@ def get_simulated_count_data(
 
     Returns:
     --------
-    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
         Tuple containing counts, covariates, and offsets.
     """
     components = _get_components_simulation(dim, rank)
@@ -886,13 +886,13 @@ def get_real_count_data(n_samples: int = 270, dim: int = 100) -> np.ndarray:
     return counts
 
 
-def _closest(lst: list[float], element: float) -> float:
+def _closest(lst: List[float], element: float) -> float:
     """
     Find the closest element in a list to a given element.
 
     Parameters:
     -----------
-    lst : list[float]
+    lst : List[float]
         List of float values.
     element : float
         Element to find the closest value to.
@@ -958,7 +958,7 @@ def load_pln(path_of_directory: str) -> Dict[str, Any]:
 
 
 def load_plnpcacollection(
-    path_of_directory: str, ranks: Optional[list[int]] = None
+    path_of_directory: str, ranks: Optional[List[int]] = None
 ) -> Dict[int, Dict[str, Any]]:
     """
     Load PlnPCAcollection models from the given directory.
@@ -967,8 +967,8 @@ def load_plnpcacollection(
     ----------
     path_of_directory : str
         The path to the directory containing the PlnPCAcollection models.
-    ranks : list[int], optional
-        A list of ranks specifying which models to load. If None, all models in the directory will be loaded.
+    ranks : List[int], optional
+        A List of ranks specifying which models to load. If None, all models in the directory will be loaded.
 
     Returns
     -------
@@ -1025,7 +1025,7 @@ def _check_right_rank(data: Dict[str, Any], rank: int) -> None:
         )
 
 
-def _extract_data_from_formula(formula: str, data: Dict[str, Any]) -> tuple:
+def _extract_data_from_formula(formula: str, data: Dict[str, Any]) -> Tuple:
     """
     Extract data from the given formula and data dictionary.
 
@@ -1038,7 +1038,7 @@ def _extract_data_from_formula(formula: str, data: Dict[str, Any]) -> tuple:
 
     Returns
     -------
-    tuple
+    Tuple
         A tuple containing the extracted counts, covariates, and offsets.
 
     """
@@ -1066,7 +1066,7 @@ def _is_dict_of_dict(dictionary: Dict[Any, Any]) -> bool:
         True if the dictionary is a dictionary of dictionaries, False otherwise.
 
     """
-    return isinstance(dictionary[list(dictionary.keys())[0]], dict)
+    return isinstance(dictionary[List(dictionary.keys())[0]], dict)
 
 
 def _get_dict_initialization(
@@ -1116,11 +1116,11 @@ def _to_tensor(
     if obj is None:
         return None
     if isinstance(obj, np.ndarray):
-        return torch.from_numpy(obj)
+        return torch.from_numpy(obj).to(DEVICE)
     if isinstance(obj, torch.Tensor):
-        return obj
+        return obj.to(DEVICE)
     if isinstance(obj, pd.DataFrame):
-        return torch.from_numpy(obj.values)
+        return torch.from_numpy(obj.values).to(DEVICE)
     raise TypeError(
         "Please give either an np.ndarray or torch.Tensor or pd.DataFrame or None"
     )
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 960cdd0d..6a9c27f4 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -2,8 +2,7 @@ import time
 from abc import ABC, abstractmethod
 import warnings
 import os
-from collections.abc import Iterable
-from typing import Optional, Dict, List, Type, Any
+from typing import Optional, Dict, List, Type, Any, Iterable
 
 import pandas as pd
 import torch
@@ -472,9 +471,11 @@ class _Pln(ABC):
         if self.dim > 400:
             warnings.warn("Only displaying the first 400 variables.")
             sigma = sigma[:400, :400]
-            sns.heatmap(self.covariance[:400, :400], ax=ax)
+            sns.heatmap(self.covariance[:400, :400].cpu(), ax=ax)
         else:
-            sns.heatmap(self.covariance, ax=ax)
+            sns.heatmap(self.covariance.cpu(), ax=ax)
+        ax.set_title("Covariance Matrix")
+        plt.legend()
         if savefig:
             plt.savefig(name_file + self._NAME)
         plt.show()  # to avoid displaying a blank screen
@@ -1238,7 +1239,7 @@ class Pln(_Pln):
                 "n_samples",
             ]
         ):
-            return self._covariance.detach()
+            return self._covariance.cpu().detach()
         return None
 
     @covariance.setter
@@ -2373,7 +2374,7 @@ class PlnPCA(_Pln):
             cov_latent = self._latent_mean.T @ self._latent_mean
             cov_latent += torch.diag(torch.sum(torch.square(self._latent_var), dim=0))
             cov_latent /= self.n_samples
-            return (self._components @ cov_latent @ self._components.T).detach()
+            return (self._components @ cov_latent @ self._components.T).cpu().detach()
         return None
 
     @property
diff --git a/tests/requirements.txt b/tests/requirements.txt
deleted file mode 100644
index 19e8bedd..00000000
--- a/tests/requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-numpy
-pandas
-pyPLNmodels
-pytest
-pytest_lazy_fixture
-scanpy
-scikit_learn
-torch
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 0fa0e7ea..90e9868e 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -8,7 +8,7 @@ from tests.utils import filter_models
 @filter_models(["Pln"])
 def test_number_of_iterations_pln_full(fitted_pln):
     nb_iterations = len(fitted_pln._elbos_list)
-    assert 50 < nb_iterations < 300
+    assert 50 < nb_iterations < 400
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-- 
GitLab


From 5d48d2ffed2f64effe38e6c6f5761f7e40a5cf30 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 26 May 2023 22:06:54 +0200
Subject: [PATCH 6/9] add docstrings

---
 pyPLNmodels/_utils.py | 102 +++++++++++-------------------------------
 1 file changed, 27 insertions(+), 75 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 8fcff06c..a02c4c31 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -495,51 +495,10 @@ def _check_two_dimensions_are_equal(
         )
 
 
-def _init_S(
-    counts: torch.Tensor,
-    covariates: torch.Tensor,
-    offsets: torch.Tensor,
-    beta: torch.Tensor,
-    C: torch.Tensor,
-    M: torch.Tensor,
-) -> torch.Tensor:
-    """
-    Initialize the S matrix.
-
-    Parameters
-    ----------
-    counts : torch.Tensor, shape (n, )
-        Count data.
-    covariates : torch.Tensor or None, shape (n, d) or None
-        Covariate data.
-    offsets : torch.Tensor or None, shape (n, ) or None
-        Offset data.
-    beta : torch.Tensor, shape (d, )
-        Beta parameter.
-    C : torch.Tensor, shape (r, d)
-        C parameter.
-    M : torch.Tensor, shape (r, k)
-        M parameter.
-
-    Returns
-    -------
-    torch.Tensor, shape (r, r)
-        Initialized S matrix.
-    """
-    n, rank = M.shape
-    batch_matrix = torch.matmul(C[:, None, :], C[:, :, None])[None]
-    CW = torch.matmul(C[None], M[:, None, :]).squeeze()
-    common = torch.exp(offsets + covariates @ beta + CW)[:, None, None]
-    prod = batch_matrix * common
-    hess_posterior = torch.sum(prod, dim=1) + torch.eye(rank, device=DEVICE)
-    inv_hess_posterior = -torch.inverse(hess_posterior)
-    hess_posterior = torch.diagonal(inv_hess_posterior, dim1=-2, dim2=-1)
-    return hess_posterior
-
-
 def _format_data(data: pd.DataFrame) -> torch.Tensor or None:
     """
-    Format the input data.
+    Transforms the data in a torch.tensor if the input is an array, and None if the input is None.
+    Raises an error if the input is not an array or None.
 
     Parameters
     ----------
@@ -550,6 +509,11 @@ def _format_data(data: pd.DataFrame) -> torch.Tensor or None:
     -------
     torch.Tensor or None
         Formatted data.
+
+    Raises
+    ------
+    AttributeError
+        If the value is not an array or None.
     """
     if data is None:
         return None
@@ -572,7 +536,7 @@ def _format_model_param(
     take_log_offsets: bool,
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
-    Format the model parameters.
+    Format each of the model parameters to an array or None if None.
 
     Parameters
     ----------
@@ -595,13 +559,11 @@ def _format_model_param(
     ------
     ValueError
         If counts has negative values.
-
     """
     counts = _format_data(counts)
     if torch.min(counts) < 0:
-        raise ValueError("Counts should be only non negavtive values.")
-    if covariates is not None:
-        covariates = _format_data(covariates)
+        raise ValueError("Counts should be only non negative values.")
+    covariates = _format_data(covariates)
     if offsets is None:
         if offsets_formula == "logsum":
             print("Setting the offsets as the log of the sum of counts")
@@ -621,7 +583,7 @@ def _check_data_shape(
     counts: torch.Tensor, covariates: torch.Tensor, offsets: torch.Tensor
 ) -> None:
     """
-    Check the shape of the input data.
+    Check if the shape of the input data is valid.
 
     Parameters
     ----------
@@ -631,11 +593,6 @@ def _check_data_shape(
         Covariate data.
     offsets : torch.Tensor or None, shape (n, p) or None
         Offset data.
-
-    Raises
-    ------
-    ValueError
-        If the dimensions of the input data do not match.
     """
     n_counts, p_counts = counts.shape
     n_offsets, p_offsets = offsets.shape
@@ -670,23 +627,18 @@ def _nice_string_of_dict(dictionnary: dict) -> str:
 
 def _plot_ellipse(mean_x: float, mean_y: float, cov: np.ndarray, ax) -> float:
     """
-    Plot an ellipse on the given axes.
+    Plot an ellipse given two coordinates and the covariance.
 
     Parameters:
     -----------
     mean_x : float
-        Mean value of x-coordinate.
+        x-coordinate of the mean.
     mean_y : float
-        Mean value of y-coordinate.
+        y-coordinate of the mean.
     cov : np.ndarray
-        Covariance matrix.
+        Covariance matrix of the 2d vector.
     ax : object
         Axes object to plot the ellipse on.
-
-    Returns:
-    --------
-    float
-        Pearson correlation coefficient.
     """
     pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
     ell_radius_x = np.sqrt(1 + pearson)
@@ -709,24 +661,24 @@ def _plot_ellipse(mean_x: float, mean_y: float, cov: np.ndarray, ax) -> float:
     )
     ellipse.set_transform(transf + ax.transData)
     ax.add_patch(ellipse)
-    return pearson
 
 
 def _get_components_simulation(dim: int, rank: int) -> torch.Tensor:
     """
-    Get the components for simulation.
+    Get the components for simulation. The resulting covariance matrix
+    will be a matrix per blocks plus a little noise.
 
     Parameters:
     -----------
     dim : int
-        Dimension.
+        Dimension of the data.
     rank : int
-        Rank.
+        Rank of the resulting covariance matrix (i.e. number of components).
 
     Returns:
     --------
     torch.Tensor
-        Components for simulation.
+        Components.
     """
     block_size = dim // rank
     prev_state = torch.random.get_rng_state()
@@ -745,16 +697,16 @@ def get_simulation_offsets_cov_coef(
     n_samples: int, nb_cov: int, dim: int
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
-    Get simulation offsets, covariance coefficients.
+    Get offsets, covariance coefficients with right shapes.
 
     Parameters:
     -----------
     n_samples : int
         Number of samples.
     nb_cov : int
-        Number of covariates.
+        Number of covariates. If 0, covariates will be None.
     dim : int
-        Dimension.
+        Dimension required of the data.
 
     Returns:
     --------
@@ -790,7 +742,7 @@ def get_simulated_count_data(
     seed: int = 0,
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
-    Get simulated count data.
+    Get simulated count data from the Pln model.
 
     Parameters:
     -----------
@@ -799,11 +751,11 @@ def get_simulated_count_data(
     dim : int, optional
         Dimension, by default 25.
     rank : int, optional
-        Rank, by default 5.
+        Rank of the covariance matrix, by default 5.
     nb_cov : int, optional
         Number of covariates, by default 1.
     return_true_param : bool, optional
-        Whether to return true parameters, by default False.
+        Whether to return the true parameters of the model, by default False.
     seed : int, optional
         Seed value for random number generation, by default 0.
 
@@ -823,7 +775,7 @@ def get_simulated_count_data(
 
 def get_real_count_data(n_samples: int = 270, dim: int = 100) -> np.ndarray:
     """
-    Get real count data.
+    Get real count data from the scMARK dataset.
 
     Parameters:
     -----------
-- 
GitLab


From e8f816cbcdbc74c4b32cef53cd3d93dc211313b0 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 31 May 2023 23:59:46 +0200
Subject: [PATCH 7/9] added functionality such as load pln in __init__, and
 change the way we simulate the data to take as parameters pln_param, a new
 class containing all the parameters (components, coef, offsets and
 covariates.).

---
 pyPLNmodels/__init__.py |   6 +
 pyPLNmodels/_utils.py   | 622 +++++++++++-----------------------------
 pyPLNmodels/models.py   |  46 +--
 tests/test_pln_full.py  |   2 +-
 4 files changed, 199 insertions(+), 477 deletions(-)

diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py
index e619625e..529f8b0a 100644
--- a/pyPLNmodels/__init__.py
+++ b/pyPLNmodels/__init__.py
@@ -7,8 +7,12 @@ from ._utils import (
     load_model,
     load_plnpcacollection,
     load_pln,
+    sample_pln,
+    get_simulation_parameters,
 )
 
+from ._initialization import log_posterior
+
 __all__ = (
     "PlnPCAcollection",
     "Pln",
@@ -21,4 +25,6 @@ __all__ = (
     "load_model",
     "load_plnpcacollection",
     "load_pln",
+    "sample_pln",
+    "log_posterior",
 )
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index a02c4c31..084d736b 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -89,180 +89,53 @@ class _PlotArgs:
         ax.legend()
 
 
-def _init_covariance(
-    counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor
-) -> torch.Tensor:
-    """
-    Initialization for the covariance for the Pln model. Take the log of counts
-    (careful when counts=0), and computes the Maximum Likelihood
-    Estimator in the gaussian case.
-
-    Parameters
-    ----------
-    counts : torch.Tensor
-        Samples with size (n,p)
-    offsets : torch.Tensor
-        Offset, size (n,p)
-    covariates : torch.Tensor
-        Covariates, size (n,d)
-    coef : torch.Tensor
-        Coefficient of size (d,p)
-
-    Returns
-    -------
-    torch.Tensor
-        Covariance matrix of size (p,p)
-    """
-    log_y = torch.log(counts + (counts == 0) * math.exp(-2))
-    log_y_centered = log_y - torch.mean(log_y, axis=0)
-    n_samples = counts.shape[0]
-    sigma_hat = 1 / (n_samples - 1) * (log_y_centered.T) @ log_y_centered
-    return sigma_hat
-
-
-def _init_components(
-    counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor, rank: int
-) -> torch.Tensor:
-    """
-    Initialization for components for the Pln model. Get a first guess for covariance
-    that is easier to estimate and then takes the rank largest eigenvectors to get components.
-
-    Parameters
-    ----------
-    counts : torch.Tensor
-        Samples with size (n,p)
-    offsets : torch.Tensor
-        Offset, size (n,p)
-    covariates : torch.Tensor
-        Covariates, size (n,d)
-    coef : torch.Tensor
-        Coefficient of size (d,p)
-    rank : int
-        The dimension of the latent space, i.e. the reduced dimension.
-
-    Returns
-    -------
-    torch.Tensor
-        Initialization of components of size (p,rank)
-    """
-    sigma_hat = _init_covariance(counts, covariates, coef).detach()
-    components = _components_from_covariance(sigma_hat, rank)
-    return components
-
-
-def _init_latent_mean(
-    counts: torch.Tensor,
-    covariates: torch.Tensor,
-    offsets: torch.Tensor,
-    coef: torch.Tensor,
-    components: torch.Tensor,
-    n_iter_max=500,
-    lr=0.01,
-    eps=7e-3,
-) -> torch.Tensor:
-    """
-    Initialization for the variational parameter latent_mean.
-    Basically, the mode of the log_posterior is computed.
-
-    Parameters
-    ----------
-    counts : torch.Tensor
-        Samples with size (n,p)
-    offsets : torch.Tensor
-        Offset, size (n,p)
-    covariates : torch.Tensor
-        Covariates, size (n,d)
-    coef : torch.Tensor
-        Coefficient of size (d,p)
-    components : torch.Tensor
-        Components of size (p,rank)
-    n_iter_max : int, optional
-        The maximum number of iterations in the gradient ascent. Default is 500.
-    lr : float, optional
-        The learning rate of the optimizer. Default is 0.01.
-    eps : float, optional
-        The tolerance. The algorithm will stop as soon as the criterion is lower than the tolerance.
-        Default is 7e-3.
-
-    Returns
-    -------
-    torch.Tensor
-        The initialized latent mean with size (n,rank)
-    """
-    mode = torch.randn(counts.shape[0], components.shape[1], device=DEVICE)
-    mode.requires_grad_(True)
-    optimizer = torch.optim.Rprop([mode], lr=lr)
-    crit = 2 * eps
-    old_mode = torch.clone(mode)
-    keep_condition = True
-    i = 0
-    while i < n_iter_max and keep_condition:
-        batch_loss = log_posterior(counts, covariates, offsets, mode, components, coef)
-        loss = -torch.mean(batch_loss)
-        loss.backward()
-        optimizer.step()
-        crit = torch.max(torch.abs(mode - old_mode))
-        optimizer.zero_grad()
-        if crit < eps and i > 2:
-            keep_condition = False
-        old_mode = torch.clone(mode)
-        i += 1
-    return mode
-
-
 def _sigmoid(tens: torch.Tensor) -> torch.Tensor:
     return 1 / (1 + torch.exp(-tens))
 
 
-def sample_pln(
-    components: torch.Tensor,
-    coef: torch.Tensor,
-    covariates: torch.Tensor,
-    offsets: torch.Tensor,
-    _coef_inflation: torch.Tensor = None,
-    seed: int = None,
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+def sample_pln(pln_param, seed: int = None, return_latent=False) -> torch.Tensor:
     """
     Sample from the Poisson Log-Normal (Pln) model.
 
     Parameters
     ----------
-    components : torch.Tensor
-        Components of size (p, rank)
-    coef : torch.Tensor
-        Coefficient of size (d, p)
-    covariates : torch.Tensor or None
-        Covariates, size (n, d) or None
-    offsets : torch.Tensor
-        Offset, size (n, p)
-    _coef_inflation : torch.Tensor or None, optional
-        Coefficient for zero-inflation model, size (d, p) or None. Default is None.
+    pln_param : PlnParameters object
+        parameters of the model, containing the coeficient, the covariates,
+        the components and the offsets.
     seed : int or None, optional
         Random seed for reproducibility. Default is None.
+    return_latent : bool, optional
+        If True will return also the latent variables. Default is False.
 
     Returns
     -------
-    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+    tuple[torch.Tensor, torch.Tensor, torch.Tensor] if return_latent is True
         Tuple containing counts (torch.Tensor), gaussian (torch.Tensor), and ksi (torch.Tensor)
+    torch.Tensor if return_latent is False
     """
     prev_state = torch.random.get_rng_state()
     if seed is not None:
         torch.random.manual_seed(seed)
 
-    n_samples = offsets.shape[0]
-    rank = components.shape[1]
+    n_samples = pln_param.offsets.shape[0]
+    rank = pln_param.components.shape[1]
 
-    if covariates is None:
+    if pln_param.covariates is None:
         XB = 0
     else:
-        XB = torch.matmul(covariates, coef)
+        XB = torch.matmul(pln_param.covariates, pln_param.coef)
 
-    gaussian = torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) + XB
-    parameter = torch.exp(offsets + gaussian)
+    gaussian = (
+        torch.mm(torch.randn(n_samples, rank, device=DEVICE), pln_param.components.T)
+        + XB
+    )
+    parameter = torch.exp(pln_param.offsets + gaussian)
 
-    if _coef_inflation is not None:
+    if pln_param.coef_inflation is not None:
         print("ZIPln is sampled")
-        zero_inflated_mean = torch.matmul(covariates, _coef_inflation)
+        zero_inflated_mean = torch.matmul(
+            pln_param.covariates, pln_param.coef_inflation
+        )
         ksi = torch.bernoulli(1 / (1 + torch.exp(-zero_inflated_mean)))
     else:
         ksi = 0
@@ -270,59 +143,9 @@ def sample_pln(
     counts = (1 - ksi) * torch.poisson(parameter)
 
     torch.random.set_rng_state(prev_state)
-    return counts, gaussian, ksi
-
-
-def _components_from_covariance(covariance: torch.Tensor, rank: int) -> torch.Tensor:
-    """
-    Get the PCA with rank components of covariance.
-
-    Parameters
-    ----------
-    covariance : torch.Tensor
-        Covariance matrix of size (p, p)
-    rank : int
-        The number of columns wanted for components
-
-    Returns
-    -------
-    torch.Tensor
-        Requested components of size (p, rank) containing the rank eigenvectors
-        with largest eigenvalues.
-    """
-    eigenvalues, eigenvectors = TLA.eigh(covariance)
-    requested_components = eigenvectors[:, -rank:] @ torch.diag(
-        torch.sqrt(eigenvalues[-rank:])
-    )
-    return requested_components
-
-
-def _init_coef(
-    counts: torch.Tensor, covariates: torch.Tensor, offsets: torch.Tensor
-) -> torch.Tensor:
-    """
-    Initialize the coefficient for the Pln model using Poisson regression model.
-
-    Parameters
-    ----------
-    counts : torch.Tensor
-        Samples with size (n, p)
-    covariates : torch.Tensor
-        Covariates, size (n, d)
-    offsets : torch.Tensor
-        Offset, size (n, p)
-
-    Returns
-    -------
-    torch.Tensor or None
-        Coefficient of size (d, p) or None if covariates is None.
-    """
-    if covariates is None:
-        return None
-
-    poiss_reg = _PoissonReg()
-    poiss_reg.fit(counts, covariates, offsets)
-    return poiss_reg.beta
+    if return_latent is True:
+        return counts, gaussian, ksi
+    return counts
 
 
 def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
@@ -345,59 +168,6 @@ def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
     )
 
 
-def log_posterior(
-    counts: torch.Tensor,
-    covariates: torch.Tensor,
-    offsets: torch.Tensor,
-    posterior_mean: torch.Tensor,
-    components: torch.Tensor,
-    coef: torch.Tensor,
-) -> torch.Tensor:
-    """
-    Compute the log posterior of the Poisson Log-Normal (Pln) model.
-
-    Parameters
-    ----------
-    counts : torch.Tensor
-        Samples with size (batch_size, p)
-    covariates : torch.Tensor or None
-        Covariates, size (batch_size, d) or (d)
-    offsets : torch.Tensor
-        Offset, size (batch_size, p)
-    posterior_mean : torch.Tensor
-        Posterior mean with size (N_samples, N_batch, rank) or (batch_size, rank)
-    components : torch.Tensor
-        Components with size (p, rank)
-    coef : torch.Tensor
-        Coefficient with size (d, p)
-
-    Returns
-    -------
-    torch.Tensor
-        Log posterior of size n_samples.
-    """
-    length = len(posterior_mean.shape)
-    rank = posterior_mean.shape[-1]
-    components_posterior_mean = torch.matmul(
-        components.unsqueeze(0), posterior_mean.unsqueeze(2)
-    ).squeeze()
-
-    if covariates is None:
-        XB = 0
-    else:
-        XB = torch.matmul(covariates, coef)
-
-    log_lambda = offsets + components_posterior_mean + XB
-    first_term = (
-        -rank / 2 * math.log(2 * math.pi)
-        - 1 / 2 * torch.norm(posterior_mean, dim=-1) ** 2
-    )
-    second_term = torch.sum(
-        -torch.exp(log_lambda) + log_lambda * counts - _log_stirling(counts), axis=-1
-    )
-    return first_term + second_term
-
-
 def _trunc_log(tens: torch.Tensor, eps: float = 1e-16) -> torch.Tensor:
     integer = torch.min(torch.max(tens, torch.tensor([eps])), torch.tensor([1 - eps]))
     return torch.log(integer)
@@ -426,7 +196,8 @@ def _raise_wrong_dimension_error(
     str_second_array: str,
     dim_first_array: int,
     dim_second_array: int,
-    dim_of_error: int,
+    dim_order_first: int,
+    dim_order_second: int,
 ) -> None:
     """
     Raise an error for mismatched dimensions between two tensors.
@@ -441,60 +212,23 @@ def _raise_wrong_dimension_error(
         Dimension of the first tensor
     dim_second_array : int
         Dimension of the second tensor
-    dim_of_error : int
-        Dimension causing the error
-
+    dim_order_first : int
+        Dimension causing the error for the first tensor.
+    dim_order_second : int
+        Dimension causing the error for the second tensor.
     Raises
     ------
     ValueError
-        If the dimensions of the two tensors do not match at the non-singleton dimension.
+        If the dimensions of the two tensors do not match.
     """
     msg = (
-        f"The size of tensor {str_first_array} ({dim_first_array}) must match "
+        f"The size of tensor {str_first_array} at non-singleton dimension {dim_order_first} ({dim_first_array}) must match "
         f"the size of tensor {str_second_array} ({dim_second_array}) at "
-        f"non-singleton dimension {dim_of_error}"
+        f"non-singleton dimension {dim_order_second}"
     )
     raise ValueError(msg)
 
 
-def _check_two_dimensions_are_equal(
-    str_first_array: str,
-    str_second_array: str,
-    dim_first_array: int,
-    dim_second_array: int,
-    dim_of_error: int,
-) -> None:
-    """
-    Check if two dimensions are equal.
-
-    Parameters
-    ----------
-    str_first_array : str
-        Name of the first array.
-    str_second_array : str
-        Name of the second array.
-    dim_first_array : int
-        Dimension of the first array.
-    dim_second_array : int
-        Dimension of the second array.
-    dim_of_error : int
-        Dimension of the error.
-
-    Raises
-    ------
-    ValueError
-        If the dimensions of the two arrays are not equal.
-    """
-    if dim_first_array != dim_second_array:
-        _raise_wrong_dimension_error(
-            str_first_array,
-            str_second_array,
-            dim_first_array,
-            dim_second_array,
-            dim_of_error,
-        )
-
-
 def _format_data(data: pd.DataFrame) -> torch.Tensor or None:
     """
     Transforms the data in a torch.tensor if the input is an array, and None if the input is None.
@@ -596,11 +330,11 @@ def _check_data_shape(
     """
     n_counts, p_counts = counts.shape
     n_offsets, p_offsets = offsets.shape
-    _check_two_dimensions_are_equal("counts", "offsets", n_counts, n_offsets, 0)
+    _check_two_dimensions_are_equal("counts", "offsets", n_counts, n_offsets, 0, 0)
     if covariates is not None:
         n_cov, _ = covariates.shape
-        _check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0)
-    _check_two_dimensions_are_equal("counts", "offsets", p_counts, p_offsets, 1)
+        _check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0, 0)
+    _check_two_dimensions_are_equal("counts", "offsets", p_counts, p_offsets, 1, 1)
 
 
 def _nice_string_of_dict(dictionnary: dict) -> str:
@@ -663,7 +397,7 @@ def _plot_ellipse(mean_x: float, mean_y: float, cov: np.ndarray, ax) -> float:
     ax.add_patch(ellipse)
 
 
-def _get_components_simulation(dim: int, rank: int) -> torch.Tensor:
+def _get_simulation_components(dim: int, rank: int) -> torch.Tensor:
     """
     Get the components for simulation. The resulting covariance matrix
     will be a matrix per blocks plus a little noise.
@@ -693,7 +427,7 @@ def _get_components_simulation(dim: int, rank: int) -> torch.Tensor:
     return components.to(DEVICE)
 
 
-def get_simulation_offsets_cov_coef(
+def _get_simulation_coef_cov_offsets(
     n_samples: int, nb_cov: int, dim: int
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
@@ -730,7 +464,117 @@ def get_simulation_offsets_cov_coef(
         low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device=DEVICE
     )
     torch.random.set_rng_state(prev_state)
-    return offsets, covariates, coef
+    return coef, covariates, offsets
+
+
+class PlnParameters:
+    def __init__(self, components, coef, covariates, offsets, coef_inflation=None):
+        """
+        Instantiate all the needed parameters to sample from the PLN model.
+
+        Parameters:
+        -----------
+        components : torch.Tensor
+            Components of size (p, rank)
+        coef : torch.Tensor
+            Coefficient of size (d, p)
+        covariates : torch.Tensor or None
+            Covariates, size (n, d) or None
+        offsets : torch.Tensor
+            Offset, size (n, p)
+        _coef_inflation : torch.Tensor or None, optional
+            Coefficient for zero-inflation model, size (d, p) or None. Default is None.
+
+        """
+        self.components = _format_data(components)
+        self.coef = _format_data(coef)
+        self.covariates = _format_data(covariates)
+        self.offsets = _format_data(offsets)
+        self.coef_inflation = _format_data(coef_inflation)
+        _check_two_dimensions_are_equal(
+            "components", "coef", self.components.shape[0], self.coef.shape[1], 0, 1
+        )
+        if self.offsets is not None:
+            _check_two_dimensions_are_equal(
+                "components",
+                "offsets",
+                self.components.shape[0],
+                self.offsets.shape[1],
+                0,
+                1,
+            )
+        if self.covariates is not None:
+            _check_two_dimensions_are_equal(
+                "offsets",
+                "covariates",
+                self.offsets.shape[0],
+                self.covariates.shape[0],
+                0,
+                0,
+            )
+            _check_two_dimensions_are_equal(
+                "covariates", "coef", self.covariates.shape[1], self.coef.shape[0], 1, 0
+            )
+        for array in [self.components, self.coef, self.covariates, self.offsets]:
+            if array is not None:
+                if len(array.shape) != 2:
+                    raise RuntimeError(
+                        f"Expected all arrays to be 2-dimensional, got {len(array.shape)}"
+                    )
+
+    @property
+    def covariance(self):
+        return self.components @ self.components.T
+
+
+def _check_two_dimensions_are_equal(
+    str_first_array: str,
+    str_second_array: str,
+    dim_first_array: int,
+    dim_second_array: int,
+    dim_order_first: int,
+    dim_order_second: int,
+) -> None:
+    """
+    Check if two dimensions are equal.
+
+    Parameters
+    ----------
+    str_first_array : str
+        Name of the first array.
+    str_second_array : str
+        Name of the second array.
+    dim_first_array : int
+        Dimension of the first array.
+    dim_second_array : int
+        Dimension of the second array.
+    dim_order_first : int
+        Dimension causing the error for the first tensor.
+    dim_order_second : int
+        Dimension causing the error for the second tensor.
+
+    Raises
+    ------
+    ValueError
+        If the dimensions of the two arrays are not equal.
+    """
+    if dim_first_array != dim_second_array:
+        _raise_wrong_dimension_error(
+            str_first_array,
+            str_second_array,
+            dim_first_array,
+            dim_second_array,
+            dim_order_first,
+            dim_order_second,
+        )
+
+
+def get_simulation_parameters(
+    n_samples: int = 100, dim: int = 25, nb_cov: int = 1, rank: int = 5
+) -> PlnParameters:
+    coef, covariates, offsets = _get_simulation_coef_cov_offsets(n_samples, nb_cov, dim)
+    components = _get_simulation_components(dim, rank)
+    return PlnParameters(components, coef, covariates, offsets)
 
 
 def get_simulated_count_data(
@@ -742,7 +586,7 @@ def get_simulated_count_data(
     seed: int = 0,
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
-    Get simulated count data from the Pln model.
+    Get simulated count data from the PlnPCA model.
 
     Parameters:
     -----------
@@ -764,13 +608,17 @@ def get_simulated_count_data(
     tuple[torch.Tensor, torch.Tensor, torch.Tensor]
         Tuple containing counts, covariates, and offsets.
     """
-    components = _get_components_simulation(dim, rank)
-    offsets, cov, true_coef = get_simulation_offsets_cov_coef(n_samples, nb_cov, dim)
-    true_covariance = torch.matmul(components, components.T)
-    counts, _, _ = sample_pln(components, true_coef, cov, offsets, seed=seed)
+    pln_param = get_simulation_parameters(n_samples, dim, nb_cov, rank)
+    counts = sample_pln(pln_param, seed=seed, return_latent=False)
     if return_true_param is True:
-        return counts, cov, offsets, true_covariance, true_coef
-    return counts, cov, offsets
+        return (
+            counts,
+            pln_param.covariates,
+            pln_param.offsets,
+            pln_param.covariance,
+            pln_param.coef,
+        )
+    return pln_param.counts, pln_param.cov, pln_param.offsets
 
 
 def get_real_count_data(n_samples: int = 270, dim: int = 100) -> np.ndarray:
@@ -801,47 +649,23 @@ def get_real_count_data(n_samples: int = 270, dim: int = 100) -> np.ndarray:
         dim = 100
     counts_stream = pkg_resources.resource_stream(__name__, "data/scRT/Y_mark.csv")
     counts = pd.read_csv(counts_stream).values[:n_samples, :dim]
-    # counts = pd.read_csv("./pyPLNmodels/data/scRT/Y_mark.csv").values[
-    # :n_samples, :dim
-    # ]
     print(f"Returning dataset of size {counts.shape}")
     return counts
 
 
-def _closest(lst: list[float], element: float) -> float:
-    """
-    Find the closest element in a list to a given element.
-
-    Parameters:
-    -----------
-    lst : list[float]
-        List of float values.
-    element : float
-        Element to find the closest value to.
-
-    Returns:
-    --------
-    float
-        Closest element in the list.
-    """
-    lst = np.asarray(lst)
-    idx = (np.abs(lst - element)).argmin()
-    return lst[idx]
-
-
 def load_model(path_of_directory: str) -> Dict[str, Any]:
     """
-    Load models from the given directory.
+    Load model from the given directory for future initialization.
 
     Parameters
     ----------
     path_of_directory : str
-        The path to the directory containing the models.
+        The path to the directory containing the model.
 
     Returns
     -------
     Dict[str, Any]
-        A dictionary containing the loaded models.
+        A dictionary containing the loaded model.
 
     """
     working_dir = os.getcwd()
@@ -853,9 +677,9 @@ def load_model(path_of_directory: str) -> Dict[str, Any]:
             parameter = filename[:-4]
             try:
                 data[parameter] = pd.read_csv(filename, header=None).values
-            except pd.errors.EmptyDataError as err:
+            except pd.errors.EmptyDataError:
                 print(
-                    f"Can't load {parameter} since empty. Standard initialization will be performed"
+                    f"Can't load {parameter} since empty. Standard initialization will be performed for this parameter"
                 )
     os.chdir(working_dir)
     return data
@@ -863,18 +687,7 @@ def load_model(path_of_directory: str) -> Dict[str, Any]:
 
 def load_pln(path_of_directory: str) -> Dict[str, Any]:
     """
-    Load Pln models from the given directory.
-
-    Parameters
-    ----------
-    path_of_directory : str
-        The path to the directory containing the Pln models.
-
-    Returns
-    -------
-    Dict[str, Any]
-        A dictionary containing the loaded Pln models.
-
+    Alias for :func:`~pyPLNmodels._utils.load_model`.
     """
     return load_model(path_of_directory)
 
@@ -1048,108 +861,7 @@ def _to_tensor(
     )
 
 
-class _PoissonReg:
-    """
-    Poisson regression model.
-
-    Attributes
-    ----------
-    beta : torch.Tensor
-        The learned regression coefficients.
-
-    Methods
-    -------
-    fit(Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False)
-        Fit the Poisson regression model to the given data.
-
-    """
-
-    def __init__(self) -> None:
-        self.beta: Optional[torch.Tensor] = None
-
-    def fit(
-        self,
-        Y: torch.Tensor,
-        covariates: torch.Tensor,
-        O: torch.Tensor,
-        Niter_max: int = 300,
-        tol: float = 0.001,
-        lr: float = 0.005,
-        verbose: bool = False,
-    ) -> None:
-        """
-        Fit the Poisson regression model to the given data.
-
-        Parameters
-        ----------
-        Y : torch.Tensor
-            The dependent variable of shape (n_samples, n_features).
-        covariates : torch.Tensor
-            The covariates of shape (n_samples, n_covariates).
-        O : torch.Tensor
-            The offset term of shape (n_samples, n_features).
-        Niter_max : int, optional
-            The maximum number of iterations (default is 300).
-        tol : float, optional
-            The tolerance for convergence (default is 0.001).
-        lr : float, optional
-            The learning rate (default is 0.005).
-        verbose : bool, optional
-            Whether to print intermediate information during fitting (default is False).
-
-        """
-        beta = torch.rand(
-            (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True
-        )
-        optimizer = torch.optim.Rprop([beta], lr=lr)
-        i = 0
-        grad_norm = 2 * tol  # Criterion
-        while i < Niter_max and grad_norm > tol:
-            loss = -compute_poissreg_log_like(Y, O, covariates, beta)
-            loss.backward()
-            optimizer.step()
-            grad_norm = torch.norm(beta.grad)
-            beta.grad.zero_()
-            i += 1
-            if verbose:
-                if i % 10 == 0:
-                    print("log like : ", -loss)
-                    print("grad_norm : ", grad_norm)
-                if i < Niter_max:
-                    print("Tolerance reached in {} iterations".format(i))
-                else:
-                    print("Maximum number of iterations reached")
-        self.beta = beta
-
-
-def compute_poissreg_log_like(
-    Y: torch.Tensor, O: torch.Tensor, covariates: torch.Tensor, beta: torch.Tensor
-) -> torch.Tensor:
-    """
-    Compute the log likelihood of a Poisson regression model.
-
-    Parameters
-    ----------
-    Y : torch.Tensor
-        The dependent variable of shape (n_samples, n_features).
-    O : torch.Tensor
-        The offset term of shape (n_samples, n_features).
-    covariates : torch.Tensor
-        The covariates of shape (n_samples, n_covariates).
-    beta : torch.Tensor
-        The regression coefficients of shape (n_covariates, n_features).
-
-    Returns
-    -------
-    torch.Tensor
-        The log likelihood of the Poisson regression model.
-
-    """
-    XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
-    return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
-
-
-def array2tensor(func):
+def _array2tensor(func):
     def setter(self, array_like):
         array_like = _to_tensor(array_like)
         func(self, array_like)
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 960cdd0d..190022d5 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -21,10 +21,6 @@ from ._closed_forms import (
 from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln
 from ._utils import (
     _PlotArgs,
-    _init_covariance,
-    _init_components,
-    _init_coef,
-    _init_latent_mean,
     _format_data,
     _format_model_param,
     _nice_string_of_dict,
@@ -32,12 +28,19 @@ from ._utils import (
     _check_data_shape,
     _extract_data_from_formula,
     _get_dict_initialization,
-    array2tensor,
+    _array2tensor,
+)
+
+from ._initialization import (
+    _init_covariance,
+    _init_components,
+    _init_coef,
+    _init_latent_mean,
 )
 
 if torch.cuda.is_available():
     DEVICE = "cuda"
-    print("Using a GPU")
+    print("Using a GPU.")
 else:
     DEVICE = "cpu"
 # shoudl add a good init for M. for pln we should not put
@@ -106,6 +109,7 @@ class _Pln(ABC):
     ):
         """
         Create a _Pln instance from a formula and data.
+        See also :func:`~pyPLNmodels.PlnPCAcollection.__init__`
 
         Parameters
         ----------
@@ -689,7 +693,7 @@ class _Pln(ABC):
         return self._cpu_attribute_or_none("_latent_var")
 
     @latent_mean.setter
-    @array2tensor
+    @_array2tensor
     def latent_mean(self, latent_mean):
         """
         Setter for the latent mean property.
@@ -711,7 +715,7 @@ class _Pln(ABC):
         self._latent_mean = latent_mean
 
     @latent_var.setter
-    @array2tensor
+    @_array2tensor
     def latent_var(self, latent_var):
         """
         Setter for the latent variance property.
@@ -812,7 +816,7 @@ class _Pln(ABC):
         return self._cpu_attribute_or_none("_covariates")
 
     @counts.setter
-    @array2tensor
+    @_array2tensor
     def counts(self, counts):
         """
         Setter for the counts property.
@@ -836,7 +840,7 @@ class _Pln(ABC):
         self._counts = counts
 
     @offsets.setter
-    @array2tensor
+    @_array2tensor
     def offsets(self, offsets):
         """
         Setter for the offsets property.
@@ -858,7 +862,7 @@ class _Pln(ABC):
         self._offsets = offsets
 
     @covariates.setter
-    @array2tensor
+    @_array2tensor
     def covariates(self, covariates):
         """
         Setter for the covariates property.
@@ -877,7 +881,7 @@ class _Pln(ABC):
         self._covariates = covariates
 
     @coef.setter
-    @array2tensor
+    @_array2tensor
     def coef(self, coef):
         """
         Setter for the coef property.
@@ -989,7 +993,7 @@ class _Pln(ABC):
         Notes
         -----
         - If `covariates` is not provided and there are no covariates in the model, None is returned.
-        - If `covariates` is provided, it should have the shape `(n_samples, nb_cov)`, where `n_samples` is the number of samples and `nb_cov` is the number of covariates.
+        - If `covariates` is provided, it should have the shape `(_, nb_cov)`, where `nb_cov` is the number of covariates.
         - The predicted values are obtained by multiplying the covariates by the coefficients.
 
         """
@@ -1439,7 +1443,7 @@ class PlnPCAcollection:
         return {model.rank: model.latent_var for model in self.values()}
 
     @counts.setter
-    @array2tensor
+    @_array2tensor
     def counts(self, counts: torch.Tensor):
         """
         Setter for the counts property.
@@ -1453,7 +1457,7 @@ class PlnPCAcollection:
             model.counts = counts
 
     @coef.setter
-    @array2tensor
+    @_array2tensor
     def coef(self, coef: torch.Tensor):
         """
         Setter for the coef property.
@@ -1467,7 +1471,7 @@ class PlnPCAcollection:
             model.coef = coef
 
     @covariates.setter
-    @array2tensor
+    @_array2tensor
     def covariates(self, covariates: torch.Tensor):
         """
         Setter for the covariates property.
@@ -1493,7 +1497,7 @@ class PlnPCAcollection:
         return self[self.ranks[0]].offsets
 
     @offsets.setter
-    @array2tensor
+    @_array2tensor
     def offsets(self, offsets: torch.Tensor):
         """
         Setter for the offsets property.
@@ -2116,7 +2120,7 @@ class PlnPCA(_Pln):
         return self._cpu_attribute_or_none("_latent_var")
 
     @latent_mean.setter
-    @array2tensor
+    @_array2tensor
     def latent_mean(self, latent_mean: torch.Tensor):
         """
         Setter for the latent mean.
@@ -2133,7 +2137,7 @@ class PlnPCA(_Pln):
         self._latent_mean = latent_mean
 
     @latent_var.setter
-    @array2tensor
+    @_array2tensor
     def latent_var(self, latent_var: torch.Tensor):
         """
         Setter for the latent variance.
@@ -2174,7 +2178,7 @@ class PlnPCA(_Pln):
         return self._cpu_attribute_or_none("_covariates")
 
     @covariates.setter
-    @array2tensor
+    @_array2tensor
     def covariates(self, covariates: torch.Tensor):
         """
         Setter for the covariates.
@@ -2455,7 +2459,7 @@ class PlnPCA(_Pln):
         return self._cpu_attribute_or_none("_components")
 
     @components.setter
-    @array2tensor
+    @_array2tensor
     def components(self, components: torch.Tensor):
         """
         Setter for the components.
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 0fa0e7ea..db0185f1 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -8,7 +8,7 @@ from tests.utils import filter_models
 @filter_models(["Pln"])
 def test_number_of_iterations_pln_full(fitted_pln):
     nb_iterations = len(fitted_pln._elbos_list)
-    assert 50 < nb_iterations < 300
+    assert 50 < nb_iterations < 500
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-- 
GitLab


From 01aa247f5de7eb7439557642e6c29608c708c740 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 1 Jun 2023 00:01:01 +0200
Subject: [PATCH 8/9] added functionality such as load_pln in the docs

---
 docs/source/index.rst  | 25 ++++++++++++++++++++-----
 docs/source/module.rst | 23 -----------------------
 2 files changed, 20 insertions(+), 28 deletions(-)
 delete mode 100644 docs/source/module.rst

diff --git a/docs/source/index.rst b/docs/source/index.rst
index bf6f6e0d..98f3e0a6 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -6,16 +6,31 @@
 Welcome to pyPLNmodels's documentation!
 =======================================
 
+API documentation
+=================
+
 .. toctree::
-   :maxdepth: 2
-   :caption: Contents
+   :maxdepth: 1
+   :caption: Classes
+
+   ./plnpcacollection.rst
+   ./plnpca.rst
+   ./pln.rst
 
-   ./module.rst
+.. toctree::
+   :maxdepth: 1
+   :caption: Load saved models
 
+   ./load.rst
+
+.. toctree::
+   :maxdepth: 1
+   :caption: Sampling PLN data
 
+   ./sampling.rst
 Indices and tables
 ==================
 
 * :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
+.. * :ref:`modindex`
+.. * :ref:`search`
diff --git a/docs/source/module.rst b/docs/source/module.rst
deleted file mode 100644
index de0c03d0..00000000
--- a/docs/source/module.rst
+++ /dev/null
@@ -1,23 +0,0 @@
-API documentation
-=================
-
-.. autoclass:: pyPLNmodels.PlnPCAcollection
-   :members:
-   :show-inheritance:
-   :special-members: __init__
-   :undoc-members:
-
-
-.. autoclass:: pyPLNmodels.PlnPCA
-   :members:
-   :inherited-members:
-   :special-members: __init__
-   :undoc-members:
-   :show-inheritance:
-
-.. autoclass:: pyPLNmodels.Pln
-   :members:
-   :inherited-members:
-   :special-members: __init__
-   :undoc-members:
-   :show-inheritance:
-- 
GitLab


From a90face83e3ada09842eb2ad734260f79fdb68ff Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 1 Jun 2023 00:02:27 +0200
Subject: [PATCH 9/9] add all the initialization of beta etc in the same file

---
 pyPLNmodels/_initialization.py | 335 +++++++++++++++++++++++++++++++++
 1 file changed, 335 insertions(+)
 create mode 100644 pyPLNmodels/_initialization.py

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
new file mode 100644
index 00000000..2510b963
--- /dev/null
+++ b/pyPLNmodels/_initialization.py
@@ -0,0 +1,335 @@
+import torch
+import math
+from typing import Optional
+from ._utils import _log_stirling
+
+if torch.cuda.is_available():
+    DEVICE = torch.device("cuda")
+else:
+    DEVICE = torch.device("cpu")
+
+
+def _init_covariance(
+    counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor
+) -> torch.Tensor:
+    """
+    Initialization for the covariance for the Pln model. Take the log of counts
+    (careful when counts=0), and computes the Maximum Likelihood
+    Estimator in the gaussian case.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n,p)
+    offsets : torch.Tensor
+        Offset, size (n,p)
+    covariates : torch.Tensor
+        Covariates, size (n,d)
+    coef : torch.Tensor
+        Coefficient of size (d,p)
+
+    Returns
+    -------
+    torch.Tensor
+        Covariance matrix of size (p,p)
+    """
+    log_y = torch.log(counts + (counts == 0) * math.exp(-2))
+    log_y_centered = log_y - torch.mean(log_y, axis=0)
+    n_samples = counts.shape[0]
+    sigma_hat = 1 / (n_samples - 1) * (log_y_centered.T) @ log_y_centered
+    return sigma_hat
+
+
+def _init_components(
+    counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor, rank: int
+) -> torch.Tensor:
+    """
+    Initialization for components for the Pln model. Get a first guess for covariance
+    that is easier to estimate and then takes the rank largest eigenvectors to get components.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n,p)
+    offsets : torch.Tensor
+        Offset, size (n,p)
+    covariates : torch.Tensor
+        Covariates, size (n,d)
+    coef : torch.Tensor
+        Coefficient of size (d,p)
+    rank : int
+        The dimension of the latent space, i.e. the reduced dimension.
+
+    Returns
+    -------
+    torch.Tensor
+        Initialization of components of size (p,rank)
+    """
+    sigma_hat = _init_covariance(counts, covariates, coef).detach()
+    components = _components_from_covariance(sigma_hat, rank)
+    return components
+
+
+def _init_latent_mean(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    coef: torch.Tensor,
+    components: torch.Tensor,
+    n_iter_max=500,
+    lr=0.01,
+    eps=7e-3,
+) -> torch.Tensor:
+    """
+    Initialization for the variational parameter latent_mean.
+    Basically, the mode of the log_posterior is computed.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n,p)
+    offsets : torch.Tensor
+        Offset, size (n,p)
+    covariates : torch.Tensor
+        Covariates, size (n,d)
+    coef : torch.Tensor
+        Coefficient of size (d,p)
+    components : torch.Tensor
+        Components of size (p,rank)
+    n_iter_max : int, optional
+        The maximum number of iterations in the gradient ascent. Default is 500.
+    lr : float, optional
+        The learning rate of the optimizer. Default is 0.01.
+    eps : float, optional
+        The tolerance. The algorithm will stop as soon as the criterion is lower than the tolerance.
+        Default is 7e-3.
+
+    Returns
+    -------
+    torch.Tensor
+        The initialized latent mean with size (n,rank)
+    """
+    mode = torch.randn(counts.shape[0], components.shape[1], device=DEVICE)
+    mode.requires_grad_(True)
+    optimizer = torch.optim.Rprop([mode], lr=lr)
+    crit = 2 * eps
+    old_mode = torch.clone(mode)
+    keep_condition = True
+    i = 0
+    while i < n_iter_max and keep_condition:
+        batch_loss = log_posterior(counts, covariates, offsets, mode, components, coef)
+        loss = -torch.mean(batch_loss)
+        loss.backward()
+        optimizer.step()
+        crit = torch.max(torch.abs(mode - old_mode))
+        optimizer.zero_grad()
+        if crit < eps and i > 2:
+            keep_condition = False
+        old_mode = torch.clone(mode)
+        i += 1
+    return mode
+
+
+def _components_from_covariance(covariance: torch.Tensor, rank: int) -> torch.Tensor:
+    """
+    Get the PCA with rank components of covariance.
+
+    Parameters
+    ----------
+    covariance : torch.Tensor
+        Covariance matrix of size (p, p)
+    rank : int
+        The number of columns wanted for components
+
+    Returns
+    -------
+    torch.Tensor
+        Requested components of size (p, rank) containing the rank eigenvectors
+        with largest eigenvalues.
+    """
+    eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
+    requested_components = eigenvectors[:, -rank:] @ torch.diag(
+        torch.sqrt(eigenvalues[-rank:])
+    )
+    return requested_components
+
+
+def _init_coef(
+    counts: torch.Tensor, covariates: torch.Tensor, offsets: torch.Tensor
+) -> torch.Tensor:
+    """
+    Initialize the coefficient for the Pln model using Poisson regression model.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n, p)
+    covariates : torch.Tensor
+        Covariates, size (n, d)
+    offsets : torch.Tensor
+        Offset, size (n, p)
+
+    Returns
+    -------
+    torch.Tensor or None
+        Coefficient of size (d, p) or None if covariates is None.
+    """
+    if covariates is None:
+        return None
+
+    poiss_reg = _PoissonReg()
+    poiss_reg.fit(counts, covariates, offsets)
+    return poiss_reg.beta
+
+
+def log_posterior(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    posterior_mean: torch.Tensor,
+    components: torch.Tensor,
+    coef: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Compute the log posterior of the Poisson Log-Normal (Pln) model.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (batch_size, p)
+    covariates : torch.Tensor or None
+        Covariates, size (batch_size, d) or (d)
+    offsets : torch.Tensor
+        Offset, size (batch_size, p)
+    posterior_mean : torch.Tensor
+        Posterior mean with size (N_samples, N_batch, rank) or (batch_size, rank)
+    components : torch.Tensor
+        Components with size (p, rank)
+    coef : torch.Tensor
+        Coefficient with size (d, p)
+
+    Returns
+    -------
+    torch.Tensor
+        Log posterior of size n_samples.
+    """
+    rank = posterior_mean.shape[-1]
+    components_posterior_mean = torch.matmul(
+        components.unsqueeze(0), posterior_mean.unsqueeze(2)
+    ).squeeze()
+
+    if covariates is None:
+        XB = 0
+    else:
+        XB = torch.matmul(covariates, coef)
+
+    log_lambda = offsets + components_posterior_mean + XB
+    first_term = (
+        -rank / 2 * math.log(2 * math.pi)
+        - 1 / 2 * torch.norm(posterior_mean, dim=-1) ** 2
+    )
+    second_term = torch.sum(
+        -torch.exp(log_lambda) + log_lambda * counts - _log_stirling(counts), axis=-1
+    )
+    return first_term + second_term
+
+
+class _PoissonReg:
+    """
+    Poisson regression model.
+
+    Attributes
+    ----------
+    beta : torch.Tensor
+        The learned regression coefficients.
+
+    Methods
+    -------
+    fit(Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False)
+        Fit the Poisson regression model to the given data.
+
+    """
+
+    def __init__(self) -> None:
+        self.beta: Optional[torch.Tensor] = None
+
+    def fit(
+        self,
+        Y: torch.Tensor,
+        covariates: torch.Tensor,
+        offsets: torch.Tensor,
+        Niter_max: int = 300,
+        tol: float = 0.001,
+        lr: float = 0.005,
+        verbose: bool = False,
+    ) -> None:
+        """
+        Fit the Poisson regression model to the given data.
+
+        Parameters
+        ----------
+        Y : torch.Tensor
+            The dependent variable of shape (n_samples, n_features).
+        covariates : torch.Tensor
+            The covariates of shape (n_samples, n_covariates).
+        offsets : torch.Tensor
+            The offset term of shape (n_samples, n_features).
+        Niter_max : int, optional
+            The maximum number of iterations (default is 300).
+        tol : float, optional
+            The tolerance for convergence (default is 0.001).
+        lr : float, optional
+            The learning rate (default is 0.005).
+        verbose : bool, optional
+            Whether to print intermediate information during fitting (default is False).
+
+        """
+        beta = torch.rand(
+            (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True
+        )
+        optimizer = torch.optim.Rprop([beta], lr=lr)
+        i = 0
+        grad_norm = 2 * tol  # Criterion
+        while i < Niter_max and grad_norm > tol:
+            loss = -compute_poissreg_log_like(Y, offsets, covariates, beta)
+            loss.backward()
+            optimizer.step()
+            grad_norm = torch.norm(beta.grad)
+            beta.grad.zero_()
+            i += 1
+            if verbose:
+                if i % 10 == 0:
+                    print("log like : ", -loss)
+                    print("grad_norm : ", grad_norm)
+                if i < Niter_max:
+                    print(f"Tolerance reached in {i} iterations")
+                else:
+                    print("Maximum number of iterations reached")
+        self.beta = beta
+
+
+def compute_poissreg_log_like(
+    Y: torch.Tensor, O: torch.Tensor, covariates: torch.Tensor, beta: torch.Tensor
+) -> torch.Tensor:
+    """
+    Compute the log likelihood of a Poisson regression model.
+
+    Parameters
+    ----------
+    Y : torch.Tensor
+        The dependent variable of shape (n_samples, n_features).
+    O : torch.Tensor
+        The offset term of shape (n_samples, n_features).
+    covariates : torch.Tensor
+        The covariates of shape (n_samples, n_covariates).
+    beta : torch.Tensor
+        The regression coefficients of shape (n_covariates, n_features).
+
+    Returns
+    -------
+    torch.Tensor
+        The log likelihood of the Poisson regression model.
+
+    """
+    XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
+    return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
-- 
GitLab