Skip to content

tinker_cookbook.hyperparam_utils.get_lora_param_count

tinker_cookbook.hyperparam_utils.get_lora_param_count(model_name, lora_rank, train_mlp, train_attn, train_unembed)

Get the number of parameters in the LoRA adapter.

Mirrors the signature of ServiceClient.create_lora_training_client: the returned count reflects exactly which submodules will be adapted.

Parameters:

  • model_name (str) – Tinker base model identifier.
  • lora_rank (int) – Rank of the LoRA decomposition.
  • train_mlp (bool) – Whether MLP layers are LoRA-trained.
  • train_attn (bool) – Whether attention layers are LoRA-trained.
  • train_unembed (bool) – Whether the unembedding (LM head) is LoRA-trained.

Returns: Total trainable parameter count.