"""sam3 tracking command-line interface."""
import argparse
import os
from pathlib import Path
import torch
from psifx.utils.constants import SAM3_PATH
from psifx.utils.command import Command, register_command
from psifx.video.tracking.sam3.tool import Sam3TrackingTool
[docs]
class Sam3Command(Command):
"""
Command-line interface for running SAM3.
"""
[docs]
@staticmethod
def setup(parser: argparse.ArgumentParser):
"""
Sets up the command.
:param parser: The argument parser.
:return:
"""
from psifx.video.tracking.command import VisualizationTrackingCommand
subparsers = parser.add_subparsers(title="available commands")
register_command(subparsers, "inference", Sam3InferenceCommand)
register_command(subparsers, "visualization", VisualizationTrackingCommand)
[docs]
@staticmethod
def execute(parser: argparse.ArgumentParser, args: argparse.Namespace):
"""
Executes the command.
:param parser: The argument parser.
:param args: The arguments.
:return:
"""
parser.print_help()
[docs]
class Sam3InferenceCommand(Command):
"""
Command-line interface for tracking video elements with SAM3.
"""
[docs]
@staticmethod
def setup(parser: argparse.ArgumentParser):
"""
Sets up the command.
:param parser: The argument parser.
:return:
"""
parser.add_argument(
"--video",
type=Path,
required=True,
help="path to the input video file, such as ``/path/to/video.mp4`` (or .avi, .mkv, etc.)",
)
parser.add_argument(
"--mask_dir",
type=Path,
required=True,
help="path to the output mask directory, such as ``/path/to/mask_dir``",
)
parser.add_argument(
"--text_prompt",
type=str,
default="people",
help="text description of objects to track (e.g., 'people', 'cars', 'dogs')",
)
parser.add_argument(
"--chunk_size",
type=int,
default=300,
help="number of frames to process at once (lower values use less memory)",
)
parser.add_argument(
"--iou_threshold",
type=float,
default=0.3,
help="IoU threshold for stitching chunks together (0.0 to 1.0)",
)
parser.add_argument(
"--max_num_objects",
type=int,
default=None,
help=(
"optional cap on tracked object count "
"(e.g., set to 2 for a two-person interaction)"
),
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="device on which to run the inference, either 'cpu' or 'cuda'",
)
parser.add_argument(
"--model_path",
type=str,
default=SAM3_PATH,
help="SAM3 model id or local path (e.g. 'facebook/sam3' or '/path/to/sam3')",
)
parser.add_argument(
"--api_token",
type=str,
default=os.environ.get("HF_TOKEN"),
help="Hugging Face token (defaults to HF_TOKEN env var when available)",
)
parser.add_argument(
"--overwrite",
default=False,
action=argparse.BooleanOptionalAction,
help="overwrite existing files, otherwise raises an error",
)
parser.add_argument(
"--verbose",
default=True,
action=argparse.BooleanOptionalAction,
help="verbosity of the script",
)
[docs]
@staticmethod
def execute(parser: argparse.ArgumentParser, args: argparse.Namespace):
"""
Executes the command.
:param parser: The argument parser.
:param args: The arguments.
:return:
"""
tool = Sam3TrackingTool(
device=args.device,
model_path=args.model_path,
api_token=args.api_token,
max_num_objects=args.max_num_objects,
overwrite=args.overwrite,
verbose=args.verbose,
)
tool.infer(
video_path=args.video,
mask_dir=args.mask_dir,
text_prompt=args.text_prompt,
chunk_size=args.chunk_size,
iou_threshold=args.iou_threshold,
)
del tool