Mashaan blog

Point Tracking using TAPIR

References

@inproceedings{doersch2023tapir,
  title      = {TAPIR: Tracking any point with per-frame initialization and temporal refinement},
  author     = {Doersch, Carl and Yang, Yi and Vecerik, Mel and Gokay, Dilara and Gupta, Ankush and Aytar, Yusuf and Carreira, Joao and Zisserman, Andrew},
  booktitle  = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages      = {10061--10072},
  year       = {2023}
}

Running TAPIR notebook

In the repository the authors shared a number of colab notebooks. In the video, I used the standard TAPIR with jax implementation. I run it on a colab TPU environment. I also forked the repository to add print statements to check the tensors shapes while executing the method. To keep the size of this markdown file small, I’m going to share the python code along with the output. You can copy and paste them into a jupyter notebook.

Install dependencies:

!pip install git+https://github.com/mashaan14/tapnet.git
MODEL_TYPE = 'tapir'  # 'tapir' or 'bootstapir'

Download model

%mkdir tapnet/checkpoints

if MODEL_TYPE == "tapir":
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
else:
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy

%ls tapnet/checkpoints
mkdir: cannot create directory ‘tapnet/checkpoints’: No such file or directory
--2025-02-02 07:37:37--  https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.183.207, 64.233.179.207, 142.251.184.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.183.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 124307770 (119M) [application/octet-stream]
Saving to: ‘tapnet/checkpoints/tapir_checkpoint_panning.npy’

tapir_checkpoint_pa 100%[===================>] 118.55M   174MB/s    in 0.7s    

2025-02-02 07:37:38 (174 MB/s) - ‘tapnet/checkpoints/tapir_checkpoint_panning.npy’ saved [124307770/124307770]

tapir_checkpoint_panning.npy

Imports

#%matplotlib widget
from google.colab import output
import jax
import matplotlib
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tapnet.models import tapir_model
from tapnet.utils import model_utils
from tapnet.utils import transforms
from tapnet.utils import viz_utils

output.enable_custom_widget_manager()

Load checkpoint

if MODEL_TYPE == 'tapir':
  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
else:
  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
if MODEL_TYPE == 'bootstapir':
  kwargs.update(
      dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)
  )
tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)

Load the video

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

video = media.read_video("tapnet/examplar_videos/horsejump-high.mp4")
media.show_video(video, fps=10)
--2025-02-02 07:38:39--  http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.183.207, 64.233.179.207, 142.251.184.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.183.207|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 467706 (457K) [video/mp4]
Saving to: ‘tapnet/examplar_videos/horsejump-high.mp4’

horsejump-high.mp4  100%[===================>] 456.74K  --.-KB/s    in 0.001s  

2025-02-02 07:38:39 (316 MB/s) - ‘tapnet/examplar_videos/horsejump-high.mp4’ saved [467706/467706]

Utility functions

def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = model_utils.preprocess_frames(frames)
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  print(f'frames: {frames.shape}')
  print(f'query_points: {query_points.shape}')
  outputs = tapir(
      video=frames,
      query_points=query_points,
      is_training=False,
      query_chunk_size=32,
  )
  tracks, occlusions, expected_dist = (
      outputs['tracks'],
      outputs['occlusion'],
      outputs['expected_dist'],
  )

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]


inference = jax.jit(inference)


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(
      np.int32
  )  # [num_points, 3]
  return points

Predict sparse point tracks

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 100  # @param {type: "integer"}

frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(
    0, frames.shape[1], frames.shape[2], num_points
)
tracks, visibles = inference(frames, query_points)
tracks = np.array(tracks)
visibles = np.array(visibles)

# Visualize sparse point tracks
height, width = video.shape[1:3]
tracks = transforms.convert_grid_coordinates(
    tracks, (resize_width, resize_height), (width, height)
)
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)

Here’s the output that I got, after adding print statements to my fork:

frames: (1, 50, 256, 256, 3)
query_points: (1, 100, 3)
***********************************************
get_feature_grids function
***********************************************
refinement_resolutions
256
256
feature_grid: (1, 50, 32, 32, 256)
feature_grid: (1, 50, 32, 32, 256)
hires_feats: (1, 50, 64, 64, 128)
hires_feats: (1, 50, 64, 64, 128)
-----------------------------------------------
***********************************************
get_query_features function
***********************************************
hires_query_feats: (1, 100, 128)
hires_query_feats: (1, 100, 128)
query_feats: (1, 100, 256)
query_feats: (1, 100, 256)
-----------------------------------------------
***********************************************
estimate_trajectories function
***********************************************
***********************************************
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 32, 50, 2)
occlusion: (1, 32, 50)
expected_dist: (1, 32, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 32, 50, 2)
occ_guess_input: (1, 32, 50, 1)
expd_guess_input: (1, 32, 50, 1)
mlp_input_features: (1, 32, 50, 384)
corrs_chunked: (1, 32, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->tracks_from_cost_volume  functions
points: (1, 4, 50, 2)
occlusion: (1, 4, 50)
expected_dist: (1, 4, 50)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------
***********************************************
estimate_trajectories-->refine_pips  functions
pos_guess_input: (1, 4, 50, 2)
occ_guess_input: (1, 4, 50, 1)
expd_guess_input: (1, 4, 50, 1)
mlp_input_features: (1, 4, 50, 384)
corrs_chunked: (1, 4, 50, 98)
-----------------------------------------------

Understanding tensor shapes

You might notice that estimate_trajectories was run for four times with a number of points $32, 32, 32,$ and $4$. That’s because we passed query_chunk_size=32 to the method. So, it split the $100$ query points into four batches of size $32$ except for the last batch. I explained in these notes a single run of estimate_trajectories function, and should not be different for other batches.

001

002

003