-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathrun.py
126 lines (113 loc) · 5.66 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from run_on_video.data_utils import ClipFeatureExtractor
from run_on_video.model_utils import build_inference_model
from utils.tensor_utils import pad_sequences_1d
from moment_detr.span_utils import span_cxw_to_xx
from utils.basic_utils import l2_normalize_np_array
import torch.nn.functional as F
import numpy as np
class MomentDETRPredictor:
def __init__(self, ckpt_path, clip_model_name_or_path="ViT-B/32", device="cuda"):
self.clip_len = 2 # seconds
self.device = device
print("Loading feature extractors...")
self.feature_extractor = ClipFeatureExtractor(
framerate=1/self.clip_len, size=224, centercrop=True,
model_name_or_path=clip_model_name_or_path, device=device
)
print("Loading trained Moment-DETR model...")
self.model = build_inference_model(ckpt_path).to(self.device)
@torch.no_grad()
def localize_moment(self, video_path, query_list):
"""
Args:
video_path: str, path to the video file
query_list: List[str], each str is a query for this video
"""
# construct model inputs
n_query = len(query_list)
video_feats = self.feature_extractor.encode_video(video_path)
video_feats = F.normalize(video_feats, dim=-1, eps=1e-5)
n_frames = len(video_feats)
# add tef
tef_st = torch.arange(0, n_frames, 1.0) / n_frames
tef_ed = tef_st + 1.0 / n_frames
tef = torch.stack([tef_st, tef_ed], dim=1).to(self.device) # (n_frames, 2)
video_feats = torch.cat([video_feats, tef], dim=1)
assert n_frames <= 75, "The positional embedding of this pretrained MomentDETR only support video up " \
"to 150 secs (i.e., 75 2-sec clips) in length"
video_feats = video_feats.unsqueeze(0).repeat(n_query, 1, 1) # (#text, T, d)
video_mask = torch.ones(n_query, n_frames).to(self.device)
query_feats = self.feature_extractor.encode_text(query_list) # #text * (L, d)
query_feats, query_mask = pad_sequences_1d(
query_feats, dtype=torch.float32, device=self.device, fixed_length=None)
query_feats = F.normalize(query_feats, dim=-1, eps=1e-5)
model_inputs = dict(
src_vid=video_feats,
src_vid_mask=video_mask,
src_txt=query_feats,
src_txt_mask=query_mask
)
# decode outputs
outputs = self.model(**model_inputs)
# #moment_queries refers to the positional embeddings in MomentDETR's decoder, not the input text query
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #moment_queries=10, #classes=2)
scores = prob[..., 0] # * (batch_size, #moment_queries) foreground label is 0, we directly take it
pred_spans = outputs["pred_spans"] # (bsz, #moment_queries, 2)
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
saliency_scores = []
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
for j in range(len(valid_vid_lengths)):
_score = _saliency_scores[j, :int(valid_vid_lengths[j])].tolist()
_score = [round(e, 4) for e in _score]
saliency_scores.append(_score)
# compose predictions
predictions = []
video_duration = n_frames * self.clip_len
for idx, (spans, score) in enumerate(zip(pred_spans.cpu(), scores.cpu())):
spans = span_cxw_to_xx(spans) * video_duration
# # (#queries, 3), [st(float), ed(float), score(float)]
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
cur_query_pred = dict(
query=query_list[idx], # str
vid=video_path,
pred_relevant_windows=cur_ranked_preds, # List([st(float), ed(float), score(float)])
pred_saliency_scores=saliency_scores[idx] # List(float), len==n_frames, scores for each frame
)
predictions.append(cur_query_pred)
return predictions
def run_example():
# load example data
from utils.basic_utils import load_jsonl
video_path = "run_on_video/example/RoripwjYFp8_60.0_210.0.mp4"
query_path = "run_on_video/example/queries.jsonl"
queries = load_jsonl(query_path)
query_text_list = [e["query"] for e in queries]
ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"
# run predictions
print("Build models...")
clip_model_name_or_path = "ViT-B/32"
# clip_model_name_or_path = "tmp/ViT-B-32.pt"
moment_detr_predictor = MomentDETRPredictor(
ckpt_path=ckpt_path,
clip_model_name_or_path=clip_model_name_or_path,
device="cuda"
)
print("Run prediction...")
predictions = moment_detr_predictor.localize_moment(
video_path=video_path, query_list=query_text_list)
# print data
for idx, query_data in enumerate(queries):
print("-"*30 + f"idx{idx}")
print(f">> query: {query_data['query']}")
print(f">> video_path: {video_path}")
print(f">> GT moments: {query_data['relevant_windows']}")
print(f">> Predicted moments ([start_in_seconds, end_in_seconds, score]): "
f"{predictions[idx]['pred_relevant_windows']}")
print(f">> GT saliency scores (only localized 2-sec clips): {query_data['saliency_scores']}")
print(f">> Predicted saliency scores (for all 2-sec clip): "
f"{predictions[idx]['pred_saliency_scores']}")
if __name__ == "__main__":
run_example()