@inproceedings{doersch2023tapir,
title = {TAPIR: Tracking any point with per-frame initialization and temporal refinement},
author = {Doersch, Carl and Yang, Yi and Vecerik, Mel and Gokay, Dilara and Gupta, Ankush and Aytar, Yusuf and Carreira, Joao and Zisserman, Andrew},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages = {10061--10072},
year = {2023}
}
In the repository the authors shared a number of colab notebooks. In the video, I used the standard TAPIR with jax implementation. I run it on a colab TPU environment. I also forked the repository to add print statements to check the tensors shapes while executing the method. To keep the size of this markdown file small, I’m going to share the python code along with the output. You can copy and paste them into a jupyter notebook.
!pip install git+https://github.com/mashaan14/tapnet.git
MODEL_TYPE = 'tapir' # 'tapir' or 'bootstapir'
%mkdir tapnet/checkpoints
if MODEL_TYPE == "tapir":
!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
else:
!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy
%ls tapnet/checkpoints
mkdir: cannot create directory ‘tapnet/checkpoints’: No such file or directory
--2025-02-02 07:37:37-- https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.183.207, 64.233.179.207, 142.251.184.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.183.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 124307770 (119M) [application/octet-stream]
Saving to: ‘tapnet/checkpoints/tapir_checkpoint_panning.npy’
tapir_checkpoint_pa 100%[===================>] 118.55M 174MB/s in 0.7s
2025-02-02 07:37:38 (174 MB/s) - ‘tapnet/checkpoints/tapir_checkpoint_panning.npy’ saved [124307770/124307770]
tapir_checkpoint_panning.npy
#%matplotlib widget
from google.colab import output
import jax
import matplotlib
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tapnet.models import tapir_model
from tapnet.utils import model_utils
from tapnet.utils import transforms
from tapnet.utils import viz_utils
output.enable_custom_widget_manager()
if MODEL_TYPE == 'tapir':
checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
else:
checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']
kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
if MODEL_TYPE == 'bootstapir':
kwargs.update(
dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)
)
tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)
%mkdir tapnet/examplar_videos
!wget -P tapnet/examplar_videos http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4
video = media.read_video("tapnet/examplar_videos/horsejump-high.mp4")
media.show_video(video, fps=10)
--2025-02-02 07:38:39-- http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.183.207, 64.233.179.207, 142.251.184.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.183.207|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 467706 (457K) [video/mp4]
Saving to: ‘tapnet/examplar_videos/horsejump-high.mp4’
horsejump-high.mp4 100%[===================>] 456.74K --.-KB/s in 0.001s
2025-02-02 07:38:39 (316 MB/s) - ‘tapnet/examplar_videos/horsejump-high.mp4’ saved [467706/467706]
def inference(frames, query_points):
"""Inference on one video.
Args:
frames: [num_frames, height, width, 3], [0, 255], np.uint8
query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]
Returns:
tracks: [num_points, 3], [-1, 1], [t, y, x]
visibles: [num_points, num_frames], bool
"""
# Preprocess video to match model inputs format
frames = model_utils.preprocess_frames(frames)
query_points = query_points.astype(np.float32)
frames, query_points = frames[None], query_points[None] # Add batch dimension
print(f'frames: {frames.shape}')
print(f'query_points: {query_points.shape}')
outputs = tapir(
video=frames,
query_points=query_points,
is_training=False,
query_chunk_size=32,
)
tracks, occlusions, expected_dist = (
outputs['tracks'],
outputs['occlusion'],
outputs['expected_dist'],
)
# Binarize occlusions
visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
return tracks[0], visibles[0]
inference = jax.jit(inference)
def sample_random_points(frame_max_idx, height, width, num_points):
"""Sample random points with (time, height, width) order."""
y = np.random.randint(0, height, (num_points, 1))
x = np.random.randint(0, width, (num_points, 1))
t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
points = np.concatenate((t, y, x), axis=-1).astype(
np.int32
) # [num_points, 3]
return points
resize_height = 256 # @param {type: "integer"}
resize_width = 256 # @param {type: "integer"}
num_points = 100 # @param {type: "integer"}
frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(
0, frames.shape[1], frames.shape[2], num_points
)
tracks, visibles = inference(frames, query_points)
tracks = np.array(tracks)
visibles = np.array(visibles)
# Visualize sparse point tracks
height, width = video.shape[1:3]
tracks = transforms.convert_grid_coordinates(
tracks, (resize_width, resize_height), (width, height)
)
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)
Here’s the output that I got, after adding print statements to my fork:
frames: (1, 50, 256, 256, 3)
query_points: (1, 100, 3)
***********************************************
get_feature_grids function
***********************************************
refinement_resolutions
256
256
feature_grid: (1, 50, 32, 32, 256)
feature_grid: (1, 50, 32, 32, 256)
hires_feats: (1, 50, 64, 64, 128)
hires_feats: (1, 50, 64, 64, 128)
-----------------------------------------------
***********************************************
get_query_features function
***********************************************
hires_query_feats: (1, 100, 128)
hires_query_feats: (1, 100, 128)
query_feats: (1, 100, 256)
query_feats: (1, 100, 256)
-----------------------------------------------
***********************************************
estimate_trajectories function
***********************************************
***********************************************
estimate_trajectories-->tracks_from_cost_volume functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->tracks_from_cost_volume functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->tracks_from_cost_volume functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->tracks_from_cost_volume functions
points: (1, 4, 50, 2)
occlusion: (1, 4, 50)
expected_dist: (1, 4, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
You might notice that estimate_trajectories
was run for four times with a number of points $32, 32, 32,$ and $4$. That’s because we passed query_chunk_size=32
to the method. So, it split the $100$ query points into four batches of size $32$ except for the last batch. I explained in these notes a single run of estimate_trajectories
function, and should not be different for other batches.