-
Notifications
You must be signed in to change notification settings - Fork 614
/
Copy pathvision_utils.py
317 lines (263 loc) · 11.7 KB
/
vision_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Copyright (c) Alibaba, Inc. and its affiliates.
import base64
import math
import os
import re
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
import numpy as np
import requests
import torch
from PIL import Image
from swift.utils import get_env_args
# Try to import lmdb, but don't fail if it's not available
try:
import lmdb
LMDB_AVAILABLE = True
except ImportError:
LMDB_AVAILABLE = False
# >>> internvl
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def _build_transform(input_size):
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def _dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
if min_num <= i * j <= max_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size, ((i //
(target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# <<< internvl
def rescale_image(img: Image.Image, max_pixels: int) -> Image.Image:
import torchvision.transforms as T
width = img.width
height = img.height
if max_pixels is None or max_pixels <= 0 or width * height <= max_pixels:
return img
ratio = width / height
height_scaled = math.sqrt(max_pixels / ratio)
width_scaled = height_scaled * ratio
return T.Resize((int(height_scaled), int(width_scaled)))(img)
_T = TypeVar('_T')
# Cache for LMDB environments and read transactions to avoid reopening
_LMDB_ENV_CACHE: Dict[str, Any] = {}
_LMDB_TXN_CACHE: Dict[str, Any] = {}
def load_file(path: Union[str, bytes, _T]) -> Union[BytesIO, _T]:
res = path
if isinstance(path, str):
path = path.strip()
if path.startswith('http'):
request_kwargs = {}
timeout = float(os.getenv('TIMEOUT', '300'))
if timeout > 0:
request_kwargs['timeout'] = timeout
content = requests.get(path, **request_kwargs).content
res = BytesIO(content)
elif path.startswith('lmdb://'):
if not LMDB_AVAILABLE:
raise ImportError(
"LMDB support requires the 'lmdb' package to be installed. "
"Please install it with 'pip install lmdb'."
)
# Parse LMDB path format: lmdb://key@path_to_lmdb
_, _, lmdb_url = path.partition('lmdb://')
key, sep, lmdb_dir = lmdb_url.partition('@')
# Verify format validity with a single check
if not sep or not key or not lmdb_dir or '@' in lmdb_dir:
raise ValueError("LMDB path must be in format: lmdb://key@path_to_lmdb (with exactly one '@')")
# Use cached environment or create a new one
env = _LMDB_ENV_CACHE.get(lmdb_dir)
if env is None:
env = lmdb.open(lmdb_dir, readonly=True, lock=False, max_readers=1024, max_spare_txns=2)
_LMDB_ENV_CACHE[lmdb_dir] = env
# Get or create read transaction
txn = _LMDB_TXN_CACHE.get(lmdb_dir)
if txn is None:
txn = env.begin(write=False)
_LMDB_TXN_CACHE[lmdb_dir] = txn
# Get data using the cached transaction
encoded_key = key.encode()
data = txn.get(encoded_key)
if data is None:
raise KeyError(f"Key '{key}' not found in LMDB at '{lmdb_dir}'")
res = BytesIO(data)
elif os.path.exists(path) or (not path.startswith('data:') and len(path) <= 200):
path = os.path.abspath(os.path.expanduser(path))
with open(path, 'rb') as f:
res = BytesIO(f.read())
else: # base64_str
data = path
if data.startswith('data:'):
match_ = re.match(r'data:(.+?);base64,(.+)', data)
assert match_ is not None
data = match_.group(2)
data = base64.b64decode(data)
res = BytesIO(data)
elif isinstance(path, bytes):
res = BytesIO(path)
return res
def load_image(image: Union[str, bytes, Image.Image]) -> Image.Image:
image = load_file(image)
if isinstance(image, BytesIO):
image = Image.open(image)
if image.mode != 'RGB':
image = image.convert('RGB')
return image
def load_batch(path_list: List[Union[str, None, Any, BytesIO]],
load_func: Callable[[Any], _T] = load_image) -> List[_T]:
res = []
assert isinstance(path_list, (list, tuple)), f'path_list: {path_list}'
for path in path_list:
if path is None: # ignore None
continue
res.append(load_func(path))
return res
def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array(
[int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
return frame_indices
def transform_image(image, input_size=448, max_num=12):
transform = _build_transform(input_size=input_size)
images = _dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def load_video_internvl(video: Union[str, bytes], bound=None, num_segments=32):
from decord import VideoReader, cpu
video_io = load_file(video)
vr = VideoReader(video_io, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
images = []
frame_indices = _get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
for frame_index in frame_indices:
images.append(Image.fromarray(vr[frame_index].asnumpy()).convert('RGB'))
return images
def load_video_cogvlm2(video: Union[str, bytes]) -> np.ndarray:
from decord import cpu, VideoReader, bridge
video_io = load_file(video)
bridge.set_bridge('torch')
clip_end_sec = 60
clip_start_sec = 0
num_frames = get_env_args('num_frames', int, 24)
decord_vr = VideoReader(video_io, ctx=cpu(0))
duration = len(decord_vr) # duration in terms of frames
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(duration, int(clip_end_sec * decord_vr.get_avg_fps())) if \
clip_end_sec is not None else duration
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
def load_video_llava(video: Union[str, bytes]) -> np.ndarray:
import av
video_io = load_file(video)
container = av.open(video_io)
total_frames = container.streams.video[0].frames
num_frames = get_env_args('num_frames', int, 16)
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format='rgb24') for x in frames])
def load_video_minicpmv_mplug_owl3(video: Union[str, bytes], max_num_frames):
from decord import VideoReader, cpu # pip install decord
def uniform_sample(_l, _n):
gap = len(_l) / _n
idxs = [int(i * gap + gap / 2) for i in range(_n)]
return [_l[i] for i in idxs]
video_io = load_file(video)
vr = VideoReader(video_io, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > max_num_frames:
frame_idx = uniform_sample(frame_idx, max_num_frames)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames
def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False):
import librosa
audio_io = load_file(audio)
res = librosa.load(audio_io, sr=sampling_rate)
return res if return_sr else res[0]
def load_video_valley(video: Union[str, bytes]):
import decord
from torchvision import transforms
video_io = load_file(video)
video_reader = decord.VideoReader(video_io)
decord.bridge.set_bridge('torch')
video = video_reader.get_batch(np.linspace(0, len(video_reader) - 1, 8).astype(np.int_)).byte()
images = [transforms.ToPILImage()(image.permute(2, 0, 1)).convert('RGB') for image in video]
return images
def load_video_ovis2(video_path, num_frames):
from moviepy.editor import VideoFileClip
with VideoFileClip(video_path) as clip:
total_frames = int(clip.fps * clip.duration)
if total_frames <= num_frames:
sampled_indices = range(total_frames)
else:
stride = total_frames / num_frames
sampled_indices = [
min(total_frames - 1, int((stride * i + stride * (i + 1)) / 2)) for i in range(num_frames)
]
frames = [clip.get_frame(index / clip.fps) for index in sampled_indices]
frames = [Image.fromarray(frame, mode='RGB') for frame in frames]
return frames