#!/usr/bin/env python3 import argparse import subprocess import os import json import pyarrow as pa import time import hashlib import signal import configparser import lancedb import mimetypes import glob from lancedb.embeddings.ollama import OllamaEmbeddings from functools import wraps # Ignore warnings os.environ['RUST_LOG'] = 'error' def handle_signals(signame): def signal_handler(sig, frame): print(f"\nReceived {signame} signal. Exiting...") quit(0) return signal_handler # Register signal handlers signal.signal(signal.SIGINT, handle_signals("SIGINT")) signal.signal(signal.SIGTERM, handle_signals("SIGTERM")) def load_config(args): config = configparser.ConfigParser() config.read(args.config) return config def setup_database(config): s3_bucket = config.get("s3", "bucket_name", fallback=None) enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on'] persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb") if s3_bucket and enable_s3: storage_options = { "aws_access_key_id": config.get("s3", "access_key_id", fallback=None), "aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None), "region": config.get("s3", "region", fallback="us-east-1"), "endpoint_url": config.get("s3", "endpoint", fallback=None), } db = lancedb.connect( f"s3://{s3_bucket}", storage_options=storage_options ) else: db = lancedb.connect(persist_directory) return db def setup_embedding_model(config): ollama_url = config.get("ollama", "url", fallback="http://localhost:11434") embedding_model = OllamaEmbeddings( host=ollama_url, name='nomic-embed-text', options=None, # type=o.llama.Options keep_alive=None, ollama_client_kwargs={ 'verify': False # Disable certificate verification (not recommended) } ) return embedding_model def create_table(db): table_name = "vectordb" schema_dict = pa.schema([ pa.field("text", pa.string()), pa.field("id", pa.string()), pa.field("description", pa.string()), pa.field("vector", pa.list_(pa.float64(), 768)) ]) try: table = db.open_table(table_name) except ValueError as e: if "Table '" in str(e) and "' was not found" in str(e): print(f"Table '{table_name}' not found. Creating...") # Convert dictionary schema to PyArrow schema schema = pa.schema(schema_dict) table = db.create_table(table_name, schema=schema, mode="overwrite") else: quit(f"A error occurred when opening table: {e}") return table def is_git_directory(path="."): #print(f"path: {path}") return subprocess.call(['git', 'rev-parse', '--is-inside-work-tree'], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0 def load_documents(root=None, exclude=None): documents = [] # Split exclude into patterns if provided exclude_patterns = (exclude or "").split("|") if exclude else [] # Check if root is None and set it to the current directory if not provided if root is None: root = os.getcwd() # Iterate through directories and files for root, dirs, files in os.walk(root): # Skip files specified in exclude files = [f for f in files if not any(glob.fnmatch.fnmatch(f"{root}/{f}", pattern) for pattern in exclude_patterns)] for file in files: path = os.path.join(root, file) try: with open(path, "rb") as f: content_bytes = f.read() content_type, _ = mimetypes.guess_type(f.name) # Decode the content to UTF-8 content_str = content_bytes.decode("utf-8", errors='ignore') # Explicitly treat application/json as text/plain if 'application/json' == content_type: content_type = "text/plain" # Fallback check if the guessed MIME type is None or not text if content_type is None or 'text' not in content_type: if not any(char in content_str for char in "\n\r\t\v\f "): continue description = "" #print(f"path: {f.name}; root: {root}; f.name: {f.name}; content_type: {content_type}; file: {file};") if is_git_directory(root): try: description = subprocess.check_output(["git", "show", "--no-patch", path], stderr=subprocess.DEVNULL).decode("utf-8").strip() or "" except subprocess.CalledProcessError as e: print(f"Error fetching git description for {path}: {e}") print(f"Documents found '{f.name}'.") doc_id = hashlib.sha256(f"{os.path.dirname(path)}{path}".encode()).hexdigest() documents.append({"text": content_str, "id": doc_id, "description": description}) except Exception as e: print(f"Error reading file {path}: {e}") return documents def load_git_data(): if not is_git_directory(): print("Current directory is not a Git repository.") return [] log_entries = subprocess.check_output([ "git", "log", "--pretty=format:%h %s", "--no-merges" ], stderr=subprocess.DEVNULL, text=True).strip().split("\n") git_documents = [] for entry in log_entries: commit_hash, message = entry.split(maxsplit=1) description = subprocess.check_output(["git", "show", "--no-patch", f"{commit_hash}"], stderr=subprocess.DEVNULL).decode("utf-8").strip() or "" git_documents.append({"text": f"Commit {commit_hash}: {message}", "id": commit_hash, "description": description}) return git_documents def generate_embeddings(documents, embedding_model): print("Generating embeddings...") for doc in documents: text = doc["text"] doc_id = doc["id"] embedding = embedding_model.generate_embeddings([text])[0] doc["vector"] = embedding print("Done.") return documents def upsert_documents(table, documents): table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute(documents) affected_rows = table.count_rows() print(f"Inserted {affected_rows} documents.") def create_vector_index(table): try: print(f"Creating vector index") table.create_index(metric="cosine", vector_column_name="vector") print("Vector index created successfully.") except Exception as e: quit(f"Error creating vector index: {e}") def wait_for_index(table, index_name): POLL_INTERVAL = 10 while True: indices = table.list_indices() if indices and any(index.name == index_name for index in indices): break print(f"Waiting for {index_name} to be ready...") time.sleep(POLL_INTERVAL) print(f"Vector index {index_name} is ready!") def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="config.ini", help="Path to config file") parser.add_argument("--root", type=str, help="Root directory to process") parser.add_argument("--exclude", type=str, help="Exclude patterns separated by '|'") args, _ = parser.parse_known_args() config = load_config(args) db = setup_database(config) embedding_model = setup_embedding_model(config) table = create_table(db) documents = load_documents(root=args.root, exclude=args.exclude) git_documents = load_git_data() documents.extend(git_documents) documents = generate_embeddings(documents, embedding_model) upsert_documents(table, documents) create_vector_index(table) wait_for_index(table, "vector_idx") print("Documents inserted.") if __name__ == "__main__": main()