diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 6910b8a1c..ca44a2c5c 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -12,7 +12,7 @@ jobs: - name: Install Python uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.9" cache: "pip" cache-dependency-path: "**/requirements*.txt" diff --git a/.gitignore b/.gitignore index dbc83e949..36aa83489 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,6 @@ src/ # test data files tests/data/*.bin tests/data/*.idx + +# evaluation results +*eval_results*.json diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 231418025..e9376391d 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -14,19 +14,23 @@ LR Scheduler Arguments Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'. + - **lr_decay_iters**: int Default = None - Number of iterations to decay learning rate over. If None, defaults to - --train-iters or the equivalent inferred value from train_epochs. + Number of iterations to decay learning rate over, If None defaults to + --train-iters or the equivalent inferred valued from train_epochs. + + - **lr_decay_fraction**: float Default = None - Effective fraction of training over which to decay lr. Overrides lr_decay_iters. - Useful when specifying train_epochs. + Effective fraction of training over which to decay lr, overrides lr_decay_iters, useful when specifying train_epochs + + - **min_lr**: float @@ -82,6 +86,14 @@ Logging Arguments +- **wandb_run_name**: str + + Default = None + + Weights and Biases run name for the current experiment + + + - **wandb_team**: str Default = None @@ -116,7 +128,7 @@ Logging Arguments - **git_hash**: str - Default = 62c9738a + Default = bb881f3b current git hash of repository @@ -186,6 +198,22 @@ Logging Arguments +- **comet_experiment**: Any + + Default = None + + Initialized comet experiment object used to log data + + + +- **peak_theoretical_tflops**: float + + Default = None + + The peak hardware flops with which to compute MFU and HFU, in units of teraflops. Automatic detection is more trouble than it's worth, so this is left to the user. Helpful table listed at https://github.com/stas00/ml-engineering/tree/master/compute/accelerator#tflops-comparison-table + + + - **log_interval**: int Default = 100 @@ -215,8 +243,7 @@ Logging Arguments Default = False Log the frob norm of the gradients to wandb / tensorboard (useful for debugging). - (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because - deepspeed.) + (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.) @@ -272,8 +299,8 @@ Logging Arguments Default = False - Enable nsys profiling. When using this option, - nsys options should be specified in commandline. + Enable nsys and pytorch profiling. When using this option with nsys, + nsys options should be directly specified in commandline. An example nsys commandline is ``` nsys profile -s none -t nvtx,cuda -o @@ -402,11 +429,11 @@ Model Arguments -- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm', 'te_rmsnorm', 'te_layernorm'] +- **norm**: typing.Literal['layernorm', 'rmsnorm', 'non_parametric_layernorm', 'scalenorm', 'te_rmsnorm', 'te_layernorm'] Default = layernorm - Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm". + Normalization layer to use. Choose from "layernorm", "rmsnorm", "non_parametric_layernorm", "scalenorm", "te_rmsnorm", "te_layernorm". @@ -843,6 +870,124 @@ Model Arguments +- **serve_model_weights**: bool + + Default = False + + If true, serve model weight pointers over a socket connection + + + +- **weight_server_port**: typing.Union[int, typing.List[int]] + + Default = 6000 + + Port(s) to serve model weights over + If an integer is provided, the port for each GPU will be 6000 + global rank + If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0] + + + +- **online_dataserver_ips**: typing.Union[str, typing.List[str]] + + Default = localhost + + ip addresses to connect to for online data serving, defaults to localhost + + + +- **online_dataserver_ports**: typing.Union[int, typing.List[int]] + + Default = 10000 + + Port(s) to connect to for online data serving, defaults to 10000 + + + +- **te_columnparallel**: bool + + Default = False + + Use TransformerEngine for RowParallelLinear layer. + + + +- **te_rowparallel**: bool + + Default = False + + Use TransformerEngine for ColumnParallelLinear layer. + + + +- **te_layernorm_mlp**: bool + + Default = False + + Use TransformerEngine for LayerNormMLP layer. + + + +- **te_mha**: bool + + Default = False + + Use TransformerEngine for MultiheadAttention layer. + + + +- **te_fp8_format**: typing.Literal['e4m3', 'hybrid'] + + Default = hybrid + + Controls the FP8 data format used during forward and backward pass by TransformerEngine. + Hybrid uses E4M3 during forward pass, E5M2 during backward pass. + + + +- **te_fp8_wgrad**: bool + + Default = True + + When set to False, override FP8 config options and do the wgrad computation + in higher precision. + + + +- **te_fp8_amax_history_len**: int + + Default = 1 + + The length of the amax history window used for scaling factor computation. + + + +- **te_fp8_amax_compute_algo**: str + + Default = most_recent + + Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 + predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent` + always chooses the most recently seen value. + + + +- **te_fp8_margin**: int + + Default = 0 + + Margin for the scaling factor computation. + + + +- **te_fp8_mha**: bool + + Default = False + + When set to True, use the FP8 implementation of Multi Head Attention. + + + - **dim_att**: int Default = None @@ -866,6 +1011,7 @@ Model Arguments Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor. + ## NeoXArgsOptimizer Optimizer Arguments @@ -1095,14 +1241,6 @@ Misc. Arguments -- **save_iters**: list - - Default = None - - Set during training - - - - **global_num_gpus**: int Default = None @@ -1307,6 +1445,14 @@ Text Generation arguments +- **eval_task_limit**: int + + Default = None + + Limit the number of examples per lm_eval_harness task + + + - **moe_top_k**: int Default = 1 @@ -1727,19 +1873,19 @@ Training Arguments -- **dataset_impl**: typing.Literal['gpt2', 'pairwise'] +- **dataset_impl**: typing.Literal['gpt2', 'pairwise', 'online'] Default = gpt2 - Dataset implementation, can be one of "gpt2" or "pairwise" + Dataset implementation, can be one of "gpt2", "pairwise", or "online" -- **train_impl**: typing.Literal['normal', 'dpo', 'rm', 'kto'] +- **train_impl**: typing.Literal['normal', 'dpo', 'rm', 'kto', 'reinforce'] Default = normal - Training implementation, can be one of "normal", "dpo", "kto", or "rm" + Training implementation, can be one of "normal", "dpo", "kto", "reinforce", or "rm" @@ -1791,6 +1937,16 @@ Training Arguments +- **z_loss**: float + + Default = 0.0 + + Z-loss parameter, only implemented for RM training currently. + https://arxiv.org/pdf/2204.02311 + https://arxiv.org/pdf/2309.10305 + + + - **kto_beta**: float Default = 0.1 @@ -1799,6 +1955,39 @@ Training Arguments +- **fp32_reinforce**: bool + + Default = True + + Whether to cast logits to fp32 for Reinforce loss calculation. + + + +- **kl_impl**: typing.Literal['abs', 'mse', 'kl', 'full'] + + Default = mse + + KL divergence implementation, can be one of "abs", "mse", "kl", or "full" + + + +- **kl_div_beta**: float + + Default = 0.1 + + Beta value for KL divergence in Reinforce loss calculation. + + + +- **reinforce_leave_one_out**: bool + + Default = False + + Whether to use reinforce leave one out for training + (from https://arxiv.org/abs/2402.14740 and https://api.semanticscholar.org/CorpusID:198489118) + + + - **allow_chopped**: bool Default = True @@ -1875,7 +2064,7 @@ Training Arguments -- **checkpoint_factor**: int +- **checkpoint_factor**: typing.Union[int, float] Default = None diff --git a/eval_tasks/eval_adapter.py b/eval_tasks/eval_adapter.py index abbd5ca8d..fda6c3b1d 100644 --- a/eval_tasks/eval_adapter.py +++ b/eval_tasks/eval_adapter.py @@ -17,6 +17,7 @@ import copy import os import sys +import itertools import dataclasses from functools import partial @@ -27,7 +28,10 @@ import torch import torch.nn.functional as F +from lm_eval.models.utils import chunks from lm_eval.models.huggingface import HFLM +from lm_eval.api.group import ConfigurableGroup +from lm_eval.loggers.utils import get_git_commit_hash from lm_eval import tasks, evaluator, utils, api from megatron.text_generation_utils import generate_samples_from_prompt from megatron import mpu @@ -219,7 +223,7 @@ def _collate(x): return (-len(toks), tuple(toks)) reord = utils.Reorderer(requests, _collate) - for chunk in utils.chunks( + for chunk in chunks( tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size ): inps, contlens, inplens, padding_length = [], [], [], None @@ -412,7 +416,8 @@ def run_eval( ] # register all the default tasks bundled with lm-evaluation-harness repository - tasks.initialize_tasks() + task_manager = tasks.TaskManager() + task_manager.initialize_tasks() # Returns a list containing all values of the task registry that # match at least one of the patterns @@ -425,7 +430,8 @@ def pattern_match(patterns, source_list): task_names.add(matching) return list(task_names) - eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + all_tasks = task_manager._all_tasks + eval_tasks = pattern_match(eval_tasks, all_tasks) print(f"Found tasks: {eval_tasks}") assert len(eval_tasks) > 0, "Must run at least one task" @@ -465,32 +471,45 @@ def pattern_match(patterns, source_list): # from simple_evaluate: # override fewshot values for all tasks we can for task_name in task_dict.keys(): - task_obj = task_dict[task_name] - if type(task_obj) == tuple: - group, task_obj = task_obj - if task_obj is None: - continue - - config = task_obj._config - - if num_fewshot is not None: - if config["num_fewshot"] == 0: - utils.eval_logger.info( - f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." - ) - else: - default_num_fewshot = config["num_fewshot"] - if not default_num_fewshot: - utils.eval_logger.warning( - f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" + group_task_objects = [] + top_level_task = task_dict[task_name] + if isinstance(task_name, ConfigurableGroup): + for task_group in list(task_dict[task_name].values()): + group_task_objects.extend(list(task_group.values())) + elif isinstance(task_name, str): + group_task_objects.append(top_level_task) + else: + raise ValueError( + "The task object is of an unhandled type. Unable to override fewshot values." + ) + + for task_obj in group_task_objects: + if type(task_obj) == tuple: + group, task_obj = task_obj + if task_obj is None: + continue + + config = task_obj._config + + utils.setup_logging() + if num_fewshot is not None: + if config["num_fewshot"] == 0: + utils.logging.info( + f"num_fewshot has been set to 0 for {config.task} in its config. Manual configuration will be ignored." ) + else: + default_num_fewshot = config["num_fewshot"] + if not default_num_fewshot: + utils.logging.warning( + f"Overwriting default num_fewshot of {config.task} from {default_num_fewshot} to {num_fewshot}" + ) - task_obj._config["num_fewshot"] = num_fewshot + task_obj._config["num_fewshot"] = num_fewshot results = evaluator.evaluate( lm=lm, task_dict=task_dict, - limit=10, # limit, + limit=limit, bootstrap_iters=bootstrap_iters, log_samples=False, ) @@ -504,12 +523,21 @@ def pattern_match(patterns, source_list): "limit": limit, "bootstrap_iters": bootstrap_iters, } - results["git_hash"] = utils.get_git_commit_hash() + results["git_hash"] = get_git_commit_hash() print(results.keys()) for task_name in task_dict.keys(): - if "alias" in results["results"][task_name]: - results["results"][task_name].pop("alias") + sub_task_names = [] + if isinstance(task_name, ConfigurableGroup): + task_groups = task_dict[task_name] + tasks_by_group = [list(group.keys()) for group in task_groups.values()] + sub_task_names = list(itertools.chain(*tasks_by_group)) + else: + sub_task_names.append(task_name) + + for sub_task in sub_task_names: + if "alias" in results["results"][sub_task]: + results["results"][sub_task].pop("alias") if was_training: self.model.train() @@ -535,4 +563,5 @@ def run_eval_harness( num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters, use_cache=False, + limit=neox_args.eval_task_limit, ) diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h index 6b1c8927d..c8dbd1f10 100644 --- a/megatron/fused_kernels/type_shim.h +++ b/megatron/fused_kernels/type_shim.h @@ -277,7 +277,7 @@ reduce_block_into_lanes(T* x, final = x[tid] + x[tid + 32]; else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1) @@ -321,7 +321,7 @@ reduce_block_into_lanes_max_op(T* x, final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1) diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index aa6558e99..1443df613 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -646,8 +646,8 @@ def __init__(self, neox_args): override_linear_precision = (False, False, not neox_args.te_fp8_wgrad) super().__init__( - margin=neox_args.fp8_margin, - fp8_format=te_fp8_format, + margin=neox_args.te_fp8_margin, + fp8_format=neox_args.te_fp8_format, amax_compute_algo=neox_args.te_fp8_amax_compute_algo, amax_history_len=neox_args.te_fp8_amax_history_len, override_linear_precision=override_linear_precision, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index d93cac6e3..64e7dd196 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1605,6 +1605,11 @@ class NeoXArgsTextgen(NeoXArgsTemplate): NOTE: Requires internet connection """ + eval_task_limit: int = None + """ + Limit the number of examples per lm_eval_harness task + """ + moe_top_k: int = 1 """ Activate top K experts in MoE diff --git a/megatron/training.py b/megatron/training.py index 3def74860..1d6ca2fbd 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1626,11 +1626,13 @@ def train( lr_scheduler=lr_scheduler, ) # Evaluation - if ( - neox_args.eval_interval - and iteration % neox_args.eval_interval == 0 - and neox_args.do_valid - ): + is_eval_internal = ( + neox_args.eval_interval and iteration % neox_args.eval_interval == 0 + ) + is_validation_configured = bool(neox_args.do_valid) or ( + isinstance(neox_args.eval_tasks, list) and len(neox_args.eval_tasks) > 0 + ) + if is_eval_internal and is_validation_configured: prefix = "iteration {}".format(iteration) evaluate_and_print_results( neox_args=neox_args, @@ -1683,46 +1685,49 @@ def evaluate( if neox_args.char_level_ppl: data_iterator = CharCounter(data_iterator, neox_args.tokenizer) - with torch.no_grad(): - iteration = 0 - while iteration < neox_args.eval_iters: - iteration += 1 - if verbose and iteration % neox_args.log_interval == 0: - print_rank_0( - "Evaluating iter {}/{}".format(iteration, neox_args.eval_iters) - ) + eval_results = {} + if data_iterator is not None: + with torch.no_grad(): + iteration = 0 + while iteration < neox_args.eval_iters: + iteration += 1 + if verbose and iteration % neox_args.log_interval == 0: + print_rank_0( + "Evaluating iter {}/{}".format(iteration, neox_args.eval_iters) + ) - # although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s - # to be consistent with deepspeed's pipe parallel engine - # since pipe parallel already takes gradient_accumulation_steps into account - default to 1 here if pipe parallel is true - for _ in range( - 1 - if neox_args.is_pipe_parallel - else neox_args.gradient_accumulation_steps - ): - # Forward evaluation - loss, metric_dict = forward_step_fn( - model=model, - data_iterator=data_iterator, - neox_args=neox_args, - timers=timers, - reference_model=reference_model, - ) - losses.append(loss) - for key in metric_dict.keys(): - metric_dicts[key].append(metric_dict[key]) - # When contiguous memory optimizations are enabled, the buffers - # allocated by the optimizations are deallocated during backward pass - # in the absence of backward pass the buffers should be reset after each - # forward pass - if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing: - deepspeed.checkpointing.reset() - - # reduces losses across processes for logging & run eval harness tasks - eval_results = {"lm_loss": reduce_losses(losses).mean().item()} - for key in metric_dicts.keys(): - eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() - eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) + # although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s + # to be consistent with deepspeed's pipe parallel engine + # since pipe parallel already takes gradient_accumulation_steps into account - default to 1 here if pipe parallel is true + for _ in range( + 1 + if neox_args.is_pipe_parallel + else neox_args.gradient_accumulation_steps + ): + # Forward evaluation + loss, metric_dict = forward_step_fn( + model=model, + data_iterator=data_iterator, + neox_args=neox_args, + timers=timers, + reference_model=reference_model, + ) + losses.append(loss) + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) + # When contiguous memory optimizations are enabled, the buffers + # allocated by the optimizations are deallocated during backward pass + # in the absence of backward pass the buffers should be reset after each + # forward pass + if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing: + deepspeed.checkpointing.reset() + + # reduces losses across processes for logging & run eval harness tasks + eval_results["lm_loss"] = reduce_losses(losses).mean().item() + for key in metric_dicts.keys(): + eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() + + eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) if neox_args.char_level_ppl: # calculate character level perplexity, if specified diff --git a/megatron/utils.py b/megatron/utils.py index fc2f80dad..5ba988b3b 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -183,7 +183,7 @@ def init_wandb(neox_args): "Skipping wandb. Execute `wandb login` on local or main node machine to enable.", flush=True, ) - wandb.config.update(neox_args.all_config) + wandb.config.update(neox_args.all_config, allow_val_change=True) def obtain_resource_pool( diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 21906f0e2..365826fb4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,7 +3,7 @@ ftfy>=6.0.1 huggingface_hub>=0.11.0 jinja2==3.1.4 lm_dataformat@git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836 -lm_eval>=0.4.0,<=0.4.1 +lm_eval==0.4.8 mpi4py>=3.0.3 numpy<2.0 pybind11>=2.6.2