CodeRecall/lancedb_context_provider.py

81 lines
No EOL
2.6 KiB
Python
Executable file

#!/usr/bin/env python3
from fastapi import FastAPI, Depends
from pydantic import BaseModel
import lancedb
from lancedb.embeddings.ollama import OllamaEmbeddings
import configparser
import argparse
import os
import uvicorn
app = FastAPI()
def load_config(args):
config = configparser.ConfigParser()
config.read(args.config)
return config
def setup_database(config):
s3_bucket = config.get("s3", "bucket_name", fallback=None)
enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on']
persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb")
if s3_bucket and enable_s3:
storage_options = {
"aws_access_key_id": config.get("s3", "access_key_id", fallback=None),
"aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None),
"region": config.get("s3", "region", fallback="us-east-1"),
"endpoint_url": config.get("s3", "endpoint", fallback=None),
}
db = lancedb.connect(
f"s3://{s3_bucket}",
storage_options=storage_options
)
else:
db = lancedb.connect(persist_directory)
return db
class ContextProviderInput(BaseModel):
query: str
fullInput: str
@app.post("/retrieve")
async def retrieve_context(input: ContextProviderInput, embedding_function: OllamaEmbeddings = Depends(lambda: ollama_ef)):
# Generate embedding from the query
query_embedding = embedding_function.generate_embeddings([input.query])[0]
# Search for similar documents
results = table.search(query_embedding).distance_type("cosine").limit(3).to_list()
# Create a list of context items
context_items = []
for result in results:
context_items.append({
"name": result.get("id", "unknown"),
"description": result.get("description", "document"),
"content": result["text"]
})
return context_items
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config.ini", help="Path to config file")
args, _ = parser.parse_known_args()
# Load configuration and create database connection
config = load_config(args)
db = setup_database(config)
# Open table
table = db.open_table("vectordb")
# Initialize Ollama embedding function
ollama_url = config.get("ollama", "url")
ollama_ef = OllamaEmbeddings(model="nomic-embed-text", url=ollama_url)
# Run the application
host = config.get("server", "host")
port = int(config.get("server", "port"))
uvicorn.run(app, host=host, port=port)