Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate SwanLab for offline/online experiment tracking and local visualization #36433

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
- [`~integrations.SwanLabCallback`] if [swanlab](http://swanlab.cn/) is installed.

If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).

Expand Down Expand Up @@ -92,6 +93,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
[[autodoc]] integrations.DVCLiveCallback
- setup

[[autodoc]] integrations.SwanLabCallback
- setup

## TrainerCallback

[[autodoc]] TrainerCallback
Expand Down
4 changes: 4 additions & 0 deletions docs/source/ja/main_classes/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ rendered properly in your Markdown viewer.
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
- [`~integrations.SwanLabCallback`] [swanlab](http://swanlab.cn/) がインストールされている場合。

パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。

Expand Down Expand Up @@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
[[autodoc]] integrations.DVCLiveCallback
- setup

[[autodoc]] integrations.SwanLabCallback
- setup

## TrainerCallback

[[autodoc]] TrainerCallback
Expand Down
4 changes: 4 additions & 0 deletions docs/source/ko/main_classes/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
- [`~integrations.DagsHubCallback`]는 [dagshub](https://dagshub.com/)이 설치되어 있으면 사용됩니다.
- [`~integrations.FlyteCallback`]는 [flyte](https://flyte.org/)가 설치되어 있으면 사용됩니다.
- [`~integrations.DVCLiveCallback`]는 [dvclive](https://dvc.org/doc/dvclive)가 설치되어 있으면 사용됩니다.
- [`~integrations.SwanLabCallback`]는 [swanlab](https://swanlab.cn)가 설치되어 있으면 사용됩니다.

패키지가 설치되어 있지만 해당 통합 기능을 사용하고 싶지 않다면, `TrainingArguments.report_to`를 사용하고자 하는 통합 기능 목록으로 변경할 수 있습니다 (예: `["azure_ml", "wandb"]`).

Expand Down Expand Up @@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
[[autodoc]] integrations.DVCLiveCallback
- setup

[[autodoc]] integrations.SwanLabCallback
- setup

## TrainerCallback [[trainercallback]]

[[autodoc]] TrainerCallback
Expand Down
4 changes: 4 additions & 0 deletions docs/source/zh/main_classes/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
- [`~integrations.DagsHubCallback`],如果安装了[dagshub](https://dagshub.com/)。
- [`~integrations.FlyteCallback`],如果安装了[flyte](https://flyte.org/)。
- [`~integrations.DVCLiveCallback`],如果安装了[dvclive](https://dvc.org/doc/dvclive)。
- [`~integrations.SwanLabCallback`],如果安装了[swanlab](http://swanlab.cn/)。

如果安装了一个软件包,但您不希望使用相关的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您想要使用的集成的列表(例如 `["azure_ml", "wandb"]`)。

Expand Down Expand Up @@ -81,6 +82,9 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
[[autodoc]] integrations.DVCLiveCallback
- setup

[[autodoc]] integrations.SwanLabCallback
- setup

## TrainerCallback

[[autodoc]] TrainerCallback
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
"is_ray_available",
"is_ray_tune_available",
"is_sigopt_available",
"is_swanlab_available",
"is_tensorboard_available",
"is_wandb_available",
],
Expand Down Expand Up @@ -5265,6 +5266,7 @@
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_swanlab_available,
is_tensorboard_available,
is_wandb_available,
)
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@
"load_dequant_gguf_tensor",
"load_gguf",
],
"higgs": ["HiggsLinear", "dequantize_higgs", "quantize_with_higgs", "replace_with_higgs_linear"],
"higgs": [
"HiggsLinear",
"dequantize_higgs",
"quantize_with_higgs",
"replace_with_higgs_linear",
],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
Expand All @@ -77,6 +82,7 @@
"MLflowCallback",
"NeptuneCallback",
"NeptuneMissingConfiguration",
"SwanLabCallback",
"TensorBoardCallback",
"WandbCallback",
"get_available_reporting_integrations",
Expand All @@ -96,6 +102,7 @@
"is_ray_available",
"is_ray_tune_available",
"is_sigopt_available",
"is_swanlab_available",
"is_tensorboard_available",
"is_wandb_available",
"rewrite_logs",
Expand Down Expand Up @@ -182,6 +189,7 @@
MLflowCallback,
NeptuneCallback,
NeptuneMissingConfiguration,
SwanLabCallback,
TensorBoardCallback,
WandbCallback,
get_available_reporting_integrations,
Expand All @@ -201,6 +209,7 @@
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_swanlab_available,
is_tensorboard_available,
is_wandb_available,
rewrite_logs,
Expand Down
163 changes: 163 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def is_dvclive_available():
return importlib.util.find_spec("dvclive") is not None


def is_swanlab_available():
return importlib.util.find_spec("swanlab") is not None


def hp_params(trial):
if is_optuna_available():
import optuna
Expand Down Expand Up @@ -610,6 +614,8 @@ def get_available_reporting_integrations():
integrations.append("codecarbon")
if is_clearml_available():
integrations.append("clearml")
if is_swanlab_available():
integrations.append("swanlab")
return integrations


Expand Down Expand Up @@ -2141,6 +2147,162 @@ def on_train_end(self, args, state, control, **kwargs):
self.live.end()


class SwanLabCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
"""

def __init__(self):
if not is_swanlab_available():
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
import swanlab

self._swanlab = swanlab
self._initialized = False
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)

def setup(self, args, state, model, **kwargs):
"""
Setup the optional SwanLab (*swanlab*) integration.

One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).

You can also override the following environment variables. Find more information about environment
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)

Environment:
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
checks if the user is already logged in. If not, the login process is initiated.

- If a string is passed to the login interface, this environment variable is ignored.
- If the user is already logged in, this environment variable takes precedence over locally stored
login information.

- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
Set this to a custom string to store results in a different project. If not specified, the name of the current
running directory is used.

- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
This environment variable specifies the storage path for log files when running in local mode.
By default, logs are saved in a folder named swanlog under the working directory.

- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
local, cloud, and disabled. Note: Case-sensitive. Find more information
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)

- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
SwanLab does not currently support the save mode functionality.This feature will be available in a future
release

- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
Web address for the SwanLab cloud environment for private version (its free)

- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
API address for the SwanLab cloud environment for private version (its free)

"""
self._initialized = True

if state.is_world_process_zero:
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
combined_dict = {**args.to_dict()}

if hasattr(model, "config") and model.config is not None:
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
elif args.run_name is not None:
init_args["experiment_name"] = args.run_name
init_args["project"] = os.getenv("SWANLAB_PROJECT", None)

if self._swanlab.get_run() is None:
self._swanlab.init(
**init_args,
)
# show transformers logo!
self._swanlab.config["FRAMEWORK"] = "🤗transformers"
# add config parameters (run may have been created manually)
self._swanlab.config.update(combined_dict)

# add number of model parameters to swanlab config
try:
self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
# get peft model parameters
if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
trainable_params, all_param = model.get_nb_trainable_parameters()
self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
self._swanlab.config.update({"peft_model_all_param": all_param})
except AttributeError:
logger.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")

# log the initial model architecture to an artifact
if self._log_model is not None:
logger.warning(
"SwanLab does not currently support the save mode functionality. "
"This feature will be available in a future release."
)
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/SwanHubX/assets/main/badge1.svg"'
f' alt="Visualize in SwanLab" height="28'
f'0" height="32"/>]({self._swanlab.get_run().public.cloud.exp_url})'
)

modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"

def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model, **kwargs)

def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
if self._log_model is not None and self._initialized and state.is_world_process_zero:
logger.warning(
"SwanLab does not currently support the save mode functionality. "
"This feature will be available in a future release."
)

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"train_loss",
"total_flos",
]

if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
for k, v in logs.items():
if k in single_value_scalars:
self._swanlab.log({f"single_value/{k}": v})
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
non_scalar_logs = rewrite_logs(non_scalar_logs)
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step})

def on_save(self, args, state, control, **kwargs):
if self._log_model is not None and self._initialized and state.is_world_process_zero:
logger.warning(
"SwanLab does not currently support the save mode functionality. "
"This feature will be available in a future release."
)

def on_predict(self, args, state, control, metrics, **kwargs):
if not self._initialized:
self.setup(args, state, **kwargs)
if state.is_world_process_zero:
metrics = rewrite_logs(metrics)
self._swanlab.log(metrics)


INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
Expand All @@ -2153,6 +2315,7 @@ def on_train_end(self, args, state, control, **kwargs):
"dagshub": DagsHubCallback,
"flyte": FlyteCallback,
"dvclive": DVCLiveCallback,
"swanlab": SwanLabCallback,
}


Expand Down
11 changes: 11 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_swanlab_available,
is_tensorboard_available,
is_wandb_available,
)
Expand Down Expand Up @@ -1098,6 +1099,16 @@ def require_sigopt(test_case):
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)


def require_swanlab(test_case):
"""
Decorator marking a test that requires swanlab.
These tests are skipped when swanlab isn't installed.
"""
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)


def require_wandb(test_case):
"""
Decorator marking a test that requires wandb.
Expand Down
16 changes: 9 additions & 7 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ class TrainingArguments:
training step under the keyword argument `mems`.
run_name (`str`, *optional*, defaults to `output_dir`):
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),
[mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will
be the same as `output_dir`.
[mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and [swanlab](https://swanlab.cn)
logging. If not specified, will be the same as `output_dir`.
disable_tqdm (`bool`, *optional*):
Whether or not to disable the tqdm progress bars and table of metrics produced by
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
Expand Down Expand Up @@ -642,8 +642,8 @@ class TrainingArguments:
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
integrations.
`"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"`
for no integrations.
ddp_find_unused_parameters (`bool`, *optional*):
When using distributed training, the value of the flag `find_unused_parameters` passed to
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
Expand Down Expand Up @@ -1187,7 +1187,9 @@ class TrainingArguments:

run_name: Optional[str] = field(
default=None,
metadata={"help": "An optional descriptor for the run. Notably used for wandb, mlflow and comet logging."},
metadata={
"help": "An optional descriptor for the run. Notably used for wandb, mlflow comet and swanlab logging."
},
)
disable_tqdm: Optional[bool] = field(
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
Expand Down Expand Up @@ -2848,8 +2850,8 @@ def set_logging(
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
`"none"` for no integrations.
`"neptune"`, `"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations
installed, `"none"` for no integrations.
first_step (`bool`, *optional*, defaults to `False`):
Whether to log and evaluate the first `global_step` or not.
nan_inf_filter (`bool`, *optional*, defaults to `True`):
Expand Down
Loading