Skip to content

Commit 11d5e08

Browse files
Merge pull request #2684 from AI-Hypercomputer:sft_metrics
PiperOrigin-RevId: 832084340
2 parents 60345b6 + 26de309 commit 11d5e08

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/MaxText/sft/sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
from orbax import checkpoint as ocp
4848

49-
from tunix.sft import peft_trainer, profiler
49+
from tunix.sft import metrics_logger, peft_trainer, profiler
5050

5151
from MaxText import max_utils
5252
from MaxText import max_logging
@@ -80,7 +80,7 @@ def get_tunix_config(mt_config):
8080
)
8181

8282
# Metrics configurations
83-
metrics_logging_options = peft_trainer.metrics_logger.MetricsLoggerOptions(log_dir=mt_config.tensorboard_dir)
83+
metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=mt_config.tensorboard_dir)
8484

8585
# Profiler configurations
8686
profiler_options = None

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
import threading
6262
from functools import partial
6363
from typing import Sequence, List, Any, Callable
64-
from MaxText.inference_utils import str2bool
6564
import numpy as np
6665
import jax
6766
import psutil
@@ -80,6 +79,7 @@
8079
from MaxText import maxtext_utils
8180
from MaxText import pyconfig
8281
from MaxText.common_types import MODEL_MODE_TRAIN
82+
from MaxText.inference_utils import str2bool
8383
from MaxText.layers import models, quantizations
8484
from MaxText.checkpointing import save_checkpoint
8585
from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING

0 commit comments

Comments
 (0)