Mashaan blog

Point Tracking using TAPIR


  title      = {TAPIR: Tracking any point with per-frame initialization and temporal refinement},
  author     = {Doersch, Carl and Yang, Yi and Vecerik, Mel and Gokay, Dilara and Gupta, Ankush and Aytar, Yusuf and Carreira, Joao and Zisserman, Andrew},
  booktitle  = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages      = {10061--10072},
  year       = {2023}

Running TAPIR notebook

In the repository the authors shared a number of colab notebooks. In the video, I used the standard TAPIR with jax implementation. I run it on a colab TPU environment. I also forked the repository to add print statements to check the tensors shapes while executing the method. To keep the size of this markdown file small, I’m going to share the python code along with the output. You can copy and paste them into a jupyter notebook.

Install dependencies:

!pip install git+
MODEL_TYPE = 'tapir'  # 'tapir' or 'bootstapir'

Download model

%mkdir tapnet/checkpoints

if MODEL_TYPE == "tapir":
  !wget -P tapnet/checkpoints
  !wget -P tapnet/checkpoints

%ls tapnet/checkpoints
mkdir: cannot create directory ‘tapnet/checkpoints’: No such file or directory
--2025-02-02 07:37:37--
Resolving (,,, ...
Connecting to (||:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 124307770 (119M) [application/octet-stream]
Saving to: ‘tapnet/checkpoints/tapir_checkpoint_panning.npy’

tapir_checkpoint_pa 100%[===================>] 118.55M   174MB/s    in 0.7s    

2025-02-02 07:37:38 (174 MB/s) - ‘tapnet/checkpoints/tapir_checkpoint_panning.npy’ saved [124307770/124307770]



#%matplotlib widget
from google.colab import output
import jax
import matplotlib
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tapnet.models import tapir_model
from tapnet.utils import model_utils
from tapnet.utils import transforms
from tapnet.utils import viz_utils


Load checkpoint

if MODEL_TYPE == 'tapir':
  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
if MODEL_TYPE == 'bootstapir':
      dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)
tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)

Load the video

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos

video = media.read_video("tapnet/examplar_videos/horsejump-high.mp4")
media.show_video(video, fps=10)
--2025-02-02 07:38:39--
Resolving (,,, ...
Connecting to (||:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 467706 (457K) [video/mp4]
Saving to: ‘tapnet/examplar_videos/horsejump-high.mp4’

horsejump-high.mp4  100%[===================>] 456.74K  --.-KB/s    in 0.001s  

2025-02-02 07:38:39 (316 MB/s) - ‘tapnet/examplar_videos/horsejump-high.mp4’ saved [467706/467706]

Utility functions

def inference(frames, query_points):
  """Inference on one video.

    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  # Preprocess video to match model inputs format
  frames = model_utils.preprocess_frames(frames)
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  print(f'frames: {frames.shape}')
  print(f'query_points: {query_points.shape}')
  outputs = tapir(
  tracks, occlusions, expected_dist = (

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]

inference = jax.jit(inference)

def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(
  )  # [num_points, 3]
  return points

Predict sparse point tracks

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 100  # @param {type: "integer"}

frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(
    0, frames.shape[1], frames.shape[2], num_points
tracks, visibles = inference(frames, query_points)
tracks = np.array(tracks)
visibles = np.array(visibles)

# Visualize sparse point tracks
height, width = video.shape[1:3]
tracks = transforms.convert_grid_coordinates(
    tracks, (resize_width, resize_height), (width, height)
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)

Here’s the output that I got, after adding print statements to my fork:

frames: (1, 50, 256, 256, 3)
query_points: (1, 100, 3)
get_feature_grids function
feature_grid: (1, 50, 32, 32, 256)
feature_grid: (1, 50, 32, 32, 256)
hires_feats: (1, 50, 64, 64, 128)
hires_feats: (1, 50, 64, 64, 128)
get_query_features function
hires_query_feats: (1, 100, 128)
hires_query_feats: (1, 100, 128)
query_feats: (1, 100, 256)
query_feats: (1, 100, 256)
estimate_trajectories function
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 4, 50, 2)
occlusion: (1, 4, 50)
expected_dist: (1, 4, 50)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)

Understanding tensor shapes

You might notice that estimate_trajectories was run for four times with a number of points $32, 32, 32,$ and $4$. That’s because we passed query_chunk_size=32 to the method. So, it split the $100$ query points into four batches of size $32$ except for the last batch. I explained in these notes a single run of estimate_trajectories function, and should not be different for other batches.