SnowChatSQL — A Streamlit AI assistant for analyzing a Snowflake database.

Ryan Lindbeck
7 min readMay 2, 2023

--

SnowChatSQL is a Streamlit app that uses OpenAI, ChromaDB and Snowflake to allow users to execute queries using natural language against a Snowflake database.

TL;DR: Check out my Github repo to dive directly into the code: https://github.com/rycharlind/snowchatsql

Overview

SnowChatSQL will take any question from the user, pass it into OpenAI’s “text-davinci-003” model, along with the relevant database schema information to retrieve an AI generated SQL that can then be executed against Snowflake and displayed as a DataFrame in Streamlit.

How It works

  • Initialize a vector database to contain a document for each table along with the table’s schema. This allows us to find relevant database schema properties related to the users question, so that we can provide a context to the AI.
  • Given a request from the user to query the database, we then search the vector store for the relevant table schemas and pass those along to the prompt builder in order to ask the AI for the SQL.
  • Execute the AI provided SQL against your Snowflake data warehouse using Snowpark.
  • Display the SQL results on a Streamlit page within a DataFrame.

Initialize the vector database

In order to properly prompt the AI with the ability to create our SQL query, we need to first provide it a context of our database schema. This can actually be a lot of information, and too much to pass all of it into OpenAI at one time. This is where vector databases (chromaDB) come in. We can store the Snowflake database schema information into chromaDB, then search against it to retrieve only the relevant schema information (documents) we need in order to build the SQL query.

from snowchatsql.config.config import Config
from snowchatsql.snowflake import Snowflake
from snowchatsql.vector_store import VectorStore
from snowchatsql.prompt_builder import PromptBuilder

config = Config()
prompt_builder = PromptBuilder()
snowflake = Snowflake(config)
vector_store = VectorStore(config)

tables = snowflake.get_tables()
schema = {table: snowflake.get_table_fields(table) for table in tables}

table_fields = []
for table, fields in schema.items():
table_fields.append(prompt_builder.get_schema_prompt(table, fields))

vector_store.persist_database_schema(config.chroma.collection_name, tables, table_fields)

This code retrieves a list of tables from Snowflake database using the get_tables() function, then for each table, it uses get_table_fields() function to get a list of its fields.

The results are stored in a dictionary named schema where the keys are the table names and the values are the corresponding field names.

Next, the get_schema_prompt() function is used to generate a prompt for each table using the table name and its corresponding fields.

These prompts are then appended to a list named table_fields. Finally, the persist_database_schema() function is called to save the schema of the database into a chromaDB using the specified configuration parameters.

Prompt Builder

Next we build a service that will generate an AI prompt that we can pass to OpenAI in order to produce the appropriate SQL statement for Snowflake.

class PromptBuilder():

def get_prompt_template(self, prompt_schema: str, prompt: str):
return f"""### Snowflake SQL tables, with their properties:
#
{prompt_schema}
#
{prompt}
#
Respond only in Snowflake SQL.
If you don't know what value to use for a field, do your best to fill it in so that the SQL will execute properly.
###
"""

def build_from_schema(self, schema):
schema_str = ""
for table, fields in schema.items():
out = self.get_schema_prompt(table, fields)
schema_str += f"{out}\n"

return schema_str

def build_from_documents(self, documents: list) -> str:
out = "\n".join(list(map(lambda document: document, documents)))
return f"{out}"

def get_schema_prompt(self, table: str, fields: list) -> str:
return f"{table} ({', '.join(list(map(lambda field: self.get_field_line(field), fields)))})))"

def get_field_line(self, field) -> str:
return f"{field['name']} ({field['type']})"

Here we define a Python class called PromptBuilder which contains several methods to generate SQL prompts for the Snowflake database. Here is a brief summary of each method:

  1. get_prompt_template: This method takes two parameters, prompt_schema and prompt, and returns a string containing a pre-defined template for a SQL prompt, with the prompt_schema and prompt values included. This template includes instructions for the AI to respond only in Snowflake SQL, and to do its best to fill in any missing field values.
  2. build_from_schema: This method takes a dictionary schema as input, where each key represents a table name and the value is a list of dictionaries containing the column name and data type for each column in the table. The method then generates a prompt string for each table using the get_schema_prompt method, concatenates them together, and returns the resulting string.
  3. build_from_documents: This method takes a list of documents as input and concatenates them together into a single string.
  4. get_schema_prompt: This method takes a table name and a list of dictionaries representing the table's fields, and returns a string containing a prompt for the table. The prompt includes the table name and a list of prompts for each field, which are generated using the get_field_line method.
  5. get_field_line: This method takes a dictionary representing a field in a table and returns a string containing a prompt for that field. The prompt includes the field name and data type.

Snowflake / Snowpark

This code defines a Python class called Snowflake which allows users to interact with the Snowflake database using the Snowpark API. The Snowflake class takes a Config object as input in its constructor. The Config object contains configuration information such as account name, user, password, warehouse, database, and schema for connecting to the Snowflake database.

from snowflake.snowpark import Session, DataFrame
from snowchatsql.config.config import Config

class Snowflake():
def __init__(self, config: Config) -> None:
self.config = config
self.session = self.get_session()

def get_session(self) -> Session:
connection_parameters = {
"account": self.config.snowflake.account,
"user": self.config.snowflake.user,
"password": self.config.snowflake.password,
"warehouse": self.config.snowflake.warehouse,
"database": self.config.snowflake.database,
"schema": self.config.snowflake.schema
}

return Session.builder.configs(connection_parameters).create()

def get_tables(self) -> list:
tables = self.session.sql(f"SHOW TABLES").collect()
return list(map(lambda table: table.name, tables))

def get_table_fields(self, table: str) -> list:
schema = self.session.sql(f"DESCRIBE TABLE {table}").collect()
return list(map(lambda field: {'name': field.name, 'type': field.type}, schema))

def sql(self, query: str) -> DataFrame:
return self.session.sql(query)

The get_session() method creates a Session object using the configuration parameters provided by the Config object. The get_tables() method returns a list of all tables available in the Snowflake database using the SHOW TABLES SQL command. The get_table_fields() method takes a table name as input and returns a list of dictionaries containing the name and type of each field in the specified table using the DESCRIBE TABLE SQL command.

Finally, the sql() method takes an SQL query as input and returns the result as a DataFrame object. This method simply passes the input query to the sql() method of the Session object and returns the result as a DataFrame.

Overall, this code provides a convenient interface for interacting with the Snowflake database using the Snowpark API. It allows users to easily retrieve information about tables and fields, as well as execute arbitrary SQL queries and return the results in a familiar DataFrame format.

ChatSQL Streamlit Page

Below showcases the “ChatSQL” Streamlit page that take a user “prompt” as an input, utilizes the Prompt Builder to generate a prompt to OpenAI in order to translate it into a Snowflake SQL query … all based on database schema context that was initialized into your vector store.

import streamlit as st
import openai
from snowchatsql.c_state import CState
from snowchatsql.snowflake import Snowflake
from snowchatsql.config.config import Config
from snowchatsql.vector_store import VectorStore
from snowchatsql.prompt_builder import PromptBuilder

st.set_page_config(layout='wide')
st.header("Chat SQL")

config = Config()
snowflake = Snowflake(config)
vector_store = VectorStore(config)
prompt_builder = PromptBuilder()

if CState.CHAT_GENERATED not in st.session_state:
st.session_state[CState.CHAT_GENERATED] = []

if CState.CHAT_PAST not in st.session_state:
st.session_state[CState.CHAT_PAST] = []

if CState.RESULT_DATAFRAME not in st.session_state:
st.session_state[CState.RESULT_DATAFRAME] = []

st.caption("Use the below text area to enter your prompt. The prompt will be used to generate a SQL query. The query will be executed against Snowflake and the results will be displayed below.")

prompt = st.text_area("Enter your prompt here", key=CState.CHAT_PROMPT)


if prompt:
related_documents = vector_store.search(prompt=prompt, collection_name=config.chroma.collection_name)
prompt_schema = prompt_builder.build_from_documents(related_documents)
prompt_final = prompt_builder.get_prompt_template(prompt_schema, prompt=prompt)

response = openai.Completion.create(
model="text-davinci-003",
prompt=prompt_final,
temperature=0,
max_tokens=250,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["#", ";"]
)

output = response["choices"][0]["text"]

st.session_state[CState.CHAT_GENERATED].append(output)
st.session_state[CState.CHAT_PAST].append(prompt)

try:
df = snowflake.sql(output)
st.session_state[CState.RESULT_DATAFRAME].append(df)
except Exception as e:
st.warning("Could not execute this SQL. It may need some manual tweaking.")

if st.session_state[CState.CHAT_GENERATED]:

for i in range(len(st.session_state[CState.CHAT_GENERATED]) - 1, -1, -1):
st.info(st.session_state[CState.CHAT_PAST][i])
st.code(st.session_state[CState.CHAT_GENERATED][i], language="sql")
st.dataframe(st.session_state[CState.RESULT_DATAFRAME][i])
st.divider()

This code uses the Streamlit library to create a web application for generating SQL queries and executing them against Snowflake data warehouse. The script imports several modules including openai, Snowflake, Config, VectorStore, and PromptBuilder from the snowchatsql package, and st from the Streamlit library.

After setting the page layout to wide, the script adds a header “Chat SQL” to the web app. It also creates instances of various classes including Config, Snowflake, VectorStore, and PromptBuilder.

Next, the script checks if certain session state variables are not present, and if not, initializes them. These variables include CHAT_GENERATED, CHAT_PAST, and RESULT_DATAFRAME. The session state variables are used to store the state of the application during the current user session.

The script then adds a caption to the web app, which instructs the user to enter their SQL query prompt in a text area provided. The prompt entered by the user is stored in the CHAT_PROMPT key of the session state.

If a prompt is entered, the script uses the VectorStore to search for related documents, builds a prompt schema using PromptBuilder, and generates a prompt template. The generated prompt is then used to make an API request to OpenAI's language model text-davinci-003 to generate a SQL query response.

The SQL query response generated by OpenAI is stored in the CHAT_GENERATED session state variable, and the prompt entered by the user is stored in the CHAT_PAST session state variable. The script then attempts to execute the SQL query against the Snowflake data warehouse using the Snowflake class. If successful, the resulting data frame is stored in the RESULT_DATAFRAME session state variable.

Finally, the script checks if there are any SQL queries generated in the current user session, and if so, displays them in the web app along with their corresponding prompt and result data frame. The st.info, st.code, st.dataframe, and st.divider functions are used to display the prompt, SQL code, data frame, and a divider respectively.

In Summary

SnowChatSQL is a Streamlit app that allows users to execute queries using natural language against a Snowflake database. It uses OpenAI’s “text-davinci-003” model, ChromaDB, and Snowflake to generate SQL queries based on the database schema and user prompts. It contains modules to initialize a vector database with a document for each table and its schema. The PromptBuilder class generates SQL prompts for the Snowflake database based on the schema and user prompts. The ChatSQL Streamlit page takes user prompts as input, generates a SQL query using OpenAI and PromptBuilder, and executes it against Snowflake. The results are displayed as a DataFrame in the Streamlit page.

Check out the full code base on Github: https://github.com/rycharlind/snowchatsql

--

--

Ryan Lindbeck
Ryan Lindbeck

Written by Ryan Lindbeck

Strategic Visionary Leader in Healthcare Analytics | Software & Data Engineer

No responses yet