Seeking Assistance: PyTorchStreamReader Error with Fine-tuned SAM Model on Roboflow Dataset

I am seeking assistance with an issue I’m encountering after fine-tuning a Segment Anything Model (SAM) using a Roboflow dataset. My end goal is to use this custom-trained model to segment kitchen variants from uploaded images.

I have fine-tuned the SAM model and am now attempting to load it for inference. However, when executing my Python script, I receive the following error:

Using device: cpu
Loading SAM model from models/finetuned_sam_model.pth...
Traceback (most recent call last):
  File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/onlySam.py", line 29, in <module>
    checkpoint = torch.load(SAM_CHECKPOINT_PATH, map_location=torch.device('cpu'))
  File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/venv/lib/python3.10/site-packages/torch/serialization.py", line 1486, in load
    with _open_zipfile_reader(opened_file) as opened_zipfile:
  File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/venv/lib/python3.10/site-packages/torch/serialization.py", line 771, in __init__
    super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/onlySam.py", line 39, in <module>
    raise RuntimeError(f"Failed to load SAM model. Ensure you are using the official checkpoint file. Error: {e}")
RuntimeError: Failed to load SAM model. Ensure you are using the official checkpoint file. Error: PytorchStreamReader failed reading zip archive: failed finding central directory

It appears to be a RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory, suggesting an issue with the .pth file itself or how PyTorch is attempting to read it.

I am attaching my training script and the relevant inference code snippet for your reference:

Training Script:

import os
import torch
import roboflow
import supervision as sv
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
import cv2
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- 1. Initial Setup ---
HOME = os.getcwd()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_TYPE = "vit_h"

print(f"Home directory: {HOME}")
print(f"Using device: {DEVICE}")

# --- 2. Download Pre-trained SAM Checkpoint ---
print("\nDownloading the pre-trained SAM model (ViT-H)...")
!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(f"SAM checkpoint downloaded to: {CHECKPOINT_PATH}")

# --- 3. Download Custom Dataset from Roboflow ---
# ‼️ PASTE YOUR ROBOFLOW CREDENTIALS HERE ‼️
ROBOFLOW_API_KEY = "ROBOFLOW_API_KEY"  # Replace with your key
WORKSPACE_ID = "WORKSPACE_ID" # Replace with your workspace ID
PROJECT_ID = "PROJECT_ID"   # Replace with your project ID
VERSION_ID = 1                   # Replace with your dataset version number

try:
    rf = roboflow.Roboflow(api_key=ROBOFLOW_API_KEY)
    project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
    dataset = project.version(VERSION_ID).download("coco")
    DATASET_PATH = dataset.location
    print(f"\nDataset downloaded successfully to: {DATASET_PATH}")
except Exception as e:
    print(f"\n❌ Error downloading dataset: {e}")
    print("Please check your Roboflow API key, workspace, project, and version details.")
    DATASET_PATH = None

# --- 4. Fine-Tuning the SAM Model ---
if DATASET_PATH:
    print(f"\nLoading SAM model of type '{MODEL_TYPE}' to '{DEVICE}'...")
    sam_model = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
    sam_model.to(device=DEVICE)
    print("Model loaded successfully.")

    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
    loss_fn = torch.nn.MSELoss()

    # Custom Dataset Class for SAM
    class SAMDataset(Dataset):
        def __init__(self, dataset_path, model):
            self.dataset_path = dataset_path
            self.model = model
            self.predictor = SamPredictor(self.model)
            train_dir = os.path.join(dataset_path, "train")
            self.image_paths = sorted([os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
            self.annotations_path = os.path.join(train_dir, "_annotations.coco.json")

            # 💡 FIX: Use sv.DetectionDataset.from_coco for loading annotations
            self.annotations = sv.DetectionDataset.from_coco(
                annotations_path=self.annotations_path,
                images_directory_path=train_dir
            )

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            image_path = self.image_paths[idx]
            image_name = os.path.basename(image_path)

            # Search for corresponding annotation
            img_annotations = None
            for annotation in self.annotations:
                print(f"Checking annotation: {annotation}")  # Debugging each annotation
                boxes = annotation[0]  # Bounding boxes are the first element in the tuple
                if boxes is not None:  # Ensure boxes exist
                    img_annotations = annotation
                    break

            if img_annotations is None or len(img_annotations[0]) == 0:
                return None

            image = cv2.imread(image_path)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            self.predictor.set_image(image_rgb)
            image_embedding = self.predictor.get_image_embedding().squeeze()

            # Check the data type of the bounding box
            print(f"Bounding box data type: {type(img_annotations[0][0])}")
            print(f"Bounding box values: {img_annotations[0][0]}")

            # Extract the bounding boxes (first element of the tuple)
            try:
                box = torch.tensor([float(x) for x in img_annotations[0][0]], device=self.model.device)  # Ensure values are floats
            except Exception as e:
                print(f"Error converting bounding box to tensor: {e}")
                return None  # Skip this sample if there's an issue with the box

            # Extract mask if it exists (second element in the tuple)
            if img_annotations[1] is not None:
                ground_truth_mask = torch.tensor(img_annotations[1], device=self.model.device).float()
            else:
                ground_truth_mask = torch.zeros_like(box)  # Empty mask if no mask exists

            return image_embedding, box, ground_truth_mask


    def collate_fn(batch):
        batch = list(filter(lambda x: x is not None, batch))
        if not batch: return None, None, None
        return torch.utils.data.dataloader.default_collate(batch)

    sam_dataset = SAMDataset(dataset_path=DATASET_PATH, model=sam_model)
    train_dataloader = DataLoader(sam_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

    print("\nStarting the fine-tuning process...")
    epochs = 10
    for epoch in range(epochs):
        epoch_losses = []
        for image_embedding, box, ground_truth_mask in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
            if image_embedding is None: continue

            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(points=None, boxes=box, masks=None)
            low_res_masks, _ = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            upscaled_masks = sam_model.postprocess_masks(low_res_masks, (1024, 1024), (1024, 1024)).to(DEVICE)

            loss = loss_fn(upscaled_masks, ground_truth_mask.unsqueeze(1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())

        print(f"Epoch {epoch + 1} | Average Loss: {np.mean(epoch_losses):.4f}")

    print("\n✅ Fine-tuning finished.")
    FINETUNED_CHECKPOINT_PATH = os.path.join(HOME, "weights", "finetuned_sam_model.pth")
    torch.save(sam_model.state_dict(), FINETUNED_CHECKPOINT_PATH)
    print(f"Fine-tuned model saved to: {FINETUNED_CHECKPOINT_PATH}")

# --- 5. Inference with the Fine-Tuned Model ---
if DATASET_PATH and 'FINETUNED_CHECKPOINT_PATH' in locals():
    print("\nRunning inference with the fine-tuned model...")

    fine_tuned_model = sam_model_registry[MODEL_TYPE]()
    fine_tuned_model.load_state_dict(torch.load(FINETUNED_CHECKPOINT_PATH))
    fine_tuned_model.to(device=DEVICE)
    predictor = SamPredictor(fine_tuned_model)

    valid_dir = os.path.join(DATASET_PATH, "valid")
    test_image_names = [f for f in os.listdir(valid_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]

    if not test_image_names:
        print("No validation images found to test inference.")
    else:
        test_image_path = os.path.join(valid_dir, np.random.choice(test_image_names))
        test_image = cv2.imread(test_image_path)
        test_image_rgb = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)

        predictor.set_image(test_image_rgb)

        valid_annotations_path = os.path.join(valid_dir, "_annotations.coco.json")
        valid_annotations = sv.DetectionDataset.from_coco(
            annotations_path=valid_annotations_path,
            images_directory_path=valid_dir
        )

        img_annotations = valid_annotations[os.path.basename(test_image_path)]
        input_box = np.array(img_annotations.xyxy[0])

        masks, _, _ = predictor.predict(box=input_box, multimask_output=False)

        detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=masks), mask=masks)

        # Annotate and display the result
        mask_annotator = sv.MaskAnnotator()
        annotated_image = mask_annotator.annotate(scene=test_image.copy(), detections=detections)

        print("\nDisplaying segmentation result from the fine-tuned model:")
        sv.plot_image(annotated_image)

print("\n✨ Script finished successfully!")

Code Snippet:

from segment_anything import sam_model_registry, SamPredictor
import os
import torch 

SAM_CHECKPOINT_PATH = 'models/sam_finetuned_model.pth'
SAM_MODEL_TYPE = 'vit_h'
INPUT_IMAGE_PATH = 'images/kitchen14.jpg'

# --- Setup ---
# Load SAM model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

try:
    checkpoint = torch.load(SAM_CHECKPOINT_PATH, map_location=torch.device('cpu'))
    sam = sam_model_registry[SAM_MODEL_TYPE]()
    sam.load_state_dict(checkpoint)
    sam.to(device)

    predictor = SamPredictor(sam)
    print("SAM model loaded successfully.")
except FileNotFoundError:
    raise FileNotFoundError(f"SAM checkpoint not found at: {SAM_CHECKPOINT_PATH}. Please download the official model and place it here.")
except Exception as e:
    raise RuntimeError(f"Failed to load SAM model. Ensure you are using the official checkpoint file. Error: {e}")

Could anyone shed light on what might be causing this PytorchStreamReader error when loading a fine-tuned SAM model? Specifically:

  • Could the fine-tuning process or saving method result in a corrupted .pth file?
  • Are there specific considerations when saving a fine-tuned SAM model that I might have missed?
  • Are there common pitfalls when attempting to load a custom SAM checkpoint that might lead to this “failed finding central directory” error?

Any guidance or suggestions on troubleshooting this issue would be greatly appreciated.

Thank you for your time and assistance.

Hi @chandani_vibhakar!
Unfortunately I’m unable to support model training outside of the Roboflow app, there are too many potential variables to account for.

Apologies for the inconvenience!

Here is a prompt you can pass to ChatGPT to help locally debug:

I'm trying to load a fine-tuned Segment Anything Model (SAM) in PyTorch for 
inference, but I'm getting this error:
RuntimeError: PytorchStreamReader failed reading zip archive: failed finding
central directory.

In my training script, I saved the model using torch.save(model.state_dict(), 
"finetuned_sam_model.pth"), but in my inference script I'm using 
torch.load("finetuned_sam_model.pth") and expecting a full checkpoint.

Can you help me debug why this error is occurring and how I should correctly 
load my fine-tuned SAM model for inference?
1 Like

Hi @Ford,

Thank you for your quick response and for clarifying the scope of support! I understand that debugging external training setups can be complex due to the many variables involved. I apologize for any inconvenience caused by my initial query.

I appreciate the ChatGPT prompt you provided, it’s a very helpful suggestion. I will definitely use it to debug the model loading issue locally and work on resolving the PytorchStreamReader error.

1 Like