Skip to content

Conversation

@justheuristic
Copy link
Collaborator

@justheuristic justheuristic commented Aug 17, 2023

NB: this pull request makes several drastic changes to the backend, block_functions and pools. It might be better if I walk you through before the review. On a related note, if it interferes with long-term plans for the codebase, please raise a concern - i'm happy to rollback any detrimetnal changes.

Why this exists:

  1. So that user would be able to
outputs = model.forward(hidden_states,
                        attention_mask=something,
                        layer_past=my_learned_deep_prompts,
                        position_ids=..., **anything_ellse)
loss(outputs).backward()
assert my_learned_deep_prompts[0][0].grad is not None

and expect that the outputs are the same

  1. So that we can integrate all peft tuners
    • LoRA: output_with_lora = internal_model_interface.forward(inputs, **lora_adapters)
    • prefix- and P-tuning: output = internal_model_interface.forward(inputs, layer_past=make_method_dependent_tensors())
    • IA3: output_with_lora = internal_model_interface.forward(inputs, **ia3_state_dict)
    • that other method they're gonna add in the future

What does this PR contain

New functionality

  • servers will now support arbitrary kwargs for transformers blocks
    • user may provide either one set of kwargs or a different set of kwargs for each block
  • servers can now backprop w.r.t. additional keyword args and return gradients
  • low-level client interface supports forwarding arbitrary kwargs
  • user-facing client interface supports passing kwargs directly

Internal codebase changes:

  • RemoteSequenceManager.get_request_metadata now always accepts (server_id, protocol, block_uids, args, kwargs) in that order

    • This potentially breaks backward compatibility; gotta check with @borzunov on how many people it would affect
  • client-side code: packing args/kwargs and forming metadata was moved from sequential_autograd to remote_forward_backward

    • why: to reduce the number of low-level things (like serialization) that sequence-level code cares about
  • Task size is now specified explicitly in block_functions

  • Task and PrioritizedTaskPool support kwargs

    • note: if we (eventually) implement server-side batching, we can only batch queries if they have the same input schema anyway
      , and therefore, this pull request does not make server-side batching any more complicated than it already is

Notable missing functionality

  • (implementation issue) _RemoteSequentialAutogradFunction can't split sub-batches with kwargs

    • problem: how do we know which kwargs to split across batch dimension
    • layer_past and attention_mask should be split across batch dimension
    • adapters and head_mask should not be split across batch dimension
    • possible convention: if something.shape[0] == batch_size, it should be split over dim 0
      • bad: if batch_size == LoRA rank and we split the adapters >.<
    • possible convention: user provides batch_kwargs and global_kwargs, manually specifying which kwarg is split
    • possible convention: split if kwargs_schema[key] is BatchTensorProto; replicate if it is just a TensorProto
  • (implementation issue) InferenceSession only accepts kwargs during it's creation

    • to match the HF behavior, we should support kwargs during any inference steps
    • __problem 1: __ if a given server fails and we find a replacement, how should we playback the previous kwargs?
      • what if different past previous use different sets of kwargs?
      • possible solution: merge any two steps that have (1) matching kwargs up to IS and (2) no single kwarg has batch dimension
    • problem 2: how do we handle pushes (direct server-to-server communication) if next server requires kwargs?
      • possible solution: as a client, disable push for [this specific inference step] if the next server has non-empty kwargs
      • possible solution (2): implement "partial messages" so the client can broadcast all kwargs ahead of time, but the servers would still await
      • possible solution (3): implement a special inference type message that does NOT trigger inference step, but can change metadata/kwargs; this type of message is sent by client to all servers in parallel.

Tests & sanity checks

Sanity checks:

  • no perceivable slowdown on gpu vs main
  • check that main notebooks work with NF4 servers
  • check that forward/backward partitioning works properly with extra kwargs
  • per-block extra args work for non-merged inference pools (NOT the same as using merge inference pool with a single block)
  • force re-balancing and check that pools are properly deleted (nothing would suggest otherwise; testing just in case)
  • forward-backward still works on GPU
  • check that inference (specifically) works on GPU when user specifies a different dtype (i.e. not the same as server's torch_dtype)
  • check old client vs new server
  • check old server vs new client
  • check memory leaks under training queries
  • check memory leaks under inference queries
  • test that inference correctly forwards kwargs in case one of servers forcibly fails midway through inference

CI tests

  • everything that used to work works again
  • forward and backward with attention_mask - check exact match
  • inference with attention mask chunking - check exact match
  • full model exact match with attention mask
  • block-level backward with non-differentiable kwargs
  • block-level backward with differentiable kwargs
  • find a CI test that splits forward/backward batch into sub-batches or write a new one

@justheuristic justheuristic marked this pull request as draft August 17, 2023 01:45
@justheuristic justheuristic changed the title Forward arbitrary kwargs Forward arbitrary kwargs to remote blocks Aug 17, 2023
@justheuristic justheuristic marked this pull request as ready for review August 22, 2023 13:03
@justheuristic
Copy link
Collaborator Author

note 2 self: old client runs backward with inputs that do not require_grad, we must support that!

@justheuristic
Copy link
Collaborator Author

justheuristic commented Sep 6, 2023

note 2self: on wake up, do

  • add args/kwargs partitioning in _RemoteSequentialAutogradFunction
    • strategy: if tensor.shape[0] == batch_size, split; otherwise, replicate
    • add some documentation on how that works

  • modify inference_session.py to save past args/kwargs and re-send them on server failure
  • from first step
  • from intermediate steps
  • Q: how do we merge _ServerInferenceSession.history if it uses different kwargs between steps

  • handle prefix length correctly if first input batch contains layer_past
  • in petals.server.handler
  • in petals.client.inference_session

if attempt_no >= 1:
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subjective matter: sequence_manager is the first parameter to most internal functions; can rollback if the reviewer disagrees.

value = value[:, offset : offset + max_chunk_length]
kwargs_chunk[key] = value
return kwargs_chunk

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is a potential problem; not all tensors where shape[-2] == seq_len can be time-sliced.

Counter-example: a LoRA adapter might accidentally have it's rank equal to sequence length

@staticmethod
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor):
# TODO add kwargs here; figure out a way to split kwargs across servers
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

problem: how do we split args/kwargs into sub-batches?

# Conflicts:
#	src/petals/__init__.py
#	src/petals/client/inference_session.py
@dvmazur
Copy link
Collaborator

dvmazur commented Dec 2, 2023

@justheuristic solemnly swears to

  • show a proof that forwarding kwargs works in a basic test that is easy to follow
  • show an example of how new clients can work with old servers, at least for all supported basic ops
  • show an example of how old clients can work with new servers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants