217 lines
No EOL
8 KiB
Python
217 lines
No EOL
8 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import subprocess
|
|
import os
|
|
import json
|
|
import pyarrow as pa
|
|
import time
|
|
import hashlib
|
|
import signal
|
|
import configparser
|
|
import lancedb
|
|
import mimetypes
|
|
import glob
|
|
from lancedb.embeddings.ollama import OllamaEmbeddings
|
|
from functools import wraps
|
|
|
|
# Ignore warnings
|
|
os.environ['RUST_LOG'] = 'error'
|
|
|
|
def handle_signals(signame):
|
|
def signal_handler(sig, frame):
|
|
print(f"\nReceived {signame} signal. Exiting...")
|
|
quit(0)
|
|
return signal_handler
|
|
|
|
# Register signal handlers
|
|
signal.signal(signal.SIGINT, handle_signals("SIGINT"))
|
|
signal.signal(signal.SIGTERM, handle_signals("SIGTERM"))
|
|
|
|
def load_config(args):
|
|
config = configparser.ConfigParser()
|
|
config.read(args.config)
|
|
return config
|
|
|
|
def setup_database(config):
|
|
s3_bucket = config.get("s3", "bucket_name", fallback=None)
|
|
enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on']
|
|
persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb")
|
|
|
|
if s3_bucket and enable_s3:
|
|
storage_options = {
|
|
"aws_access_key_id": config.get("s3", "access_key_id", fallback=None),
|
|
"aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None),
|
|
"region": config.get("s3", "region", fallback="us-east-1"),
|
|
"endpoint_url": config.get("s3", "endpoint", fallback=None),
|
|
}
|
|
db = lancedb.connect(
|
|
f"s3://{s3_bucket}",
|
|
storage_options=storage_options
|
|
)
|
|
else:
|
|
db = lancedb.connect(persist_directory)
|
|
|
|
return db
|
|
|
|
def setup_embedding_model(config):
|
|
ollama_url = config.get("ollama", "url", fallback="http://localhost:11434")
|
|
embedding_model = OllamaEmbeddings(
|
|
host=ollama_url,
|
|
name='nomic-embed-text',
|
|
options=None, # type=o.llama.Options
|
|
keep_alive=None,
|
|
ollama_client_kwargs={
|
|
'verify': False # Disable certificate verification (not recommended)
|
|
}
|
|
)
|
|
return embedding_model
|
|
|
|
def create_table(db):
|
|
table_name = "vectordb"
|
|
schema_dict = pa.schema([
|
|
pa.field("text", pa.string()),
|
|
pa.field("id", pa.string()),
|
|
pa.field("description", pa.string()),
|
|
pa.field("vector", pa.list_(pa.float64(), 768))
|
|
])
|
|
|
|
try:
|
|
table = db.open_table(table_name)
|
|
except ValueError as e:
|
|
if "Table '" in str(e) and "' was not found" in str(e):
|
|
print(f"Table '{table_name}' not found. Creating...")
|
|
# Convert dictionary schema to PyArrow schema
|
|
schema = pa.schema(schema_dict)
|
|
table = db.create_table(table_name, schema=schema, mode="overwrite")
|
|
else:
|
|
quit(f"A error occurred when opening table: {e}")
|
|
return table
|
|
|
|
def is_git_directory(path="."):
|
|
#print(f"path: {path}")
|
|
return subprocess.call(['git', 'rev-parse', '--is-inside-work-tree'], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0
|
|
|
|
def load_documents(root=None, exclude=None):
|
|
documents = []
|
|
|
|
# Split exclude into patterns if provided
|
|
exclude_patterns = (exclude or "").split("|") if exclude else []
|
|
|
|
# Check if root is None and set it to the current directory if not provided
|
|
if root is None:
|
|
root = os.getcwd()
|
|
|
|
# Iterate through directories and files
|
|
for root, dirs, files in os.walk(root):
|
|
|
|
# Skip files specified in exclude
|
|
files = [f for f in files if not any(glob.fnmatch.fnmatch(f"{root}/{f}", pattern) for pattern in exclude_patterns)]
|
|
|
|
for file in files:
|
|
path = os.path.join(root, file)
|
|
try:
|
|
with open(path, "rb") as f:
|
|
content_bytes = f.read()
|
|
content_type, _ = mimetypes.guess_type(f.name)
|
|
|
|
# Decode the content to UTF-8
|
|
content_str = content_bytes.decode("utf-8", errors='ignore')
|
|
|
|
# Explicitly treat application/json as text/plain
|
|
if 'application/json' == content_type:
|
|
content_type = "text/plain"
|
|
|
|
# Fallback check if the guessed MIME type is None or not text
|
|
if content_type is None or 'text' not in content_type:
|
|
if not any(char in content_str for char in "\n\r\t\v\f "):
|
|
continue
|
|
|
|
description = ""
|
|
#print(f"path: {f.name}; root: {root}; f.name: {f.name}; content_type: {content_type}; file: {file};")
|
|
if is_git_directory(root):
|
|
try:
|
|
description = subprocess.check_output(["git", "show", "--no-patch", path], stderr=subprocess.DEVNULL).decode("utf-8").strip() or ""
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error fetching git description for {path}: {e}")
|
|
|
|
print(f"Documents found '{f.name}'.")
|
|
doc_id = hashlib.sha256(f"{os.path.dirname(path)}{path}".encode()).hexdigest()
|
|
documents.append({"text": content_str, "id": doc_id, "description": description})
|
|
except Exception as e:
|
|
print(f"Error reading file {path}: {e}")
|
|
return documents
|
|
|
|
def load_git_data():
|
|
if not is_git_directory():
|
|
print("Current directory is not a Git repository.")
|
|
return []
|
|
|
|
log_entries = subprocess.check_output([
|
|
"git", "log", "--pretty=format:%h %s", "--no-merges"
|
|
], stderr=subprocess.DEVNULL, text=True).strip().split("\n")
|
|
|
|
git_documents = []
|
|
for entry in log_entries:
|
|
commit_hash, message = entry.split(maxsplit=1)
|
|
description = subprocess.check_output(["git", "show", "--no-patch", f"{commit_hash}"], stderr=subprocess.DEVNULL).decode("utf-8").strip() or ""
|
|
git_documents.append({"text": f"Commit {commit_hash}: {message}", "id": commit_hash, "description": description})
|
|
|
|
return git_documents
|
|
|
|
def generate_embeddings(documents, embedding_model):
|
|
print("Generating embeddings...")
|
|
for doc in documents:
|
|
text = doc["text"]
|
|
doc_id = doc["id"]
|
|
embedding = embedding_model.generate_embeddings([text])[0]
|
|
doc["vector"] = embedding
|
|
print("Done.")
|
|
return documents
|
|
|
|
def upsert_documents(table, documents):
|
|
table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute(documents)
|
|
affected_rows = table.count_rows()
|
|
print(f"Inserted {affected_rows} documents.")
|
|
|
|
def create_vector_index(table):
|
|
try:
|
|
print(f"Creating vector index")
|
|
table.create_index(metric="cosine", vector_column_name="vector")
|
|
print("Vector index created successfully.")
|
|
except Exception as e:
|
|
quit(f"Error creating vector index: {e}")
|
|
|
|
def wait_for_index(table, index_name):
|
|
POLL_INTERVAL = 10
|
|
while True:
|
|
indices = table.list_indices()
|
|
if indices and any(index.name == index_name for index in indices):
|
|
break
|
|
print(f"Waiting for {index_name} to be ready...")
|
|
time.sleep(POLL_INTERVAL)
|
|
print(f"Vector index {index_name} is ready!")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="config.ini", help="Path to config file")
|
|
parser.add_argument("--root", type=str, help="Root directory to process")
|
|
parser.add_argument("--exclude", type=str, help="Exclude patterns separated by '|'")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
config = load_config(args)
|
|
db = setup_database(config)
|
|
embedding_model = setup_embedding_model(config)
|
|
table = create_table(db)
|
|
|
|
documents = load_documents(root=args.root, exclude=args.exclude)
|
|
git_documents = load_git_data()
|
|
documents.extend(git_documents)
|
|
|
|
documents = generate_embeddings(documents, embedding_model)
|
|
upsert_documents(table, documents)
|
|
create_vector_index(table)
|
|
wait_for_index(table, "vector_idx")
|
|
print("Documents inserted.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |