Skip to content

Commit

Permalink
Merge pull request #451 from trueagi-io/sql_space
Browse files Browse the repository at this point in the history
Sql space
  • Loading branch information
Necr0x0Der authored Oct 4, 2023
2 parents a4470cc + 47ce46f commit 7bfa532
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
237 changes: 237 additions & 0 deletions python/sandbox/sql_space/sql_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import psycopg2
from hyperon import *
from hyperon.ext import register_atoms
import re


def results2bindings(vars, values):
new_bindings_set = BindingsSet.empty()
if len(values) == 0 or len(vars) != len(values[0]):
return new_bindings_set

for value in values:
bindings = Bindings()
for i in range(len(vars)):
bindings.add_var_binding(vars[i], ValueAtom(str(value[i])))
new_bindings_set.push(bindings)

return new_bindings_set


class SqlHelper:
colums_word = "ColumnNames"
insert_command_sql = "INSERT INTO"

@staticmethod
def get_query_atoms(query_atom):
children = query_atom.get_children()
new_query_atoms = []
for ch in children:
if 'limit' not in repr(ch).lower():
new_query_atoms.append(ch)
return new_query_atoms

@staticmethod
def get_fields_and_conditions(query_atom):
''' parse sql query and get columns to select and conditions for filtering '''
atoms = query_atom.get_children()
fields = {}
conditions = {}
limit = ""
vars_map = {}
for atom in atoms:
if isinstance(atom, ExpressionAtom):
items = atom.get_children()
if len(items) == 3:
id_fields = items[1].get_children()
current_field_info = items[2].get_children()
if len(id_fields) != 2 or len(current_field_info) != 2:
raise SyntaxError("Incorrect number of arguments")
# (musicbrainz.artist (id $id) (name $name))
# identification field
id_name = repr(id_fields[0])
vars_map[id_name] = repr(id_fields[1])
# field to select
field_name = repr(current_field_info[0])
vars_map[field_name] = repr(current_field_info[1])
# table
table = repr(items[0])
if table not in fields:
fields[table] = set()
if table not in conditions:
conditions[table] = set()
# add id field to corresponding category (filed/condition)
if isinstance(id_fields[1], VariableAtom):
fields[table].add(id_name)
else:
conditions[table].add(id_name)
# add selected field to corresponding category (filed/condition)
if isinstance(current_field_info[1], VariableAtom):
fields[table].add(field_name)
else:
conditions[table].add(field_name)

if len(items) == 2 and ("limit" in repr(items[0]).lower()):
limit = repr(items[1])
return fields, conditions, limit, vars_map

@staticmethod
def get_fields_and_values(query_atom):
''' parse sql query and get columns to select and conditions for filtering '''
atoms = query_atom.get_children()
fields_map = {}
for atom in atoms:
if isinstance(atom, ExpressionAtom):
items = atom.get_children()
if len(items) != 2:
raise SyntaxError("Incorrect number of arguments")
# (musicbrainz.artist (id $id) (name $name)
# field to select
field_name = repr(items[0])
fields_map[field_name] = repr(items[1])
return fields_map

def save_query_result(self, sql_space, space, query_atom):
# if no fields provided get them from information_schema.columns
res = sql_space.query(query_atom)
variables = []
for val in res:
temp_dict = {}
for k, v in val.items():
temp_dict['$' + str(k)] = str(v)
variables.append(temp_dict)
atoms = self.get_query_atoms(query_atom)
new_atoms = []
for var in variables:
for atom in atoms:
if isinstance(atom, ExpressionAtom):
temp = repr(atom)
for k, v in var.items():
temp = temp.replace(k, v)
new_atoms.append(temp)
for atom in new_atoms:
space.add_atom(E(S(atom)))
return res

def insert(self, space, query_atom):
fields_map = SqlHelper.get_fields_and_values(query_atom)
res = []
table = fields_map.pop("table")
values = []
for field_name, field_value in fields_map.items():
values.append(field_value.replace('"', "") if "(" in field_value and field_value[-2] == ')'
else field_value.replace('"', "'"))
fields_str = ", ".join(list(fields_map.keys()))
values_str = ", ".join(list(values))
query = f'''{self.insert_command_sql} {table} ({fields_str}) VALUES ({values_str}) RETURNING 0;'''
res.extend(space.query(E(S(query))))
return res


class SqlSpace(GroundingSpace):
def __init__(self, database, host, user, password, port):
super().__init__()
self.conn = psycopg2.connect(database=database,
host=host,
user=user,
password=password,
port=port)
self.cursor = self.conn.cursor()

def from_space(self, cspace):
self.gspace = GroundingSpaceRef(cspace)

def construct_query(self, query_atom):
fields, conditions, limit, vars_map = SqlHelper.get_fields_and_conditions(query_atom)
sql_query = "SELECT"

vars_names = []
for k, values in fields.items():
for val in values:
sql_query = sql_query + f" {k}.{val},"
vars_names.append(vars_map[val])
sql_query = sql_query[:-1] + " FROM "
for k in fields.keys():
sql_query = sql_query + f"{k},"

sql_condition = " WHERE"
for k, values in conditions.items():
for val in values:
if val in vars_map:
sql_condition = sql_condition + f" {k}.{val} = {vars_map[val]} AND"
if len(sql_condition) > 6:
sql_query = sql_query[:-1] + sql_condition[:-4]
else:
sql_query = sql_query[:-1]
if len(limit) > 0:
sql_query = sql_query + f" LIMIT {limit}"
return sql_query, vars_names

def insert(self, sql_query):
try:
if len(sql_query) > 6:
self.cursor.execute(sql_query)
self.conn.commit()
except (Exception, psycopg2.DatabaseError) as error:
bindings_set = BindingsSet.empty()
bindings = Bindings()
bindings.add_var_binding("error on insert: ", ValueAtom(error))
bindings_set.push(bindings)
return bindings_set
return BindingsSet.empty()

def query(self, query_atom):
try:
atoms = query_atom.get_children()
if len(atoms) > 0 and SqlHelper.insert_command_sql in repr(atoms[0]):
return self.insert(repr(atoms[0]))
else:
new_bindings_set = BindingsSet.empty()
sql_query, vars_names = self.construct_query(query_atom)
if len(sql_query) > 6:
self.cursor.execute(sql_query)
values = self.cursor.fetchall()
if len(vars_names) == 0 and len(values) > 0:
vars = [f"var{i + 1}" for i in range(len(values[0]))]
else:
vars = [v[1:] for v in vars_names]
if len(vars) > 0 and len(values) > 0:
return results2bindings(vars, values)
return new_bindings_set
except (Exception, psycopg2.DatabaseError) as error:
print(error)


def wrapsqlop(func):
def wrapper(*args):
if len(args) > 1:
if isinstance(args[0], GroundedAtom):
space1 = args[0].get_object()
if isinstance(space1, SpaceRef):
if isinstance(args[1], GroundedAtom):
space2 = args[1].get_object()
if isinstance(space2, SpaceRef):
args = args[2:]
res = func(space1, space2, *args)
return [ValueAtom(val) for val in res]
else:
args = args[1:]
res = func(space1, *args)
return [ValueAtom(val) for val in res]
return []

return wrapper


@register_atoms
def sql_space_atoms():
helper = SqlHelper()
newSQLSpaceAtom = OperationAtom('new-sql-space', lambda database, host, user, password, port: [
G(SpaceRef(SqlSpace(database, host, user, password, port)))], unwrap=False)
saveQueryResult = G(OperationObject('sql.save-query-result', wrapsqlop(helper.save_query_result), unwrap=False))
sqlInsert = G(OperationObject('sql.insert', wrapsqlop(helper.insert), unwrap=False))
return {
r"new-sql-space": newSQLSpaceAtom,
r"sql.save-query-result": saveQueryResult,
r"sql.insert": sqlInsert
}
16 changes: 16 additions & 0 deletions python/sandbox/sql_space/sql_space_test.metta
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
!(extend-py! sql_space)
; database can be installed with use of https://github.com/metabrainz/musicbrainz-docker#publish-ports-of-all-services
; the description https://musicbrainz.org/doc/MusicBrainz_Database
!(bind! &sql_space (new-sql-space musicbrainz_db localhost musicbrainz musicbrainz 5432))
; save sql query results into given space
!(sql.save-query-result &sql_space &self (, (musicbrainz.artist (id $id) (name $name)) (musicbrainz.artist (id $id) (begin_date_year 1977)) (limit 3)))
!(get-atoms &self)
;result : [GroundingSpace, ((musicbrainz.artist (id "127482") (name "Kanye West"))), ((musicbrainz.artist (id "127482") (begin_date_year 1977))), ((musicbrainz.artist (id "23366") (name "The Dirty Dozen Brass Band"))), ((musicbrainz.artist (id "23366") (begin_date_year 1977))), ((musicbrainz.artist (id "35629") (name "Fabolous"))), ((musicbrainz.artist (id "35629") (begin_date_year 1977)))]

!(match &sql_space (, (musicbrainz.artist (id $id) (name $name)) (musicbrainz.artist (id $id) (begin_date_year 1983)) (limit 3)) $name)
;result ["NOFX", "Red Hot Chili Peppers", "Bon Jovi"]

!(sql.insert &sql_space ((table musicbrainz.artist) (gid "uuid_generate_v4()")
(name "some name3") (sort_name "name some3") (begin_date_year 1988) (begin_date_month 1)
(begin_date_day 1) (type 1) (area 222) (gender 2)))

0 comments on commit 7bfa532

Please sign in to comment.