VideoQA: Frame Retrieval from Video using CLIP

Ruhaan Choudhary
Shivang Nagta

GitHub Repository

Abstract

VideoQA is a tool that retrieves the most appropriate frame from a video based on a natural language query. The query and video frames are both encoded using the CLIP model, and cosine similarity is used to find the best matching frame. The system then uses the LLAVA model to generate a detailed answer or description based on the selected frame and the user's query.

Introduction

VideoQA is designed to answer questions about the content of a video by retrieving the most relevant frame and generating a response. The system leverages the CLIP model to encode both the user's query and the video frames into a shared embedding space, enabling semantic matching via cosine similarity. Once the best frame is identified, the LLAVA model is used to generate a detailed answer, taking into account the selected frame, its transcription, and the chat history.

Approach & Methodology

Code Structure & Implementation

1. Environment Setup

Explanation: The environment setup section installs all the necessary Python packages and system dependencies required for the project. This includes faiss for efficient similarity search, ffmpeg for video processing, ffmpeg-python for Python bindings, torch and torchvision for deep learning, and CLIP for vision-language embedding. These commands ensure that the code can process videos, extract frames, and run deep learning models.


# Install required packages
!pip install faiss-cpu # For CPU version, use faiss-gpu for GPU
!apt-get install -y ffmpeg
!pip install ffmpeg-python
!pip install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
            

2. Imports and Device Setup

Explanation: This section imports all the required Python libraries for the project. It includes standard libraries like os and pickle, as well as specialized libraries for image and video processing (cv2, PIL), deep learning (torch), and vision-language models (clip, transformers). The device is set to GPU if available, otherwise CPU, to optimize computation.


import os
import pickle
import numpy as np
import cv2
from PIL import Image
from faiss import IndexFlatL2
from transformers import AutoProcessor, BlipForQuestionAnswering
import torch
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
            

3. Frame Extraction from Video

Explanation: The extract_frames function takes a video file and extracts frames at a fixed interval (one frame per second). It uses OpenCV to read the video, saves each selected frame as an image file, and returns a list of the saved frame paths. This step is crucial for converting the video into a set of images that can be processed by the CLIP model.


def extract_frames(video_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
    count = 0
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if count % frame_rate == 0:  # Extract one frame per second
            frame_path = os.path.join(output_dir, f"frame_{count}.jpg")
            cv2.imwrite(frame_path, frame)
            frames.append(frame_path)
        count += 1
    cap.release()
    return frames

frames = extract_frames("input_video.mp4", "video_frames")
            

4. Frame Embedding and Indexing

Explanation: Each extracted frame is processed using the CLIP model to obtain a high-dimensional embedding vector that captures its visual semantics. These embeddings are stacked into a matrix. Dummy transcriptions are created for each frame (these could be replaced with real transcriptions). The embeddings are then indexed using FAISS, which allows for fast similarity search. Metadata for each frame (path and transcription) is saved for later retrieval.


import faiss

model, preprocess = clip.load("ViT-B/32", device=device)
frame_embeddings = []

for frame_path in frames:
    image = preprocess(Image.open(frame_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        frame_embedding = model.encode_image(image).cpu().numpy()
        frame_embeddings.append(frame_embedding)

frame_embeddings = np.vstack(frame_embeddings)  # Shape: (num_frames, embedding_dim)

# Dummy transcriptions for each frame
transcriptions = [f"Transcription for {frame_path}" for frame_path in frames]

# Initialize FAISS index for similarity search
embedding_dim = frame_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(frame_embeddings)

# Store metadata
metadata = [{"frame_path": frames[i], "transcription": transcriptions[i]} for i in range(len(frames))]

# Save metadata and index
with open("metadata.pkl", "wb") as f:
    pickle.dump(metadata, f)
faiss.write_index(index, "faiss_index.bin")
            

Main Components

LLAVAChatModel Class

Explanation: The LLAVAChatModel class encapsulates the logic for processing user queries and generating responses. It loads both the CLIP and LLAVA models and provides methods for:

This modular design allows for easy extension and integration with different models or retrieval strategies.


from transformers import CLIPProcessor, CLIPModel

class LLAVAChatModel:
    def __init__(self, llava_processor, llava_model, clip_processor, clip_model, device="cuda"):
        self.device = device
        self.llava_processor = llava_processor
        self.llava_model = llava_model.to(self.device)
        self.clip_processor = clip_processor
        self.clip_model = clip_model.to(self.device)

    def process_query(self, user_prompt, chat_history=None):
        # Load FAISS index and metadata
        index = faiss.read_index("faiss_index.bin")
        with open("metadata.pkl", "rb") as f:
            metadata = pickle.load(f)

        # Encode the user prompt using CLIP
        inputs = self.clip_processor(text=[user_prompt], return_tensors="pt").to(self.device)
        with torch.no_grad():
            text_embedding = self.clip_model.get_text_features(**inputs).cpu().numpy()

        # Find the most similar frame using cosine similarity
        from sklearn.metrics.pairwise import cosine_similarity
        similarities = cosine_similarity(text_embedding, index.reconstruct_n(0, index.ntotal))
        best_frame_idx = similarities.argmax()

        # Get metadata for the best frame
        best_metadata = metadata[best_frame_idx]
        frame_path = best_metadata['frame_path']
        transcription = best_metadata['transcription']

        # Generate response using LLAVA
        chat_history = chat_history or []
        response = self.generate_response(
            frame_path=frame_path,
            user_prompt=user_prompt,
            chat_history=chat_history,
            transcription=transcription
        )

        return {
            'response': response,
            'frame_path': frame_path,
            'transcription': transcription,
            'similarity_score': similarities[0][best_frame_idx]
        }

    def generate_response(self, frame_path, user_prompt, chat_history, transcription):
        from PIL import Image
        image = Image.open(frame_path)
        formatted_history = "\n".join(
            f"{message['role'].capitalize()}: {message['content']}" for message in chat_history
        )
        prompt = f"""
        Chat History:
        {formatted_history}

        Frame Transcription:
        {transcription}

        Question: {user_prompt}
        Answer:"""
        inputs = self.llava_processor(
            text=[prompt], images=image, return_tensors="pt", padding=True
        ).to(self.device)
        output = self.llava_model.generate(**inputs, max_new_tokens=200)
        response = self.llava_processor.tokenizer.decode(output[0], skip_special_tokens=True)
        return response
            

Main Function

Explanation: The main function initializes the LLAVA processor and model, creates an instance of LLAVAChatModel, and starts an interactive chat loop. For each user query, it retrieves the most relevant frame, generates a response, and maintains the chat history. The loop continues until the user types "quit" or "exit".


def main():
    print("Starting chat application...")

    # Initialize LLAVA processor and model
    processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
    model = LlavaForConditionalGeneration.from_pretrained(
        "llava-hf/llava-1.5-7b-hf",
        torch_dtype=torch.float16,
        device_map="auto"
    )

    # Initialize the chat model
    chat_model = LLAVAChatModel(processor, model)

    # Initialize chat history
    chat_history = []

    print("Chat model ready! You can start chatting...")

    while True:
        try:
            user_prompt = input("\nYou: ")
            if user_prompt.lower() in ['quit', 'exit']:
                print("Ending chat session...")
                break

            result = chat_model.process_query(user_prompt, chat_history)

            chat_history.append({'role': 'USER', 'content': user_prompt})
            chat_history.append({'role': 'ASSISTANT', 'content': result['response']})

            print("\nAssistant:", result['response'])
            print(f"Debug Info: Frame - {result['frame_path']}, Transcription - {result['transcription']}")
        except Exception as e:
            print(f"An error occurred: {str(e)}")


if __name__ == "__main__":
    main()
            

Conclusion

VideoQA demonstrates an effective pipeline for video frame retrieval and question answering using state-of-the-art vision-language models. By encoding both queries and frames with CLIP and leveraging cosine similarity, the system efficiently identifies the most relevant frame. The integration of LLAVA further enhances the system's ability to generate detailed, context-aware responses, making VideoQA a powerful tool for video understanding and interactive applications.