Skip to content

Create a CleavedAutoModelForCausalLM to implement handle early exit mechanism from the kangaroo paper #1

@nasheedyasin

Description

@nasheedyasin

Here is some pseudo-code:

class CleavedAutoModelForCausalLM(nn.Module):
    def __init__(self, model: AutoModelForCausalLM, cleave_point: int):
        super().__init__()
        self.model = model

        self.embed_tokens = self.model.model.embed_tokens

        # Break into lower and upper halves
        self.lower_half = nn.Sequential(*self.model.model.layers[: cleave_point])
        self.upper_half = nn.Sequential(*self.model.model.layers[cleave_point:])
        self.lm_head = self.model.lm_head

    def forward(self, x):
        hidden_states = self.embed_tokens(x)
        hidden_states = self.lower_half(x)

        return self.lm_head(hidden_states)

    def full_forward(self, x):
        return self.model(x)

    def upper_forward(self, input_features):
        hidden_states = self.upper_half(input_features)
        return self.lm_head(hidden_states)

    def verify_step(self, input_features):
        return self.upper_forward(input_features)
    

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions