Introduction
VideoQA is designed to answer questions about the content of a video by retrieving the most relevant frame and generating a response. The system leverages the CLIP model to encode both the user's query and the video frames into a shared embedding space, enabling semantic matching via cosine similarity. Once the best frame is identified, the LLAVA model is used to generate a detailed answer, taking into account the selected frame, its transcription, and the chat history.
Approach & Methodology
- Frame Extraction: The video is processed to extract frames at a fixed interval (e.g., one frame per second).
- Frame Embedding: Each frame is encoded into a feature vector using the CLIP model.
- Query Embedding: The user's natural language query is also encoded using CLIP.
- Similarity Search: Cosine similarity is computed between the query embedding and all frame embeddings to find the most relevant frame.
- Response Generation: The LLAVA model generates a response based on the selected frame, its transcription, and the chat history.
Code Structure & Implementation
1. Environment Setup
Explanation:
The environment setup section installs all the necessary Python packages and system dependencies required for the project. This includes faiss
for efficient similarity search, ffmpeg
for video processing, ffmpeg-python
for Python bindings, torch
and torchvision
for deep learning, and CLIP
for vision-language embedding. These commands ensure that the code can process videos, extract frames, and run deep learning models.
# Install required packages
!pip install faiss-cpu # For CPU version, use faiss-gpu for GPU
!apt-get install -y ffmpeg
!pip install ffmpeg-python
!pip install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
2. Imports and Device Setup
Explanation:
This section imports all the required Python libraries for the project. It includes standard libraries like os
and pickle
, as well as specialized libraries for image and video processing (cv2
, PIL
), deep learning (torch
), and vision-language models (clip
, transformers
). The device is set to GPU if available, otherwise CPU, to optimize computation.
import os
import pickle
import numpy as np
import cv2
from PIL import Image
from faiss import IndexFlatL2
from transformers import AutoProcessor, BlipForQuestionAnswering
import torch
import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
3. Frame Extraction from Video
Explanation:
The extract_frames
function takes a video file and extracts frames at a fixed interval (one frame per second). It uses OpenCV to read the video, saves each selected frame as an image file, and returns a list of the saved frame paths. This step is crucial for converting the video into a set of images that can be processed by the CLIP model.
def extract_frames(video_path, output_dir):
os.makedirs(output_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
count = 0
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if count % frame_rate == 0: # Extract one frame per second
frame_path = os.path.join(output_dir, f"frame_{count}.jpg")
cv2.imwrite(frame_path, frame)
frames.append(frame_path)
count += 1
cap.release()
return frames
frames = extract_frames("input_video.mp4", "video_frames")
4. Frame Embedding and Indexing
Explanation: Each extracted frame is processed using the CLIP model to obtain a high-dimensional embedding vector that captures its visual semantics. These embeddings are stacked into a matrix. Dummy transcriptions are created for each frame (these could be replaced with real transcriptions). The embeddings are then indexed using FAISS, which allows for fast similarity search. Metadata for each frame (path and transcription) is saved for later retrieval.
import faiss
model, preprocess = clip.load("ViT-B/32", device=device)
frame_embeddings = []
for frame_path in frames:
image = preprocess(Image.open(frame_path)).unsqueeze(0).to(device)
with torch.no_grad():
frame_embedding = model.encode_image(image).cpu().numpy()
frame_embeddings.append(frame_embedding)
frame_embeddings = np.vstack(frame_embeddings) # Shape: (num_frames, embedding_dim)
# Dummy transcriptions for each frame
transcriptions = [f"Transcription for {frame_path}" for frame_path in frames]
# Initialize FAISS index for similarity search
embedding_dim = frame_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(frame_embeddings)
# Store metadata
metadata = [{"frame_path": frames[i], "transcription": transcriptions[i]} for i in range(len(frames))]
# Save metadata and index
with open("metadata.pkl", "wb") as f:
pickle.dump(metadata, f)
faiss.write_index(index, "faiss_index.bin")
Main Components
LLAVAChatModel Class
Explanation:
The LLAVAChatModel
class encapsulates the logic for processing user queries and generating responses. It loads both the CLIP and LLAVA models and provides methods for:
- Encoding the user's query using CLIP and finding the most similar frame via cosine similarity in the FAISS index.
- Retrieving the corresponding frame and its transcription from metadata.
- Generating a detailed answer using the LLAVA model, which takes the selected frame, the query, and chat history as input.
from transformers import CLIPProcessor, CLIPModel
class LLAVAChatModel:
def __init__(self, llava_processor, llava_model, clip_processor, clip_model, device="cuda"):
self.device = device
self.llava_processor = llava_processor
self.llava_model = llava_model.to(self.device)
self.clip_processor = clip_processor
self.clip_model = clip_model.to(self.device)
def process_query(self, user_prompt, chat_history=None):
# Load FAISS index and metadata
index = faiss.read_index("faiss_index.bin")
with open("metadata.pkl", "rb") as f:
metadata = pickle.load(f)
# Encode the user prompt using CLIP
inputs = self.clip_processor(text=[user_prompt], return_tensors="pt").to(self.device)
with torch.no_grad():
text_embedding = self.clip_model.get_text_features(**inputs).cpu().numpy()
# Find the most similar frame using cosine similarity
from sklearn.metrics.pairwise import cosine_similarity
similarities = cosine_similarity(text_embedding, index.reconstruct_n(0, index.ntotal))
best_frame_idx = similarities.argmax()
# Get metadata for the best frame
best_metadata = metadata[best_frame_idx]
frame_path = best_metadata['frame_path']
transcription = best_metadata['transcription']
# Generate response using LLAVA
chat_history = chat_history or []
response = self.generate_response(
frame_path=frame_path,
user_prompt=user_prompt,
chat_history=chat_history,
transcription=transcription
)
return {
'response': response,
'frame_path': frame_path,
'transcription': transcription,
'similarity_score': similarities[0][best_frame_idx]
}
def generate_response(self, frame_path, user_prompt, chat_history, transcription):
from PIL import Image
image = Image.open(frame_path)
formatted_history = "\n".join(
f"{message['role'].capitalize()}: {message['content']}" for message in chat_history
)
prompt = f"""
Chat History:
{formatted_history}
Frame Transcription:
{transcription}
Question: {user_prompt}
Answer:"""
inputs = self.llava_processor(
text=[prompt], images=image, return_tensors="pt", padding=True
).to(self.device)
output = self.llava_model.generate(**inputs, max_new_tokens=200)
response = self.llava_processor.tokenizer.decode(output[0], skip_special_tokens=True)
return response
Main Function
Explanation:
The main
function initializes the LLAVA processor and model, creates an instance of LLAVAChatModel
, and starts an interactive chat loop. For each user query, it retrieves the most relevant frame, generates a response, and maintains the chat history. The loop continues until the user types "quit" or "exit".
def main():
print("Starting chat application...")
# Initialize LLAVA processor and model
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
# Initialize the chat model
chat_model = LLAVAChatModel(processor, model)
# Initialize chat history
chat_history = []
print("Chat model ready! You can start chatting...")
while True:
try:
user_prompt = input("\nYou: ")
if user_prompt.lower() in ['quit', 'exit']:
print("Ending chat session...")
break
result = chat_model.process_query(user_prompt, chat_history)
chat_history.append({'role': 'USER', 'content': user_prompt})
chat_history.append({'role': 'ASSISTANT', 'content': result['response']})
print("\nAssistant:", result['response'])
print(f"Debug Info: Frame - {result['frame_path']}, Transcription - {result['transcription']}")
except Exception as e:
print(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()
Conclusion
VideoQA demonstrates an effective pipeline for video frame retrieval and question answering using state-of-the-art vision-language models. By encoding both queries and frames with CLIP and leveraging cosine similarity, the system efficiently identifies the most relevant frame. The integration of LLAVA further enhances the system's ability to generate detailed, context-aware responses, making VideoQA a powerful tool for video understanding and interactive applications.