Training function

celltypist.train(X=None, labels: str | list | tuple | ndarray | Series | Index | None = None, genes: str | list | tuple | ndarray | Series | Index | None = None, transpose_input: bool = False, with_mean: bool = True, check_expression: bool = True, C: float = 1.0, solver: str | None = None, max_iter: int | None = None, n_jobs: int | None = None, use_SGD: bool = False, alpha: float = 0.0001, use_GPU: bool = False, mini_batch: bool = False, batch_number: int = 100, batch_size: int = 1000, epochs: int = 10, balance_cell_type: bool = False, feature_selection: bool = False, top_genes: int = 300, date: str = '', details: str = '', url: str = '', source: str = '', version: str = '', **kwargs) Model[source]

Train a celltypist model using mini-batch (optional) logistic classifier with a global solver or stochastic gradient descent (SGD) learning.

Parameters:
  • X – Path to the input count matrix (supported types are csv, txt, tsv, tab and mtx) or AnnData (h5ad). Also accepts the input as an AnnData object, or any array-like objects already loaded in memory. See check_expression for detailed format requirements. A cell-by-gene format is desirable (see transpose_input for more information).

  • labels – Path to the file containing cell type label per line corresponding to the cells in X. Also accepts any list-like objects already loaded in memory (such as an array). If X is specified as an AnnData, this argument can also be set as a column name from cell metadata.

  • genes – Path to the file containing one gene per line corresponding to the genes in X. Also accepts any list-like objects already loaded in memory (such as an array). Note genes will be extracted from X where possible (e.g., X is an AnnData or data frame).

  • transpose_input – Whether to transpose the input matrix. Set to True if X is provided in a gene-by-cell format. (Default: False)

  • with_mean – Whether to subtract the mean values during data scaling. Setting to False can lower the memory usage when the input is a sparse matrix but may slightly reduce the model performance. (Default: True)

  • check_expression – Check whether the expression matrix in the input data is supplied as required. Except the case where a path to the raw count table file is specified, all other inputs for X should be in log1p normalized expression to 10000 counts per cell. Set to False if you want to train the data regardless of the expression formats. (Default: True)

  • C – Inverse of L2 regularization strength for traditional logistic classifier. A smaller value can possibly improve model generalization while at the cost of decreased accuracy. This argument is ignored if SGD learning is enabled (use_SGD = True). (Default: 1.0)

  • solver – Algorithm to use in the optimization problem for traditional logistic classifier. The default behavior is to choose the solver according to the size of the input data. This argument is ignored if SGD learning is enabled (use_SGD = True).

  • max_iter – Maximum number of iterations before reaching the minimum of the cost function. Try to decrease max_iter if the cost function does not converge for a long time. This argument is for both traditional and SGD logistic classifiers, and will be ignored if mini-batch SGD training is conducted (use_SGD = True and mini_batch = True). Default to 200, 500, and 1000 for large (>500k cells), medium (50-500k), and small (<50k) datasets, respectively.

  • n_jobs – Number of CPUs used. Default to one CPU. -1 means all CPUs are used. This argument is for both traditional and SGD logistic classifiers.

  • use_SGD – Whether to implement SGD learning for the logistic classifier. (Default: False)

  • alpha – L2 regularization strength for SGD logistic classifier. A larger value can possibly improve model generalization while at the cost of decreased accuracy. This argument is ignored if SGD learning is disabled (use_SGD = False). (Default: 0.0001)

  • use_GPU – Whether to use GPU for logistic classifier. This argument is ignored if SGD learning is enabled (use_SGD = True). (Default: False)

  • mini_batch – Whether to implement mini-batch training for the SGD logistic classifier. Setting to True may improve the training efficiency for large datasets (for example, >100k cells). This argument is ignored if SGD learning is disabled (use_SGD = False). (Default: False)

  • batch_number – The number of batches used for training in each epoch. Each batch contains batch_size cells. For datasets which cannot be binned into batch_number batches, all batches will be used. This argument is relevant only if mini-batch SGD training is conducted (use_SGD = True and mini_batch = True). (Default: 100)

  • batch_size – The number of cells within each batch. This argument is relevant only if mini-batch SGD training is conducted (use_SGD = True and mini_batch = True). (Default: 1000)

  • epochs – The number of epochs for the mini-batch training procedure. The default values of batch_number, batch_size, and epochs together allow observing ~10^6 training cells. This argument is relevant only if mini-batch SGD training is conducted (use_SGD = True and mini_batch = True). (Default: 10)

  • balance_cell_type – Whether to balance the cell type frequencies in mini-batches during each epoch. Setting to True will sample rare cell types with a higher probability, ensuring close-to-even cell type distributions in mini-batches. This argument is relevant only if mini-batch SGD training is conducted (use_SGD = True and mini_batch = True). (Default: False)

  • feature_selection – Whether to perform two-pass data training where the first round is used for selecting important features/genes using SGD learning. If True, the training time will be longer. (Default: False)

  • top_genes – The number of top genes selected from each class/cell-type based on their absolute regression coefficients. The final feature set is combined across all classes (i.e., union). (Default: 300)

  • date – Free text of the date of the model. Default to the time when the training is completed.

  • details – Free text of the description of the model.

  • url – Free text of the (possible) download url of the model.

  • source – Free text of the source (publication, database, etc.) of the model.

  • version – Free text of the version of the model.

  • **kwargs – Other keyword arguments passed to LogisticRegression (use_SGD = False and use_GPU = False), cuml.LogisticRegression (use_SGD = False and use_GPU = True), or SGDClassifier (use_SGD = True).

Returns:

An instance of the Model trained by celltypist.

Return type:

Model