CodeRecall/lancedb_ingest.py

217 lines
No EOL
8 KiB
Python

#!/usr/bin/env python3
import argparse
import subprocess
import os
import json
import pyarrow as pa
import time
import hashlib
import signal
import configparser
import lancedb
import mimetypes
import glob
from lancedb.embeddings.ollama import OllamaEmbeddings
from functools import wraps
# Ignore warnings
os.environ['RUST_LOG'] = 'error'
def handle_signals(signame):
def signal_handler(sig, frame):
print(f"\nReceived {signame} signal. Exiting...")
quit(0)
return signal_handler
# Register signal handlers
signal.signal(signal.SIGINT, handle_signals("SIGINT"))
signal.signal(signal.SIGTERM, handle_signals("SIGTERM"))
def load_config(args):
config = configparser.ConfigParser()
config.read(args.config)
return config
def setup_database(config):
s3_bucket = config.get("s3", "bucket_name", fallback=None)
enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on']
persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb")
if s3_bucket and enable_s3:
storage_options = {
"aws_access_key_id": config.get("s3", "access_key_id", fallback=None),
"aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None),
"region": config.get("s3", "region", fallback="us-east-1"),
"endpoint_url": config.get("s3", "endpoint", fallback=None),
}
db = lancedb.connect(
f"s3://{s3_bucket}",
storage_options=storage_options
)
else:
db = lancedb.connect(persist_directory)
return db
def setup_embedding_model(config):
ollama_url = config.get("ollama", "url", fallback="http://localhost:11434")
embedding_model = OllamaEmbeddings(
host=ollama_url,
name='nomic-embed-text',
options=None, # type=o.llama.Options
keep_alive=None,
ollama_client_kwargs={
'verify': False # Disable certificate verification (not recommended)
}
)
return embedding_model
def create_table(db):
table_name = "vectordb"
schema_dict = pa.schema([
pa.field("text", pa.string()),
pa.field("id", pa.string()),
pa.field("description", pa.string()),
pa.field("vector", pa.list_(pa.float64(), 768))
])
try:
table = db.open_table(table_name)
except ValueError as e:
if "Table '" in str(e) and "' was not found" in str(e):
print(f"Table '{table_name}' not found. Creating...")
# Convert dictionary schema to PyArrow schema
schema = pa.schema(schema_dict)
table = db.create_table(table_name, schema=schema, mode="overwrite")
else:
quit(f"A error occurred when opening table: {e}")
return table
def is_git_directory(path="."):
#print(f"path: {path}")
return subprocess.call(['git', 'rev-parse', '--is-inside-work-tree'], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0
def load_documents(root=None, exclude=None):
documents = []
# Split exclude into patterns if provided
exclude_patterns = (exclude or "").split("|") if exclude else []
# Check if root is None and set it to the current directory if not provided
if root is None:
root = os.getcwd()
# Iterate through directories and files
for root, dirs, files in os.walk(root):
# Skip files specified in exclude
files = [f for f in files if not any(glob.fnmatch.fnmatch(f"{root}/{f}", pattern) for pattern in exclude_patterns)]
for file in files:
path = os.path.join(root, file)
try:
with open(path, "rb") as f:
content_bytes = f.read()
content_type, _ = mimetypes.guess_type(f.name)
# Decode the content to UTF-8
content_str = content_bytes.decode("utf-8", errors='ignore')
# Explicitly treat application/json as text/plain
if 'application/json' == content_type:
content_type = "text/plain"
# Fallback check if the guessed MIME type is None or not text
if content_type is None or 'text' not in content_type:
if not any(char in content_str for char in "\n\r\t\v\f "):
continue
description = ""
#print(f"path: {f.name}; root: {root}; f.name: {f.name}; content_type: {content_type}; file: {file};")
if is_git_directory(root):
try:
description = subprocess.check_output(["git", "show", "--no-patch", path], stderr=subprocess.DEVNULL).decode("utf-8").strip() or ""
except subprocess.CalledProcessError as e:
print(f"Error fetching git description for {path}: {e}")
print(f"Documents found '{f.name}'.")
doc_id = hashlib.sha256(f"{os.path.dirname(path)}{path}".encode()).hexdigest()
documents.append({"text": content_str, "id": doc_id, "description": description})
except Exception as e:
print(f"Error reading file {path}: {e}")
return documents
def load_git_data():
if not is_git_directory():
print("Current directory is not a Git repository.")
return []
log_entries = subprocess.check_output([
"git", "log", "--pretty=format:%h %s", "--no-merges"
], stderr=subprocess.DEVNULL, text=True).strip().split("\n")
git_documents = []
for entry in log_entries:
commit_hash, message = entry.split(maxsplit=1)
description = subprocess.check_output(["git", "show", "--no-patch", f"{commit_hash}"], stderr=subprocess.DEVNULL).decode("utf-8").strip() or ""
git_documents.append({"text": f"Commit {commit_hash}: {message}", "id": commit_hash, "description": description})
return git_documents
def generate_embeddings(documents, embedding_model):
print("Generating embeddings...")
for doc in documents:
text = doc["text"]
doc_id = doc["id"]
embedding = embedding_model.generate_embeddings([text])[0]
doc["vector"] = embedding
print("Done.")
return documents
def upsert_documents(table, documents):
table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute(documents)
affected_rows = table.count_rows()
print(f"Inserted {affected_rows} documents.")
def create_vector_index(table):
try:
print(f"Creating vector index")
table.create_index(metric="cosine", vector_column_name="vector")
print("Vector index created successfully.")
except Exception as e:
quit(f"Error creating vector index: {e}")
def wait_for_index(table, index_name):
POLL_INTERVAL = 10
while True:
indices = table.list_indices()
if indices and any(index.name == index_name for index in indices):
break
print(f"Waiting for {index_name} to be ready...")
time.sleep(POLL_INTERVAL)
print(f"Vector index {index_name} is ready!")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config.ini", help="Path to config file")
parser.add_argument("--root", type=str, help="Root directory to process")
parser.add_argument("--exclude", type=str, help="Exclude patterns separated by '|'")
args, _ = parser.parse_known_args()
config = load_config(args)
db = setup_database(config)
embedding_model = setup_embedding_model(config)
table = create_table(db)
documents = load_documents(root=args.root, exclude=args.exclude)
git_documents = load_git_data()
documents.extend(git_documents)
documents = generate_embeddings(documents, embedding_model)
upsert_documents(table, documents)
create_vector_index(table)
wait_for_index(table, "vector_idx")
print("Documents inserted.")
if __name__ == "__main__":
main()