Here is some pseudo-code: ```python class CleavedAutoModelForCausalLM(nn.Module): def __init__(self, model: AutoModelForCausalLM, cleave_point: int): super().__init__() self.model = model self.embed_tokens = self.model.model.embed_tokens # Break into lower and upper halves self.lower_half = nn.Sequential(*self.model.model.layers[: cleave_point]) self.upper_half = nn.Sequential(*self.model.model.layers[cleave_point:]) self.lm_head = self.model.lm_head def forward(self, x): hidden_states = self.embed_tokens(x) hidden_states = self.lower_half(x) return self.lm_head(hidden_states) def full_forward(self, x): return self.model(x) def upper_forward(self, input_features): hidden_states = self.upper_half(input_features) return self.lm_head(hidden_states) def verify_step(self, input_features): return self.upper_forward(input_features) ```