This is to some degree a continuation of my previous post where I built an egraph implementation using duckdb.

In that post, I already had the core that translated datalog style rules to SQL style queries. In this post, I took out the complicated treatment of terms and generating fresh identifiers in the database, but did add back in semi-naive evaluation and also cleaned up the implementation a little.

Once I had that, I added stratified execution and negation.

I had originally used DuckDB because I thought that I had an OLAP workload. Now I’m not so sure that really datalog or egraphs are an OLAP workload. There is a lot of updating, and deleting. Also it turns out that the nonstandard UPSERT feature of sqlite really improved the performance when one is reconciling the bag semantics of SQL with the set semantics of datalog. I could not find a way of phrasing my SQL such that DuckDB could beat SQLite. In fact, SQLite destroys DuckDB. Profiling a bit, it appears that really the almost trivial administrative shuffling and deduplication is where DuckDB spends almost all it’s time.

The full source of datalite is at the bottom of the post and also here

Seminaive Evaluation

Seminaive evaluation is a strategy by which you can avoid a lot of duplicate work during datalog execution.In datalog, you are running your rules over and over until you can find no more new tuples. The basic idea is that you need to supply a rule with at least 1 new tuple, or else you will only derive things that you’ve already seen.

What this means is you convert a rule like Z :- A,B,C

into 3 rules

new_Z :- dA,B,C new_Z :- A,dB,C new_Z :- dA,B,dC

It turns out that actually when you implement seminaive evaluation, you almost certainly need 3 separate buckets for each relation: old, delta, and new. From typical descriptions, you’d think you need only 2 buckets old and delta. The “new” bucket gives you a place to temporarily hold tuples until you have completed your iteration and can clear out the delta relation.

So when we define a relation, we create these 3 tables

The UPSERT semantics were really really convenient. By defining the relations with a unique constraint, sqlite detects duplication and will ignore the insert if a duplication is detected. This is the function that is called to define a new relation. It creates tables with the unique constraints in SQLlite and returns a helper function to generate Atoms.


def delta(name):
    return f"datalite_delta_{name}"


def new(name):
    return f"datalite_new_{name}"

def Relation(self, name, *types):
    if name not in self.rels:
        self.rels[name] = types
        args = ", ".join(
            [f"x{n} {typ} NOT NULL" for n, typ in enumerate(types)])
        bareargs = ", ".join(
            [f"x{n}" for n, _typ in enumerate(types)])
        self.execute(
            f"CREATE TABLE {name}({args}, PRIMARY KEY ({bareargs})) WITHOUT ROWID")
        self.execute(
            f"CREATE TABLE {new(name)}({args}, PRIMARY KEY ({bareargs})) WITHOUT ROWID")
        self.execute(
            f"CREATE TABLE {delta(name)}({args}, PRIMARY KEY ({bareargs})) WITHOUT ROWID")
    else:
        assert self.rels[name] == types
    return lambda *args: Atom(name, args)

# ---------------------------------------------------
# https://www.sqlite.org/datatype3.html
INTEGER = "INTEGER"
TEXT = "TEXT"
REAL = "REAL"
BLOB = "BLOB"
JSON = "TEXT"
# example usage

edge = s.Relation("edge", INTEGER, INTEGER)
path = s.Relation("path", INTEGER, INTEGER)
N = 999

x, y, z = Vars("x y z")
s.add_fact(edge(0, 1))
s.add_rule(edge(0, 1), [])
s.add_rule(edge(y, Expr("{y} + 1")), [edge(x, y), "{y} < 1000"])
s.add_rule(path(x, y), [edge(x, y)])
s.add_rule(path(x, z), [edge(x, y), path(y, z)])

For seminaive evaluation each rule with a body of N atoms is compiled to N statements, where each statement has one of the relations replaced with a delta relation. The froms are the SQL compiled versions of the body relations

# in def compile():
            stmts = []
            for n in range(len(froms)):
                froms1 = copy(froms)
                # cheating a little here. froms actually contains "name AS alias"
                froms1[n] = delta(froms1[n])
                froms1 = ", ".join(froms1)
                stmts.append(
                    f"INSERT OR IGNORE INTO {new(head.name)} SELECT DISTINCT {selects} FROM {froms1}{wheres} ")
            return stmts

Here is the run loop

   def run(self):
        stmts = []
        for head, body in self.rules:
            stmts += self.compile(head, body)
        while True:
            for stmt in stmts:
                self.execute(stmt)
            num_new = 0
            for name, types in self.rels.items():
                self.execute(f"DELETE FROM {delta(name)}")
                wheres = " AND ".join(
                    [f"{new(name)}.x{n} = {name}.x{n}" for n in range(len(types))])
                self.execute(
                    f"INSERT OR IGNORE INTO {delta(name)} SELECT DISTINCT * FROM {new(name)}")
                wheres = " AND ".join(
                    [f"{delta(name)}.x{n} = {name}.x{n}" for n in range(len(types))])
                self.execute(
                    f"DELETE FROM {delta(name)} WHERE EXISTS (SELECT * FROM {name} WHERE {wheres})")
                self.execute(
                    f"INSERT OR IGNORE INTO {name} SELECT * FROM {delta(name)}")
                self.execute(f"DELETE FROM {new(name)}")
                self.execute(f"SELECT COUNT(*) FROM {delta(name)}")
                n = self.con.fetchone()[0]
                num_new += n
            if num_new == 0:
                break

At the end there is the administrative part of the seminaive loop. We

  1. delete the current delta relations, having used them
  2. put the new_relation into the delta relation
  3. Delete anything in delta that is already in the old relation
  4. merge the delta into the old relation
  5. count how many new tuples were discovered
  6. break the loop when no new tuples discovered

Stratification

One possibility (the simplest) is to just run the datalog program as one big mess.

There are however both performance and semantic benefits to determining which relations depend on which. The relations in the heads of rules depend on the relations in the bodies. Consider the following dependencies

A :- B
B :- A
C :- A, B
C :- C

A depends on B, B depends on A and C depends on A and B. What we can do is run the rules with A and B in the head until termination, then run the rules with C in the head. We can do this because we know no rules will generate new A and B tuples at this point. They can’t. We’ve derived everything we can for those relations.

Stratification is a nice optimization, but more interestingly it is an approach to put negation into the language of datalog, which increases it’s expressive power. Rules may ask about tuples not existing in relations of previous strata.

I used networkx, a python graph library, to do the dependency analysis. This make it quite easy. I get the simply connected components, build the condensation graph where each component is collapsed to a vertex, and get a topological dependency order.

    def stratify(self):
        G = nx.DiGraph()

        for head,  body in self.rules:
            for rel in body:
                if isinstance(rel, Atom):
                    G.add_edge(rel.name, head.name)
                elif isinstance(rel, Not):
                    assert isinstance(rel.val, Atom)
                    G.add_edge(rel.val.name, head.name)

        scc = list(nx.strongly_connected_components(G))
        cond = nx.condensation(G, scc=scc)
        for n in nx.topological_sort(cond):
            # TODO: negation check
            yield scc[n]

It’s a little more complicated to actually perform the stratification. There are probably many ways to arrange this. What I chose is that the rules that depend only upon previous strata need to be run once naively rather than seminaively at the start of the strata. The other rules that are recursive in the strata are run in a semi naive loop. I chose to not make the optimization that I know the delta relation of previous strata will always be empty. Sqlite performs these trivial rules very quickly anyway, and benchmarking, they are not very relevant

def run(self):
    for strata in self.stratify():
        stmts = []
        for head, body in self.rules:
            if head.name in strata:
                if len(body) == 0:
                    self.add_fact(head)
                elif any([rel.name in strata for rel in body if isinstance(rel, Atom)]):
                    stmts += self.compile(head, body)
                else:
                    # These are not recursive rules
                    # They need to be run once naively and can then be forgotten
                    stmt = self.compile(head, body, naive=True)
                    self.execute(stmt)
        # Prepare initial delta relation
        for name in strata:
            self.execute(
                f"INSERT OR IGNORE INTO {delta(name)} SELECT DISTINCT * FROM {new(name)}")
            self.execute(
                f"INSERT OR IGNORE INTO {name} SELECT DISTINCT * FROM {new(name)}")
            self.execute(
                f"DELETE FROM {new(name)}")
        iter = 0
        # Seminaive loop
        while True:
            iter += 1
            # print(iter)
            for stmt in stmts:
                self.execute(stmt)
            num_new = 0
            # Administrative stuff. dR = nR - R
            for name, types in self.rels.items():
                self.execute(f"DELETE FROM {delta(name)}")
                wheres = " AND ".join(
                    [f"{new(name)}.x{n} = {name}.x{n}" for n in range(len(types))])
                self.execute(
                    f"INSERT OR IGNORE INTO {delta(name)} SELECT DISTINCT * FROM {new(name)}")
                wheres = " AND ".join(
                    [f"{delta(name)}.x{n} = {name}.x{n}" for n in range(len(types))])
                self.execute(
                    f"DELETE FROM {delta(name)} WHERE EXISTS (SELECT * FROM {name} WHERE {wheres})")
                self.execute(
                    f"INSERT OR IGNORE INTO {name} SELECT * FROM {delta(name)}")
                self.execute(f"DELETE FROM {new(name)}")

                self.execute(f"SELECT COUNT(*) FROM {delta(name)}")
                n = self.cur.fetchone()[0]
                num_new += n
            if num_new == 0:
                break

Treating negation is done by translating Not(foo(1,2)) into a statement like NOT EXISTS (SELECT * FROM foo WHERE foo.x0 = 1 AND foo.x1 = 2) that occurs on the WHERE clause of the rule. These rules appear to be rather slow, so perhaps there is another encoding that would be better, but this is a very straightforward one.

        for c in body:
            # ...
            if isinstance(c, Not) and isinstance(c.val, Atom):
                args = c.val.args
                fname = fresh(c.val.name)
                conds = " AND ".join(
                    [f"{fname}.x{n} = {conv_arg(arg)}" for n, arg in enumerate(args)])
                wheres.append(
                    f"NOT EXISTS (SELECT * FROM {c.val.name} AS {fname} WHERE {conds})")

Injecting SQL

SQL is already a pretty powerful language. That is what enables this translation to be so small. We are translating from one high level language to a slightly different one.

The main thing datalite does is translated logical style queries with variables into the column and row oriented style of SQL and run the fixpoint loop.

There is nothing really to add on top of SQL in term of filtering expressions. It feels like a lot of duplicated work to build an AST datatype that translates directly to SQL expressions, so instead I went with kind of a hacky solution of allowing directly injected SQL format strings. The format string nature is because you want ot be able to refer to datalog variables and not the underlying compiled sql expression constructs.

These injected relations can go two places, in the arguments of relations and as atoms in the body of a rule. In both cases they are translated to be part of the WHERE clause of the SQL query


@dataclass(frozen=True)
class Expr:
    expr: str

# In def compile():
# ************************
        wheres = []
        formatvarmap = {k.name: i[0]
                        for k, i in varmap.items() if isinstance(k, Var)}
        for c in body:
            if isinstance(c, str):  # Injected SQL constraint expressions
                wheres.append(c.format(**formatvarmap))
# ******************************

Oh yeah, since the entire construction of datalite is via munging together SQL strings from the user input, it is wildly unsafe to expose to untrusted users. Nuff said.

Ballpark Benchmarks

I ran some fishy benchmarks to see where we are in the scheme of things. The example problem I used is transitive closure of a graph that is one big line of 1000 nodes.

Here is the example souffle:

.decl edge(x : number, y : number)

edge(0,1).
edge(m, m+1) :- edge(_n, m), m < 999.


.decl path(x : number, y : number)
path(x,y) :- edge(x,y).
path(x,z) :- edge(x,y), path(y,z).

.printsize path
  • Souffle 0.3s
  • Datalite 1.2s
  • Custom python 0.3s
  • WITH RECURSIVE 1s

All in all, let’s say that datalite has a 5x slowdown on this particular example (it is not at all clear that this example is indicative of common datalog workloads, but it is what it is.). I’ll take it. I started at 40x. I think UPSERT semantics and not using EXCEPT for set difference were the big wins over that initial version.

Some caveats:

  • Souffle can also be run in parallel mode with -j . It didn’t help that much.
  • Souffle can be run in a compiled rather than interpreted mode. I didn’t try this.
  • The custom python is actually going about as fast as python can go. Just inserting exactly these tuples into a python set takes the same time.

Custom python baseline

from itertools import groupby


def strata1(edge):
    # path(x,y) :- edge(x,y)
    edge_y = {y: list(g) for (y, g) in groupby(edge, key=lambda t: t[1])}
    path = {(x, y) for (x, y) in edge}
    deltapath = path
    while True:
        newpath = set()
        # path(x,y) :- edge(x,y), path(y,z).
        newpath.update({(x, z)
                       for (y, z) in deltapath for (x, _) in edge_y.get(y, [])})
        # if we have not discovered any new tuples return
        if newpath.issubset(path):
            return path
        else:
            # merge new tuples into path for next iteration
            deltapath = newpath.difference(path)
            path.update(newpath)


edge = {(i, i+1) for i in range(1000)}
#edge.add((999, 0))
path = strata1(edge)
print(len(path))

WITH RECURSIVE

CREATE TABLE edge(i INTEGER NOT NULL, j INTEGER NOT NULL); -- , PRIMARY KEY (i,j)) -- WITHOUT ROWID;
CREATE TABLE path(i INTEGER NOT NULL, j INTEGER NOT NULL); --, PRIMARY KEY (i,j)); 

WITH RECURSIVE
    temp(i,j) AS 
    (SELECT 0,1
    UNION
    SELECT j, j+1 FROM temp WHERE temp.j < 1000)
INSERT INTO edge SELECT * FROM temp;

SELECT COUNT(*) FROM edge;

WITH RECURSIVE
    temp(i,j) AS 
    (SELECT * FROM edge
    UNION
    SELECT edge.i, temp.j FROM edge, temp WHERE edge.j = temp.i)
INSERT INTO path SELECT * FROM temp;

SELECT COUNT(*) FROM path;

https://www.sqlite.org/lang_with.html There are restrictions on WITH RECURSIVE that make it unclear how to build general datalog queries (only one recursive use allowed, Also I’m not sure if mutual recursion is allowed. Probably not). It is probably possible at some encoding cost.

Bits and Bobbles

I sort of like this as a simple test bed for making reasonably performing unusual datalog variants.

The JSON support of SQLite is intriguing. Can be use to implement ADTs, arrays, and I think sets as datalog objects. A nice interface would be to use python lists and dicts as datalog patterns.

By adding a union find to the compiler, can have slightly more expressive equality constraints.

Full Source

from pprint import pprint
import cProfile
from typing import Any, List
import sqlite3
from dataclasses import dataclass
from collections import defaultdict
from copy import copy
import networkx as nx
import time


@dataclass(frozen=True)
class Var:
    name: str

    def __eq__(self, rhs):
        return Eq(self, rhs)


@dataclass(frozen=True)
class Expr:
    expr: str


@dataclass(frozen=True)
class Eq:
    lhs: Any
    rhs: Any


@dataclass
class Not:
    val: Any


fresh_counter = 0


def FreshVar():
    global fresh_counter
    fresh_counter += 1
    return Var(f"duckegg_x{fresh_counter}")


def Vars(xs):
    return [Var(x) for x in xs.split()]


@dataclass
class Atom:
    name: str
    args: List[Any]

    def __repr__(self):
        args = ",".join(map(repr, self.args))
        return f"{self.name}({args})"


# https://www.sqlite.org/datatype3.html
INTEGER = "INTEGER"
TEXT = "TEXT"
REAL = "REAL"
BLOB = "BLOB"
JSON = "TEXT"


def delta(name):
    return f"datalite_delta_{name}"


def new(name):
    return f"datalite_new_{name}"


class Solver():
    def __init__(self, debug=False, database=":memory:"):
        self.con = sqlite3.connect(database=database)
        self.cur = self.con.cursor()
        self.rules = []
        self.rels = {}
        self.debug = debug
        self.stats = defaultdict(int)

    def execute(self, stmt):
        if self.debug:
            # print(stmt)
            start_time = time.time()
        self.cur.execute(stmt)
        if self.debug:
            end_time = time.time()
            self.stats[stmt] += end_time - start_time

    def Relation(self, name, *types):
        if name not in self.rels:
            self.rels[name] = types
            args = ", ".join(
                [f"x{n} {typ} NOT NULL" for n, typ in enumerate(types)])
            bareargs = ", ".join(
                [f"x{n}" for n, _typ in enumerate(types)])
            self.execute(
                f"CREATE TABLE {name}({args}, PRIMARY KEY ({bareargs})) WITHOUT ROWID")
            self.execute(
                f"CREATE TABLE {new(name)}({args}, PRIMARY KEY ({bareargs})) WITHOUT ROWID")
            self.execute(
                f"CREATE TABLE {delta(name)}({args}, PRIMARY KEY ({bareargs})) WITHOUT ROWID")
        else:
            assert self.rels[name] == types
        return lambda *args: Atom(name, args)

    def add_rule(self, head, body):
        #assert len(body) > 0
        self.rules.append((head, body))

    def add_fact(self, fact: Atom):
        args = ", ".join(map(str, fact.args))
        self.execute(
            f"INSERT OR IGNORE INTO {new(fact.name)} VALUES ({args})")

    def compile(self, head, body, naive=False):
        assert isinstance(head, Atom)
        # map from variables to columns where they appear
        # We use WHERE clauses and let SQL do the heavy lifting
        counter = 0

        def fresh(name):
            nonlocal counter
            counter += 1
            return f"dataduck_{name}{counter}"
        varmap = defaultdict(list)
        froms = []
        for rel in body:
            # Every relation in the body creates a new FROM term bounded to
            # a freshname because we may have multiple instances of the same relation
            # The arguments are processed to fill a varmap, which becomes the WHERE clause
            if (isinstance(rel, Atom)):
                name = rel.name
                args = rel.args
                freshname = fresh(name)
                froms.append(f"{name} AS {freshname}")
                for n, arg in enumerate(args):
                    varmap[arg] += [f"{freshname}.x{n}"]

        wheres = []
        formatvarmap = {k.name: i[0]
                        for k, i in varmap.items() if isinstance(k, Var)}

        def conv_arg(arg):
            if isinstance(arg, int):
                return str(arg)
            if isinstance(arg, str):
                return f"'{arg}'"
            if isinstance(arg, Expr):
                return arg.expr.format(**formatvarmap)
            elif arg in varmap:
                return varmap[arg][0]
            else:
                print("Invalid head arg", arg)
                assert False
        for c in body:
            if isinstance(c, str):  # Injected SQL constraint expressions
                wheres.append(c.format(**formatvarmap))
            if isinstance(c, Not) and isinstance(c.val, Atom):
                args = c.val.args
                fname = fresh(c.val.name)
                conds = " AND ".join(
                    [f"{fname}.x{n} = {conv_arg(arg)}" for n, arg in enumerate(args)])
                wheres.append(
                    f"NOT EXISTS (SELECT * FROM {c.val.name} AS {fname} WHERE {conds})")

        # implicit equality constraints
        for v, argset in varmap.items():
            for arg in argset:
                if isinstance(v, int):  # a literal constraint
                    wheres.append(f"{arg} = {v}")
                elif isinstance(v, str):  # a literal string
                    wheres.append(f"{arg} = '{v}'")
                elif isinstance(v, Var):  # a variable constraint
                    if argset[0] != arg:
                        wheres.append(f"{argset[0]} = {arg}")
                elif isinstance(v, Expr):  # Injected SQL expression argument
                    wheres.append(f"{v.format(**formatvarmap)} = {arg}")
                else:
                    print(v, argset)
                    assert False
        if len(wheres) > 0:
            wheres = " WHERE " + " AND ".join(wheres)
        else:
            wheres = ""
        # Semi-naive bodies
        selects = ", ".join([conv_arg(arg) for arg in head.args])
        if naive:
            froms = ", ".join(froms)
            return f"INSERT OR IGNORE INTO {new(head.name)} SELECT DISTINCT {selects} FROM {froms}{wheres}"
        else:
            stmts = []
            for n in range(len(froms)):
                froms1 = copy(froms)
                # cheating a little here. froms actually contains "name AS alias"
                froms1[n] = delta(froms1[n])
                froms1 = ", ".join(froms1)
                stmts.append(
                    f"INSERT OR IGNORE INTO {new(head.name)} SELECT DISTINCT {selects} FROM {froms1}{wheres} ")
            return stmts

    def stratify(self):
        G = nx.DiGraph()

        for head,  body in self.rules:
            for rel in body:
                if isinstance(rel, Atom):
                    G.add_edge(rel.name, head.name)
                elif isinstance(rel, Not):
                    assert isinstance(rel.val, Atom)
                    G.add_edge(rel.val.name, head.name)

        scc = list(nx.strongly_connected_components(G))
        cond = nx.condensation(G, scc=scc)
        for n in nx.topological_sort(cond):
            # TODO: negation check
            yield scc[n]

    def run(self):
        for strata in self.stratify():
            stmts = []
            for head, body in self.rules:
                if head.name in strata:
                    if len(body) == 0:
                        self.add_fact(head)
                    elif any([rel.name in strata for rel in body if isinstance(rel, Atom)]):
                        stmts += self.compile(head, body)
                    else:
                        # These are not recursive rules
                        # They need to be run once naively and can then be forgotten
                        stmt = self.compile(head, body, naive=True)
                        self.execute(stmt)
            # Prepare initial delta relation
            for name in strata:
                self.execute(
                    f"INSERT OR IGNORE INTO {delta(name)} SELECT DISTINCT * FROM {new(name)}")
                self.execute(
                    f"INSERT OR IGNORE INTO {name} SELECT DISTINCT * FROM {new(name)}")
                self.execute(
                    f"DELETE FROM {new(name)}")
            iter = 0
            # Seminaive loop
            while True:
                iter += 1
                # print(iter)
                for stmt in stmts:
                    self.execute(stmt)
                num_new = 0
                for name, types in self.rels.items():
                    self.execute(f"DELETE FROM {delta(name)}")
                    wheres = " AND ".join(
                        [f"{new(name)}.x{n} = {name}.x{n}" for n in range(len(types))])
                    self.execute(
                        f"INSERT OR IGNORE INTO {delta(name)} SELECT DISTINCT * FROM {new(name)}")
                    wheres = " AND ".join(
                        [f"{delta(name)}.x{n} = {name}.x{n}" for n in range(len(types))])
                    self.execute(
                        f"DELETE FROM {delta(name)} WHERE EXISTS (SELECT * FROM {name} WHERE {wheres})")
                    self.execute(
                        f"INSERT OR IGNORE INTO {name} SELECT * FROM {delta(name)}")
                    self.execute(f"DELETE FROM {new(name)}")

                    self.execute(f"SELECT COUNT(*) FROM {delta(name)}")
                    n = self.cur.fetchone()[0]
                    num_new += n
                if num_new == 0:
                    break


s = Solver(debug=True)
#s.con.execute("SET preserve_insertion_order TO False;")
#s.con.execute("SET checkpoint_threshold TO '10GB';")
#s.con.execute("SET wal_autocheckpoint TO '10GB';")
#s.con.execute("pragma mmap_size = 30000000000;")
#s.con.execute("pragma page_size=32768")
edge = s.Relation("edge", INTEGER, INTEGER)
path = s.Relation("path", INTEGER, INTEGER)
N = 999
# for i in range(N):
#    s.add_fact(edge(i, i+1))
#s.add_fact(edge(N, 0))
# s.add_fact(edge(2, 3))

x, y, z = Vars("x y z")
s.add_fact(edge(0, 1))
s.add_rule(edge(0, 1), [])
#s.add_rule(edge(y, "{y} + 1"), [edge(x, y), "{z} = {y} + 1"])
s.add_rule(edge(y, Expr("{y} + 1")), [edge(x, y), "{y} < 1000"])
s.add_rule(path(x, y), [edge(x, y)])
s.add_rule(path(x, z), [edge(x, y), path(y, z)])

verts = s.Relation("verts", INTEGER)
s.add_rule(verts(x), [edge(x, y)])
s.add_rule(verts(y), [edge(x, y)])
nopath = s.Relation("nopath", INTEGER, INTEGER)
s.add_rule(nopath(x, y), [verts(x), verts(y), Not(path(x, y))])

cProfile.run('s.run()', sort='cumtime')
# s.run()

pprint(sorted(s.stats.items(), key=lambda s: s[1], reverse=True)[:10])

s.cur.execute("SELECT COUNT(*) FROM path")
print(s.cur.fetchone())

#s.cur.execute("SELECT COUNT(*) FROM nopath")
# print(s.cur.fetchone())

#s.cur.execute("SELECT COUNT(*) FROM verts")
# print(s.cur.fetchone())