-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2038e94
Showing
47 changed files
with
8,849 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# project-specific | ||
output/ | ||
debug*/ | ||
*.bak | ||
*.dir | ||
*.dat | ||
*.tsv | ||
*.gz | ||
|
||
# cache root | ||
cache/ | ||
|
||
# DS_Store | ||
**/.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import argparse | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
|
||
import lavis.tasks as tasks | ||
from lavis.common.config import Config | ||
from lavis.common.dist_utils import get_rank, init_distributed_mode | ||
from lavis.common.logger import setup_logger | ||
from lavis.common.optims import ( | ||
LinearWarmupCosineLRScheduler, | ||
LinearWarmupStepLRScheduler, | ||
) | ||
from lavis.common.utils import now | ||
|
||
# imports modules for registration | ||
from lavis.datasets.builders import * | ||
from lavis.models import * | ||
from lavis.processors import * | ||
from lavis.runners.runner_base import RunnerBase | ||
from lavis.tasks import * | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Training") | ||
|
||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | ||
parser.add_argument( | ||
"--options", | ||
nargs="+", | ||
help="override some settings in the used config, the key-value pair " | ||
"in xxx=yyy format will be merged into config file (deprecate), " | ||
"change to --cfg-options instead.", | ||
) | ||
|
||
args = parser.parse_args() | ||
# if 'LOCAL_RANK' not in os.environ: | ||
# os.environ['LOCAL_RANK'] = str(args.local_rank) | ||
|
||
return args | ||
|
||
|
||
def setup_seeds(config): | ||
seed = config.run_cfg.seed + get_rank() | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
cudnn.benchmark = False | ||
cudnn.deterministic = True | ||
|
||
|
||
def main(): | ||
# allow auto-dl completes on main process without timeout when using NCCL backend. | ||
# os.environ["NCCL_BLOCKING_WAIT"] = "1" | ||
|
||
# set before init_distributed_mode() to ensure the same job_id shared across all ranks. | ||
job_id = now() | ||
|
||
cfg = Config(parse_args()) | ||
|
||
init_distributed_mode(cfg.run_cfg) | ||
|
||
setup_seeds(cfg) | ||
|
||
# set after init_distributed_mode() to only log on master. | ||
setup_logger() | ||
|
||
cfg.pretty_print() | ||
|
||
task = tasks.setup_task(cfg) | ||
datasets = task.build_datasets(cfg) | ||
model = task.build_model(cfg) | ||
|
||
runner = RunnerBase( | ||
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets | ||
) | ||
runner.evaluate(skip_reload=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
from omegaconf import OmegaConf | ||
|
||
from lavis.common.registry import registry | ||
|
||
from lavis.datasets.builders import * | ||
from lavis.models import * | ||
from lavis.processors import * | ||
from lavis.tasks import * | ||
|
||
|
||
root_dir = os.path.dirname(os.path.abspath(__file__)) | ||
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) | ||
|
||
registry.register_path("library_root", root_dir) | ||
repo_root = os.path.join(root_dir, "..") | ||
registry.register_path("repo_root", repo_root) | ||
cache_root = os.path.join(repo_root, default_cfg.env.cache_root) | ||
registry.register_path("cache_root", cache_root) | ||
|
||
registry.register("MAX_INT", sys.maxsize) | ||
registry.register("SPLIT_NAMES", ["train", "val", "test"]) |
Oops, something went wrong.