Initial code commit
This commit is contained in:
parent
14c06787dc
commit
9842d55bbf
7 changed files with 441 additions and 32 deletions
217
lancedb_ingest.py
Normal file
217
lancedb_ingest.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import subprocess
|
||||
import os
|
||||
import json
|
||||
import pyarrow as pa
|
||||
import time
|
||||
import hashlib
|
||||
import signal
|
||||
import configparser
|
||||
import lancedb
|
||||
import mimetypes
|
||||
import glob
|
||||
from lancedb.embeddings.ollama import OllamaEmbeddings
|
||||
from functools import wraps
|
||||
|
||||
# Ignore warnings
|
||||
os.environ['RUST_LOG'] = 'error'
|
||||
|
||||
def handle_signals(signame):
|
||||
def signal_handler(sig, frame):
|
||||
print(f"\nReceived {signame} signal. Exiting...")
|
||||
quit(0)
|
||||
return signal_handler
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, handle_signals("SIGINT"))
|
||||
signal.signal(signal.SIGTERM, handle_signals("SIGTERM"))
|
||||
|
||||
def load_config(args):
|
||||
config = configparser.ConfigParser()
|
||||
config.read(args.config)
|
||||
return config
|
||||
|
||||
def setup_database(config):
|
||||
s3_bucket = config.get("s3", "bucket_name", fallback=None)
|
||||
enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on']
|
||||
persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb")
|
||||
|
||||
if s3_bucket and enable_s3:
|
||||
storage_options = {
|
||||
"aws_access_key_id": config.get("s3", "access_key_id", fallback=None),
|
||||
"aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None),
|
||||
"region": config.get("s3", "region", fallback="us-east-1"),
|
||||
"endpoint_url": config.get("s3", "endpoint", fallback=None),
|
||||
}
|
||||
db = lancedb.connect(
|
||||
f"s3://{s3_bucket}",
|
||||
storage_options=storage_options
|
||||
)
|
||||
else:
|
||||
db = lancedb.connect(persist_directory)
|
||||
|
||||
return db
|
||||
|
||||
def setup_embedding_model(config):
|
||||
ollama_url = config.get("ollama", "url", fallback="http://localhost:11434")
|
||||
embedding_model = OllamaEmbeddings(
|
||||
host=ollama_url,
|
||||
name='nomic-embed-text',
|
||||
options=None, # type=o.llama.Options
|
||||
keep_alive=None,
|
||||
ollama_client_kwargs={
|
||||
'verify': False # Disable certificate verification (not recommended)
|
||||
}
|
||||
)
|
||||
return embedding_model
|
||||
|
||||
def create_table(db):
|
||||
table_name = "vectordb"
|
||||
schema_dict = pa.schema([
|
||||
pa.field("text", pa.string()),
|
||||
pa.field("id", pa.string()),
|
||||
pa.field("description", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float64(), 768))
|
||||
])
|
||||
|
||||
try:
|
||||
table = db.open_table(table_name)
|
||||
except ValueError as e:
|
||||
if "Table '" in str(e) and "' was not found" in str(e):
|
||||
print(f"Table '{table_name}' not found. Creating...")
|
||||
# Convert dictionary schema to PyArrow schema
|
||||
schema = pa.schema(schema_dict)
|
||||
table = db.create_table(table_name, schema=schema, mode="overwrite")
|
||||
else:
|
||||
quit(f"A error occurred when opening table: {e}")
|
||||
return table
|
||||
|
||||
def is_git_directory(path="."):
|
||||
#print(f"path: {path}")
|
||||
return subprocess.call(['git', 'rev-parse', '--is-inside-work-tree'], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0
|
||||
|
||||
def load_documents(root=None, exclude=None):
|
||||
documents = []
|
||||
|
||||
# Split exclude into patterns if provided
|
||||
exclude_patterns = (exclude or "").split("|") if exclude else []
|
||||
|
||||
# Check if root is None and set it to the current directory if not provided
|
||||
if root is None:
|
||||
root = os.getcwd()
|
||||
|
||||
# Iterate through directories and files
|
||||
for root, dirs, files in os.walk(root):
|
||||
|
||||
# Skip files specified in exclude
|
||||
files = [f for f in files if not any(glob.fnmatch.fnmatch(f"{root}/{f}", pattern) for pattern in exclude_patterns)]
|
||||
|
||||
for file in files:
|
||||
path = os.path.join(root, file)
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
content_bytes = f.read()
|
||||
content_type, _ = mimetypes.guess_type(f.name)
|
||||
|
||||
# Decode the content to UTF-8
|
||||
content_str = content_bytes.decode("utf-8", errors='ignore')
|
||||
|
||||
# Explicitly treat application/json as text/plain
|
||||
if 'application/json' == content_type:
|
||||
content_type = "text/plain"
|
||||
|
||||
# Fallback check if the guessed MIME type is None or not text
|
||||
if content_type is None or 'text' not in content_type:
|
||||
if not any(char in content_str for char in "\n\r\t\v\f "):
|
||||
continue
|
||||
|
||||
description = ""
|
||||
#print(f"path: {f.name}; root: {root}; f.name: {f.name}; content_type: {content_type}; file: {file};")
|
||||
if is_git_directory(root):
|
||||
try:
|
||||
description = subprocess.check_output(["git", "show", "--no-patch", path], stderr=subprocess.DEVNULL).decode("utf-8").strip() or ""
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error fetching git description for {path}: {e}")
|
||||
|
||||
print(f"Documents found '{f.name}'.")
|
||||
doc_id = hashlib.sha256(f"{os.path.dirname(path)}{path}".encode()).hexdigest()
|
||||
documents.append({"text": content_str, "id": doc_id, "description": description})
|
||||
except Exception as e:
|
||||
print(f"Error reading file {path}: {e}")
|
||||
return documents
|
||||
|
||||
def load_git_data():
|
||||
if not is_git_directory():
|
||||
print("Current directory is not a Git repository.")
|
||||
return []
|
||||
|
||||
log_entries = subprocess.check_output([
|
||||
"git", "log", "--pretty=format:%h %s", "--no-merges"
|
||||
], stderr=subprocess.DEVNULL, text=True).strip().split("\n")
|
||||
|
||||
git_documents = []
|
||||
for entry in log_entries:
|
||||
commit_hash, message = entry.split(maxsplit=1)
|
||||
description = subprocess.check_output(["git", "show", "--no-patch", f"{commit_hash}"], stderr=subprocess.DEVNULL).decode("utf-8").strip() or ""
|
||||
git_documents.append({"text": f"Commit {commit_hash}: {message}", "id": commit_hash, "description": description})
|
||||
|
||||
return git_documents
|
||||
|
||||
def generate_embeddings(documents, embedding_model):
|
||||
print("Generating embeddings...")
|
||||
for doc in documents:
|
||||
text = doc["text"]
|
||||
doc_id = doc["id"]
|
||||
embedding = embedding_model.generate_embeddings([text])[0]
|
||||
doc["vector"] = embedding
|
||||
print("Done.")
|
||||
return documents
|
||||
|
||||
def upsert_documents(table, documents):
|
||||
table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute(documents)
|
||||
affected_rows = table.count_rows()
|
||||
print(f"Inserted {affected_rows} documents.")
|
||||
|
||||
def create_vector_index(table):
|
||||
try:
|
||||
print(f"Creating vector index")
|
||||
table.create_index(metric="cosine", vector_column_name="vector")
|
||||
print("Vector index created successfully.")
|
||||
except Exception as e:
|
||||
quit(f"Error creating vector index: {e}")
|
||||
|
||||
def wait_for_index(table, index_name):
|
||||
POLL_INTERVAL = 10
|
||||
while True:
|
||||
indices = table.list_indices()
|
||||
if indices and any(index.name == index_name for index in indices):
|
||||
break
|
||||
print(f"Waiting for {index_name} to be ready...")
|
||||
time.sleep(POLL_INTERVAL)
|
||||
print(f"Vector index {index_name} is ready!")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default="config.ini", help="Path to config file")
|
||||
parser.add_argument("--root", type=str, help="Root directory to process")
|
||||
parser.add_argument("--exclude", type=str, help="Exclude patterns separated by '|'")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
config = load_config(args)
|
||||
db = setup_database(config)
|
||||
embedding_model = setup_embedding_model(config)
|
||||
table = create_table(db)
|
||||
|
||||
documents = load_documents(root=args.root, exclude=args.exclude)
|
||||
git_documents = load_git_data()
|
||||
documents.extend(git_documents)
|
||||
|
||||
documents = generate_embeddings(documents, embedding_model)
|
||||
upsert_documents(table, documents)
|
||||
create_vector_index(table)
|
||||
wait_for_index(table, "vector_idx")
|
||||
print("Documents inserted.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue