81 lines
No EOL
2.6 KiB
Python
Executable file
81 lines
No EOL
2.6 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
from fastapi import FastAPI, Depends
|
|
from pydantic import BaseModel
|
|
import lancedb
|
|
from lancedb.embeddings.ollama import OllamaEmbeddings
|
|
import configparser
|
|
import argparse
|
|
import os
|
|
import uvicorn
|
|
|
|
app = FastAPI()
|
|
|
|
def load_config(args):
|
|
config = configparser.ConfigParser()
|
|
config.read(args.config)
|
|
return config
|
|
|
|
def setup_database(config):
|
|
s3_bucket = config.get("s3", "bucket_name", fallback=None)
|
|
enable_s3 = config.get("s3", "enable", fallback="false").lower() in ['true', 'yes', '1', 'on']
|
|
persist_directory = config.get("lancedb", "persist_directory", fallback="./lancedb")
|
|
|
|
if s3_bucket and enable_s3:
|
|
storage_options = {
|
|
"aws_access_key_id": config.get("s3", "access_key_id", fallback=None),
|
|
"aws_secret_access_key": config.get("s3", "secret_access_key", fallback=None),
|
|
"region": config.get("s3", "region", fallback="us-east-1"),
|
|
"endpoint_url": config.get("s3", "endpoint", fallback=None),
|
|
}
|
|
db = lancedb.connect(
|
|
f"s3://{s3_bucket}",
|
|
storage_options=storage_options
|
|
)
|
|
else:
|
|
db = lancedb.connect(persist_directory)
|
|
|
|
return db
|
|
|
|
class ContextProviderInput(BaseModel):
|
|
query: str
|
|
fullInput: str
|
|
|
|
@app.post("/retrieve")
|
|
async def retrieve_context(input: ContextProviderInput, embedding_function: OllamaEmbeddings = Depends(lambda: ollama_ef)):
|
|
# Generate embedding from the query
|
|
query_embedding = embedding_function.generate_embeddings([input.query])[0]
|
|
|
|
# Search for similar documents
|
|
results = table.search(query_embedding).distance_type("cosine").limit(3).to_list()
|
|
|
|
# Create a list of context items
|
|
context_items = []
|
|
for result in results:
|
|
context_items.append({
|
|
"name": result.get("id", "unknown"),
|
|
"description": result.get("description", "document"),
|
|
"content": result["text"]
|
|
})
|
|
return context_items
|
|
|
|
if __name__ == "__main__":
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="config.ini", help="Path to config file")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
# Load configuration and create database connection
|
|
config = load_config(args)
|
|
db = setup_database(config)
|
|
|
|
# Open table
|
|
table = db.open_table("vectordb")
|
|
|
|
# Initialize Ollama embedding function
|
|
ollama_url = config.get("ollama", "url")
|
|
ollama_ef = OllamaEmbeddings(model="nomic-embed-text", url=ollama_url)
|
|
|
|
# Run the application
|
|
host = config.get("server", "host")
|
|
port = int(config.get("server", "port"))
|
|
uvicorn.run(app, host=host, port=port) |