Initial code commit
This commit is contained in:
parent
14c06787dc
commit
9842d55bbf
7 changed files with 441 additions and 32 deletions
81
lancedb_context_provider.py
Executable file
81
lancedb_context_provider.py
Executable file
|
@ -0,0 +1,81 @@
|
|||
#!/usr/bin/env python3
|
||||
from fastapi import FastAPI, Depends
|
||||
from pydantic import BaseModel
|
||||
import lancedb
|
||||
from lancedb.embeddings.ollama import OllamaEmbeddings
|
||||
import configparser
|
||||
import argparse
|
||||
import os
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def load_config(args):
|
||||
config = configparser.ConfigParser()
|
||||
config.read(args.config)
|
||||
return config
|
||||
|
||||
def setup_database(config):
|
||||
s3_bucket = config.get("s3", "bucket_name", fallback=None)
|
||||
enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on']
|
||||
persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb")
|
||||
|
||||
if s3_bucket and enable_s3:
|
||||
storage_options = {
|
||||
"aws_access_key_id": config.get("s3", "access_key_id", fallback=None),
|
||||
"aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None),
|
||||
"region": config.get("s3", "region", fallback="us-east-1"),
|
||||
"endpoint_url": config.get("s3", "endpoint", fallback=None),
|
||||
}
|
||||
db = lancedb.connect(
|
||||
f"s3://{s3_bucket}",
|
||||
storage_options=storage_options
|
||||
)
|
||||
else:
|
||||
db = lancedb.connect(persist_directory)
|
||||
|
||||
return db
|
||||
|
||||
class ContextProviderInput(BaseModel):
|
||||
query: str
|
||||
fullInput: str
|
||||
|
||||
@app.post("/retrieve")
|
||||
async def retrieve_context(input: ContextProviderInput, embedding_function: OllamaEmbeddings = Depends(lambda: ollama_ef)):
|
||||
# Generate embedding from the query
|
||||
query_embedding = embedding_function.generate_embeddings([input.query])[0]
|
||||
|
||||
# Search for similar documents
|
||||
results = table.search(query_embedding).distance_type("cosine").limit(3).to_list()
|
||||
|
||||
# Create a list of context items
|
||||
context_items = []
|
||||
for result in results:
|
||||
context_items.append({
|
||||
"name": result.get("id", "unknown"),
|
||||
"description": result.get("description", "document"),
|
||||
"content": result["text"]
|
||||
})
|
||||
return context_items
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default="config.ini", help="Path to config file")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# Load configuration and create database connection
|
||||
config = load_config(args)
|
||||
db = setup_database(config)
|
||||
|
||||
# Open table
|
||||
table = db.open_table("vectordb")
|
||||
|
||||
# Initialize Ollama embedding function
|
||||
ollama_url = config.get("ollama", "url")
|
||||
ollama_ef = OllamaEmbeddings(model="nomic-embed-text", url=ollama_url)
|
||||
|
||||
# Run the application
|
||||
host = config.get("server", "host")
|
||||
port = int(config.get("server", "port"))
|
||||
uvicorn.run(app, host=host, port=port)
|
Loading…
Add table
Add a link
Reference in a new issue