forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama3_transform.py
executable file
·100 lines (87 loc) · 4.45 KB
/
llama3_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Mapping, Optional, Tuple
from torchtitan.datasets.tokenizer import Tokenizer
from ..multimodal import CLIPPreprocess, VisionCrossAttentionMask
# NOTE Inspired from torchtune.models.llama3_2_vision._transform.py
class Llama3VisionTransform:
"""
This class combines the transforms for the different modalities of Llama 3.2 Vision. It
performs the following transforms:
- Tokenizing the text field using :class:`torchtitan.datasets.tokenizer.titoken.TikTokenizer`
- Preprocessing the images for the CLIP encoder using :class:`torchtitan.datasets.multimodal.clip.ClipPreprocess`
- Generating the Vision Cross Attention mask for the Fused layers
using :class:`torchtitan.datasets.multimodal.utils.VisionCrossAttentionMask`
Args:
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
tile_size (int): Size of the tiles to divide the image into.
patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is
used to calculate the number of image embeddings per image.
max_num_tiles (int): Only used if possible_resolutions is NOT given.
Maximum number of tiles to break an image into.
This will be used to generate possible_resolutions,
e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224.
Default 4.
image_mean (Optional[Tuple[float, float, float]]): Mean values of each channel, used for normalization.
image_std (Optional[Tuple[float, float, float]]): Standard deviations for each channel, used for normalization.
Examples:
>>> model_transform = Llama3VisionTransform("/path/to/tokenizer.model", tile_size=224, patch_size=14)
>>> transformed_data = model_transform({"messages": user_message, "images": [img1, img2]})
>>> print(transformed_data["tokens"])
[1, 31587, 29644, 102, 2]
>>> print(transformed_data["images"][0].shape)
torch.Size([4, 3, 224, 224])
"""
def __init__(
self,
tokenizer: Tokenizer,
tile_size: int,
patch_size: int,
max_num_tiles: int = 4,
image_mean: Optional[Tuple[float, float, float]] = None,
image_std: Optional[Tuple[float, float, float]] = None,
):
self.tokenizer = tokenizer
self.transform_image = CLIPPreprocess(
image_mean=image_mean,
image_std=image_std,
tile_size=tile_size,
possible_resolutions=None,
max_num_tiles=max_num_tiles,
resample="bilinear",
resize_to_max_canvas=False,
)
self.xattn_mask = VisionCrossAttentionMask(
tile_size=tile_size,
patch_size=patch_size,
image_token_id=128256, # TODO(tj.solergibert) Hardcoded?
max_num_tiles=max_num_tiles,
)
self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1)
# TODO(tj.solergibert) self.pad_id = self.tokenizer.pad_id
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Apply image decoding, transformations and tokenization to messages in the sample.
Args:
sample (Mapping[str, Any]): A sample with a "messages" field.
Returns:
Mapping[str, Any]: The transformed sample with the following fields:
- tokens: List[int] of tokenized messages
- encoder_input: Dict[str, Any] of transformed images
- encoder_mask: List[bool] of masks for the transformed images
"""
encoder_input = {"images": [], "aspect_ratio": []}
for image in sample["images"]:
out = self.transform_image({"image": image})
encoder_input["images"].append(out["image"])
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
sample["encoder_input"] = encoder_input
sample = self.tokenizer.encode_multimodal(sample)
# TODO(tj.solergibert) What should we do (Include y/n & Mask y/n) with both bos & eos
# TODO(tj.solergibert) allowed_special to this fancy set OR set it to "all"?
sample = self.xattn_mask(sample)
return sample