TrainingClient
class tinker.TrainingClient(holder, model_seq_id, model_id)
Client for training ML models with forward/backward passes and optimization.
The TrainingClient corresponds to a fine-tuned model that you can train and sample from.
You typically get one by calling service_client.create_lora_training_client().
Key methods:
- forward_backward() - compute gradients for training
- optim_step() - update model parameters with Adam optimizer
- save_weights_and_get_sampling_client() - export trained model for inference
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
fwdbwd_result = fwdbwd_future.result() # Wait for gradients
optim_result = optim_future.result() # Wait for parameter update
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")
Parameters:
- holder (InternalClientHolder) – Internal client managing HTTP connections and async operations
- model_seq_id (int)
- model_id (types.ModelID) – Unique identifier for the model to train. Required for training operations.
forward(data, loss_fn, loss_fn_config=None)
Compute forward pass without gradients.
Parameters:
- data (List[types.Datum]) – List of training data samples
- loss_fn (types.LossFnType) – Loss function type (e.g., "cross_entropy")
- loss_fn_config (Dict[str, float] | None, default:
None) – Optional configuration for the loss function
Returns: APIFuture containing the forward pass outputs and loss
data = [types.Datum(
model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
)]
future = training_client.forward(data, "cross_entropy")
result = await future
print(f"Loss: {result.loss}")
Async variant: forward_async()
forward_backward(data, loss_fn, loss_fn_config=None)
Compute forward pass and backward pass to calculate gradients.
Parameters:
- data (List[types.Datum]) – List of training data samples
- loss_fn (types.LossFnType) – Loss function type (e.g., "cross_entropy")
- loss_fn_config (Dict[str, float] | None, default:
None) – Optional configuration for the loss function
Returns: APIFuture containing the forward/backward outputs, loss, and gradients
data = [types.Datum(
model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
)]
# Compute gradients
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
# Update parameters
optim_future = training_client.optim_step(
types.AdamParams(learning_rate=1e-4)
)
fwdbwd_result = await fwdbwd_future
print(f"Loss: {fwdbwd_result.loss}")
Async variant: forward_backward_async()
forward_backward_custom(data, loss_fn, loss_type_input='logprobs')
Compute forward/backward with a custom loss function.
Allows you to define custom loss functions that operate on log probabilities. The custom function receives logprobs and computes loss and gradients.
Parameters:
- data (List[types.Datum]) – List of training data samples
- loss_fn (CustomLossFnV1) – Custom loss function that takes (data, logprobs) and returns (loss, metrics)
- loss_type_input (Literal['logprobs'], default:
'logprobs') – Input space forloss_fn. Currently the only supported value is"logprobs".
Returns: APIFuture containing the forward/backward outputs with custom loss
def custom_loss(data, logprobs_list):
# Custom loss computation
loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list]))
metrics = {"custom_metric": loss.item()}
return loss, metrics
future = training_client.forward_backward_custom(data, custom_loss)
result = future.result()
print(f"Custom loss: {result.loss}")
print(f"Metrics: {result.metrics}")
Async variant: forward_backward_custom_async()
optim_step(adam_params)
Update model parameters using Adam optimizer.
The Adam optimizer used by tinker is identical to torch.optim.AdamW. Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
Parameters:
- adam_params (types.AdamParams) – Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
Returns: APIFuture containing optimizer step response
# First compute gradients
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
# Then update parameters
optim_future = training_client.optim_step(
types.AdamParams(
learning_rate=1e-4,
weight_decay=0.01
)
)
# Wait for both to complete
fwdbwd_result = await fwdbwd_future
optim_result = await optim_future
Async variant: optim_step_async()
save_state(name, ttl_seconds=None, overwrite=False)
Save model weights to persistent storage.
Parameters:
- name (str) – Name for the saved checkpoint
- ttl_seconds (int | None, default:
None) – Optional TTL in seconds for the checkpoint (None = never expires) - overwrite (bool, default:
False) – If True, overwrite any existing checkpoint with the same name
Returns: APIFuture containing the save response with checkpoint path
# Save after training
save_future = training_client.save_state("checkpoint-001")
result = await save_future
print(f"Saved to: {result.path}")
Async variant: save_state_async()
load_state(path, weights_access_token=None)
Load model weights from a saved checkpoint.
This loads only the model weights, not optimizer state (e.g., Adam momentum). To also restore optimizer state, use load_state_with_optimizer.
Parameters:
- path (str) – Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
- weights_access_token (str | None, default:
None) – Optional access token for loading checkpoints under a different account.
Returns: APIFuture containing the load response
# Load checkpoint to continue training (weights only, optimizer resets)
load_future = training_client.load_state("tinker://run-id/weights/checkpoint-001")
await load_future
# Continue training from loaded state
Async variant: load_state_async()
load_state_with_optimizer(path, weights_access_token=None)
Load model weights and optimizer state from a checkpoint.
Parameters:
- path (str) – Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
- weights_access_token (str | None, default:
None) – Optional access token for loading checkpoints under a different account.
Returns: APIFuture containing the load response
# Resume training with optimizer state
load_future = training_client.load_state_with_optimizer(
"tinker://run-id/weights/checkpoint-001"
)
await load_future
# Continue training with restored optimizer momentum
Async variant: load_state_with_optimizer_async()
save_weights_for_sampler(name, ttl_seconds=None)
Save model weights for use with a SamplingClient.
Parameters:
- name (str) – Name for the saved sampler weights
- ttl_seconds (int | None, default:
None) – Optional TTL in seconds for the checkpoint (None = never expires)
Returns: APIFuture containing the save response with sampler path
# Save weights for inference
save_future = training_client.save_weights_for_sampler("sampler-001")
result = await save_future
print(f"Sampler weights saved to: {result.path}")
# Use the path to create a sampling client
sampling_client = service_client.create_sampling_client(
model_path=result.path
)
Async variant: save_weights_for_sampler_async()
get_info()
Get information about the current model.
Returns: GetInfoResponse with model configuration and metadata
info = training_client.get_info()
print(f"Model ID: {info.model_data.model_id}")
print(f"Base model: {info.model_data.model_name}")
print(f"LoRA rank: {info.model_data.lora_rank}")
Async variant: get_info_async()
get_tokenizer()
Get the tokenizer for the current model.
Returns: PreTrainedTokenizer compatible with the model
create_sampling_client(model_path, retry_config=None)
Create a SamplingClient from saved weights.
Parameters:
- model_path (str) – Tinker path to saved weights
- retry_config (RetryConfig | None, default:
None) – Optional configuration for retrying failed requests
Returns: SamplingClient configured with the specified weights
sampling_client = training_client.create_sampling_client(
"tinker://run-id/weights/checkpoint-001"
)
# Use sampling_client for inference
Async variant: create_sampling_client_async()
save_weights_and_get_sampling_client(name=None, retry_config=None)
Save current weights and create a SamplingClient for inference.
Parameters:
- name (str | None, default:
None) – Deprecated, has no effect. Will be removed in a future release. - retry_config (RetryConfig | None, default:
None) – Optional configuration for retrying failed requests
Returns: SamplingClient configured with the current model weights
# After training, create a sampling client directly
sampling_client = training_client.save_weights_and_get_sampling_client()
# Now use it for inference
prompt = types.ModelInput.from_ints(tokenizer.encode("Hello"))
params = types.SamplingParams(max_tokens=20)
result = sampling_client.sample(prompt, 1, params).result()
Async variant: save_weights_and_get_sampling_client_async()