Skip to content

Multiple Improvements for mmengine #1629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4f62c98
Fix torch FutureWarning
MGAMZ Jul 21, 2024
b6b4224
Fix torch FutureWarning
MGAMZ Jul 21, 2024
4c7a5d4
Optimize the prompt for compile
MGAMZ Jul 26, 2024
28d47f8
Fix Incorrect Optim Param Resume Method
MGAMZ Aug 21, 2024
91d945f
Update runner.py to support pure-python style model wrapper configura…
MGAMZ Aug 28, 2024
0934d75
Merge branch 'open-mmlab:main' into main
MGAMZ Sep 22, 2024
7103c3e
reconstruct
MGAMZ Sep 23, 2024
eecaa92
PyTorch Profiler within IterBasedTrainLoop
MGAMZ Nov 3, 2024
698ad5e
enable hook error exception traceback
MGAMZ Nov 3, 2024
8c80332
Merge branch 'main' of github.com:MGAMZ/mmengine
MGAMZ Nov 3, 2024
3cf1003
Merge branch 'open-mmlab:main' into main
MGAMZ Nov 4, 2024
1e4c2ed
improve codes
MGAMZ Nov 15, 2024
2a5a1fe
Merge branch 'open-mmlab:main' into main
MGAMZ Nov 23, 2024
29e3a08
KeyError: 'Adafactor is already registered in optimizer at torch.optim'.
MGAMZ Jan 3, 2025
896576b
Merge branch 'main' of https://github.com/MGAMZ/mmengine
MGAMZ Jan 3, 2025
be86710
Update support for deep speed and multiple improvements.
MGAMZ Jan 11, 2025
dadedbb
Merge branch 'main' of gitee.com:MGAM/mmengine
MGAMZ Jan 11, 2025
861fc1b
improve multiple mmengine undeveloped issues.
MGAMZ Jan 12, 2025
8f37dd2
Multiple improvements
MGAMZ Jan 17, 2025
bed2660
Merge branch 'open-mmlab:main' into main
MGAMZ Jan 17, 2025
46cfdbb
Revert "Multiple improvements"
Jan 17, 2025
5376661
Revert "Update support for deep speed and multiple improvements."
Jan 17, 2025
f038e5e
Revert "reconstruct"
Jan 17, 2025
e36e2f1
Revert "PyTorch Profiler within IterBasedTrainLoop"
Jan 17, 2025
834cf9a
fix
Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ def compile_model(
Returns:
nn.Module: Compiled model.
"""
if isinstance(compile, bool) and not compile:
if isinstance(compile, bool) and not compile or \
isinstance(compile, dict) and not compile.get('disable', False):
return model

assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None:
``optimizer.state_dict()``
"""
optim_state_dict = FSDP.optim_state_dict_to_load(
state_dict, self.model, self.optim_wrapper.optimizer)
self.model, self.optim_wrapper.optimizer, state_dict)
self.optim_wrapper.load_state_dict(optim_state_dict)

def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
Expand Down
1 change: 0 additions & 1 deletion mmengine/model/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def is_init(self, value):

def init_weights(self):
"""Initialize the weights."""

is_top_level_module = False
# check if it is top-level module
if not hasattr(self, '_params_init_info'):
Expand Down
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from functools import partial
from typing import Union

import torch
Expand All @@ -17,7 +18,8 @@
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler as amp_GradScaler
GradScaler = partial(amp_GradScaler, device='cuda')


@OPTIM_WRAPPERS.register_module()
Expand Down
23 changes: 15 additions & 8 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available, is_npu_support_full_precision
from mmengine.logging.logger import print_log
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper


Expand Down Expand Up @@ -170,7 +172,10 @@ def register_transformers_optimizers():
except ImportError:
pass
else:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
try:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
except KeyError as e:
pass
transformer_optimizers.append('Adafactor')
return transformer_optimizers

Expand All @@ -196,20 +201,22 @@ def build_optim_wrapper(model: nn.Module,
OptimWrapper: The built optimizer wrapper.
"""
optim_wrapper_cfg = copy.deepcopy(cfg)
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
constructor_cfg = optim_wrapper_cfg.pop('constructor', None)
if constructor_cfg is None:
constructor_cfg = dict(type=DefaultOptimWrapperConstructor)
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)

# Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision
# to make the training normal
if is_npu_available() and not is_npu_support_full_precision():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'

constructor_cfg.update(dict(
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg
))

optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(constructor_cfg)
optim_wrapper = optim_wrapper_constructor(model)
return optim_wrapper
2 changes: 1 addition & 1 deletion mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def load_from_local(filename, map_location):
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
return checkpoint


Expand Down
1 change: 0 additions & 1 deletion mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def run(self) -> dict:
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

Expand Down
15 changes: 13 additions & 2 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import logging
import os
import os.path as osp
Expand Down Expand Up @@ -902,8 +903,18 @@ def wrap_model(
find_unused_parameters=find_unused_parameters)
else:
model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_cfg.get('type')) # type: ignore
model_wrapper_type = model_wrapper_cfg.get('type')
if isinstance(model_wrapper_type, str):
model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore
elif inspect.isclass(model_wrapper_type):
pass
else:
raise KeyError(
f'{model_wrapper_type} is not in the '
'registry. Please check whether the value of '
f'`{model_wrapper_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
default_args: dict = dict()
if issubclass(
model_wrapper_type, # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion mmengine/visualization/vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ def add_scalar(self,
(int, float, torch.Tensor, np.ndarray, np.number)):
self._tensorboard.add_scalar(name, value, step)
else:
warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, '
warnings.warn(f'Got type {type(value)} with name {name}, '
'but numpy array, torch tensor, '
f'int or float are expected. skip it!')

@force_init_env
Expand Down