Skip to content

分割标注中Sam类set_video_list方法形式参数与SAM2原始仓库不符 & sam_tools.py中forward_sam_multi_stage函数存在的问题 #2

@FelixxLuo

Description

@FelixxLuo

作者您好,感谢开源。我在尝试使用分割标注时遇到了一些报错,在排查的过程中发现了代码中的一些问题,想在此探讨。

1. Sam类set_video_list方法形式参数数量问题

这里传递了四个参数 (https://github.com/InternRobotics/RoboInter/blob/main/RoboInterTools/tools/sam.py#L96), 但SAM2原始仓库只需要三个参数 (https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_video_predictor.py#L42)

2. sam_tools.py中forward_sam_multi_stage函数存在的问题

sam_tools.py中forward_sam_multi_stage函数调用了set_video_list方法(https://github.com/InternRobotics/RoboInter/blob/main/RoboInterTools/tools/sam_tools.py#L97), 因此并不能成功运行

3. 问题探讨与可能的解决方案 (一孔之见,仅供参考)

forward_sam_multi_stage函数根据forward或是backward将视频分了段, 但SAM2原始仓库只能通过读取video_path加载视频 (https://github.com/facebookresearch/sam2/blob/main/sam2/utils/misc.py#L172), 无法将分好段的视频video_part直接传递sam_model。

因此,我考虑的做法是修改Sam类set_video_list方法,删去形式参数video_list。另外,修改sam_tools.py中forward_sam_multi_stage函数,将video_part暂存成.mp4文件,然后读取该文件。

修改Sam类set_video_list方法,删去形式参数video_list

def set_video_list(self, video_path):
    if video_path[-3:] !="mp4":
        video_path = video_path + ".mp4"
    self.inference_state = self.predictor.init_state(video_path, offload_video_to_cpu=True, offload_state_to_cpu=True)

修改sam_tools.py中forward_sam_multi_stage函数,将video_part暂存成.mp4文件,然后读取该文件 (需要在开头import os, shutil)

def forward_sam_multi_stage(model_config, model_sam):
    video_path = model_config["video_path"]
    is_video = model_config["is_video"]
    select_frame = model_config["select_frame"]
    direction = model_config["direction"]

    # ----------------------修改开始-----------------------------------------
    temp_save_dir = video_path.rsplit(".", 1)[0] + "_cache"
    if not os.path.exists(temp_save_dir):
        os.makedirs(temp_save_dir)
    # ----------------------修改结束-----------------------------------------
    
    if 'ann_human' in video_path:
        video_path = model_config['origin_video_path']
    video = extract_frames(video_path)
    
    if not is_video:
        video = video[select_frame:select_frame + 1]
    elif direction == "forward":
        video = video[select_frame:]     
    elif direction == "backward":
        video = video[:select_frame+1][::-1]
    
    positive_points_dict = model_config["positive_points"][select_frame]
    negative_points_dict = model_config["negative_points"][select_frame]
    labels_dict = model_config["labels"][select_frame]

    positive_points = [np.array(positive_points_dict[obj_idx]) for obj_idx in positive_points_dict.keys()]
    negative_points = [np.array(negative_points_dict[obj_idx]) for obj_idx in positive_points_dict.keys()]
    labels = [labels_dict[obj_idx] for obj_idx in positive_points_dict.keys()]

    for i in range(len(positive_points)):
        if len(positive_points[i]) == 0:
            raise ValueError("No positive points in the frame")
        if len(negative_points[i]) != 0:
            positive_points[i] = np.concatenate([positive_points[i], negative_points[i]], axis=0)
    
    # if length of video is larger than 800, we split the video into N/800 parts
    num_parts = int(len(video) / 800) + 1
    masks_all = np.zeros((len(positive_points), len(video), 1, video[0].shape[0], video[0].shape[1]))
    ind_all = np.arange(len(video))
    
    for i in range(num_parts):
        video_part = video[i::num_parts]
        ind_part = ind_all[i::num_parts]
        
        # must contain the first frame
        video_part = np.concatenate([video[:1], video_part], axis=0)
        ind_part = np.concatenate([np.zeros_like(ind_part[:1]), ind_part], axis=0)       
        
        
        # ----------------------修改开始-----------------------------------------
        video_part_path = os.path.join(temp_save_dir, f"part_{i}.mp4")
        height, width, _ = video_part[0].shape
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(video_part_path, fourcc, 20, (width, height))

        for frame in video_part:
            out.write(frame)
        out.release()
        
        model_sam.set_video_list(video_part_path)
        masks_all[:, ind_part] = model_sam(positive_points, labels, 0, list(positive_points_dict.keys()))
        
        if os.path.exists(temp_save_dir):
            shutil.rmtree(temp_save_dir)
        # ----------------------修改结束-----------------------------------------

    return masks_all

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions