Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 committed Sep 6, 2023
1 parent 7038b9f commit 3b7d2eb
Show file tree
Hide file tree
Showing 23 changed files with 141 additions and 35 deletions.
8 changes: 5 additions & 3 deletions configs/multimodal/vindlu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ The last several years have witnessed remarkable progress in video-and-language

### Video Question-Answering on MSRVTT-QA

| frame sampling strategy | resolution | gpus | vision encoder | text encoder | pretraining | top1 acc | config | ckpt | log |
| :---------------------: | :--------: | :--: | :------------: | :----------: | :--------------------: | :------: | :------------------------------------------------: | :----------------------------------------------: | :-------: |
| uniform 12 | 224x224 | 8 | BEiT-Base | Bert-Base | C5M (WebVid-2M + CC3M) | xx.x | [config](/configs/multimodal/vindlu/vindlu_beit-base_8x32_vqa_msrvtt-qa.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/multimodal/) | [log](<>) |
| frame sampling strategy | resolution | gpus | vision encoder | text encoder | pretraining | top1 acc | config | ckpt | log |
| :---------------------: | :--------: | :--: | :------------: | :----------: | :--------------------: | :------: | :-----------------------------------: | :---------------------------------: | :---------------------------------: |
| uniform 12 | 224x224 | 8 | BEiT-Base | Bert-Base | C5M (WebVid-2M + CC3M) | 43.6 | [config](/configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa_20230906-6e693e64.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.log) |

### Multiple-Choice Question-Answering on MSRVTT-MC (Inference)

| frame sampling strategy | resolution | gpus | vision encoder | text encoder | pretraining | top1 acc | config | ckpt |
| :---------------------: | :--------: | :--: | :------------: | :----------: | :--------------------: | :------: | :----------------------------------------------------: | :---------------------------------------------------: |
| uniform 12 | 224x224 | 8 | BEiT-Base | Bert-Base | C5M (WebVid-2M + CC3M) | 97.6 | [config](/configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k_20230905-fc36231e.pth) |

1. Currently, we only support the fine-tuning stage of VindLU models based on the pretrained checkpoint provided by the [original repo](https://github.com/klauscc/VindLU).

For more details on data preparation, you can refer to [prepare msrvtt](/tools/data/msrvtt/README.md).

## Train
Expand Down
14 changes: 7 additions & 7 deletions configs/multimodal/vindlu/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ Models:
Training Log: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k.log
Weights: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k_20230905-fc36231e.pth

- Name: vindlu_beit-base_8x32_vqa_msrvtt-qa
Config: configs/multimodal/vindlu/vindlu_beit-base_8x32_vqa_msrvtt-qa.py
- Name: vindlu_beit-base_8x8_vqa_msrvtt-qa
Config: configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py
In Collection: VindLU
Metadata:
Architecture: BEiT-Base
Batch Size: 16
Epochs: 5
Batch Size: 8
Epochs: 10
Training Data: MSRVTT-qa
Training Resources: 8 GPUs
Results:
Dataset: MSRVTT
Task: Video Question-Answering
Metrics:
Top 1 Accuracy:
Training Log:
Weights:
Top 1 Accuracy: 43.6
Training Log: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu_beit-base_8x8_vqa_msrvtt-qa.log
Weights: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu_beit-base_8x8_vqa_msrvtt-qa_20230906-6e693e64.pth

- Name: vindlu_beit-base_vqa-mc_msrvtt-mc
Config: configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
video_root = 'data/msrvtt/videos_2fps_224'
anno_file_train = 'data/msrvtt/annotations/msrvtt_ret_train9k.json'
anno_file_test = 'data/msrvtt/annotations/msrvtt_ret_test1k.json'
pretrained_ckpt_path = 'checkpoints/5M-pretrain.pth'
pretrained_ckpt_url = 'https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_c5m_pretrain' # noqa: E501

# model settings
model = dict(
type='VindLURetrieval',
gradient_checkpointing=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_ckpt_path),
init_cfg=dict(type='Pretrained', checkpoint=pretrained_ckpt_url),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[128],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
_base_ = ['../../_base_/default_runtime.py']

video_root = 'data/msrvtt/videos_2fps_224'
pretrained_ckpt_path = 'checkpoints/5M-pretrain.pth'
anno_file_train = 'data/msrvtt/annotations/msrvtt_qa_train.json'
anno_file_val = 'data/msrvtt/annotations/msrvtt_qa_val.json'
anno_file_test = 'data/msrvtt/annotations/msrvtt_qa_test.json'
answer_list_file = 'data/msrvtt/annotations/msrvtt_qa_answer_list.json'
pretrained_ckpt_url = 'https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_c5m_pretrain' # noqa: E501

# model settings
model = dict(
type='VindLUVQA',
init_cfg=dict(type='Pretrained', checkpoint=pretrained_ckpt_path),
init_cfg=dict(type='Pretrained', checkpoint=pretrained_ckpt_url),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[128],
Expand Down Expand Up @@ -56,10 +56,19 @@

train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=12),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=12,
out_of_bound_opt='repeat_last'),
dict(type='DecordDecode'),
dict(type='RandomResizedCrop', area_range=(0.5, 1.0)),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(
type='Resize',
scale=(224, 224),
keep_ratio=False,
interpolation='bicubic'),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(
Expand All @@ -79,9 +88,14 @@
clip_len=1,
frame_interval=1,
num_clips=12,
test_mode=True),
test_mode=True,
out_of_bound_opt='repeat_last'),
dict(type='DecordDecode'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(
type='Resize',
scale=(224, 224),
keep_ratio=False,
interpolation='bicubic'),
dict(type='FormatShape', input_format='NCHW'),
dict(
type='PackActionInputs',
Expand All @@ -97,7 +111,7 @@
dataset_type = 'MSRVTTVQA'

train_dataloader = dict(
batch_size=32,
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
Expand Down
2 changes: 1 addition & 1 deletion mmaction/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from .acc_metric import AccMetric, ConfusionMatrix
from .anet_metric import ANetMetric
from .ava_metric import AVAMetric
from .multimodal_metric import VQAMCACC, ReportVQA, RetrievalRecall, VQAAcc
from .multisports_metric import MultiSportsMetric
from .retrieval_metric import RetrievalMetric
from .vqa_metric import VQAMCACC, ReportVQA, RetrievalRecall, VQAAcc

__all__ = [
'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix',
Expand Down
File renamed without changes.
20 changes: 10 additions & 10 deletions mmaction/models/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beit3d import BeitModel3D
from .tokenizer import BertTokenizer
from .vindlu_ret import VindLURetrieval
from .vindlu_ret_mc import VindLURetrievalMC
from .vindlu_vqa import VindLUVQA
from .xbert import BertDecoder, BertModel
from mmaction.utils.dependency import WITH_MULTIMODAL

__all__ = [
'VindLUVQA', 'BertTokenizer', 'BertModel', 'BertDecoder', 'BeitModel3D',
'VindLURetrievalMC', 'VindLURetrieval'
]
if WITH_MULTIMODAL:
from .vindlu import * # noqa: F401,F403

else:
from mmaction.registry import MODELS
from mmaction.utils.dependency import register_multimodal_placeholder

register_multimodal_placeholder(
['VindLUVQA', 'VindLURetrievalMC', 'VindLURetrieval'], MODELS)
12 changes: 12 additions & 0 deletions mmaction/models/multimodal/vindlu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beit3d import BeitModel3D
from .tokenizer import VindLUTokenizer
from .vindlu_ret import VindLURetrieval
from .vindlu_ret_mc import VindLURetrievalMC
from .vindlu_vqa import VindLUVQA
from .xbert import BertDecoder, BertModel

__all__ = [
'VindLUVQA', 'VindLURetrievalMC', 'VindLURetrieval', 'VindLUTokenizer',
'BeitModel3D', 'BertDecoder', 'BertModel'
]
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,5 @@ def build_inputs_with_special_tokens(
return cls + token_ids_0 + sep + token_ids_1 + sep


TOKENIZER.register_module(
'BertTokenizer', module=BertTokenizer.from_pretrained)
TOKENIZER.register_module(
'VindLUTokenizer', module=VindLUTokenizer.from_pretrained)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def loss(self, inputs, data_samples):
weights = torch.cat(
[torch.tensor(sample.gt_answer_weight) for sample in data_samples],
dim=0).to(inputs.device)
# answers = [sample.gt_answer for sample in data_samples]
raw_answers = []
for sample in data_samples:
raw_answers.extend(sample.gt_answer)
Expand Down
File renamed without changes.
81 changes: 81 additions & 0 deletions mmaction/utils/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from functools import wraps
from inspect import isfunction

from importlib_metadata import PackageNotFoundError, distribution
from mmengine.utils import digit_version


def satisfy_requirement(dep):
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, dep, maxsplit=1)
parts = [p.strip() for p in parts]
package = parts[0]
if len(parts) > 1:
op, version = parts[1:]
op = {
'>=': '__ge__',
'==': '__eq__',
'>': '__gt__',
'<': '__lt__',
'<=': '__le__'
}[op]
else:
op, version = None, None

try:
dist = distribution(package)
if op is None or getattr(digit_version(dist.version), op)(
digit_version(version)):
return True
except PackageNotFoundError:
pass

return False


def require(dep, install=None):
"""A wrapper of function for extra package requirements.
Args:
dep (str): The dependency package name, like ``transformers``
or ``transformers>=4.28.0``.
install (str, optional): The installation command hint. Defaults
to None, which means to use "pip install dep".
"""

def wrapper(fn):
assert isfunction(fn)

@wraps(fn)
def ask_install(*args, **kwargs):
name = fn.__qualname__.replace('.__init__', '')
ins = install or f'pip install "{dep}"'
raise ImportError(
f'{name} requires {dep}, please install it by `{ins}`.')

if satisfy_requirement(dep):
fn._verify_require = getattr(fn, '_verify_require', lambda: None)
return fn

ask_install._verify_require = ask_install
return ask_install

return wrapper


WITH_MULTIMODAL = all(
satisfy_requirement(item) for item in ['transformers>=4.28.0'])


def register_multimodal_placeholder(names, registry):
for name in names:

def ask_install(*args, **kwargs):
raise ImportError(
f'{name} requires extra multi-modal dependencies, please '
'install it by `pip install "mmaction2[multimodal]"` '
'or `pip install -e ".[multimodal]"`.')

registry.register_module(name=name, module=ask_install)
1 change: 1 addition & 0 deletions requirements/multimodal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers>=4.28.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,6 @@ def add_mim_extension():
'tests': parse_requirements('requirements/tests.txt'),
'optional': parse_requirements('requirements/optional.txt'),
'mim': parse_requirements('requirements/mminstall.txt'),
'multimodal': parse_requirements('requirements/multimodal.txt'),
},
zip_safe=False)
1 change: 0 additions & 1 deletion tools/data/msrvtt/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
用户可参考该数据集的[官网](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/),以获取数据集相关的基本信息。运行下面的命令准备 MSRVTT 视频文件:

```shell
cd $MMACTION2/tools/data/msrvtt/
# download original videos
bash download_msrvtt.sh
# preprocess videos to lower FPS and dimension
Expand Down
1 change: 0 additions & 1 deletion tools/data/msrvtt/compress_msrvtt.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env bash


FPS=2
SIZE=224
DATA_DIR="../../../data/msrvtt/videos"
Expand Down

0 comments on commit 3b7d2eb

Please sign in to comment.