Skip to content

Commit

Permalink
rewrite try
Browse files Browse the repository at this point in the history
  • Loading branch information
ybrs committed Dec 31, 2023
1 parent 312f4a4 commit 445871f
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 39 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,24 @@
# pgduckdb
Postgresql wire protocol proxy for duckdb. In other words, this is a network proxy to allow access to your duckdb files over the postgresql wire protocol.

# Why

Duckdb is an embedded database, and in some circumstances you'd want to share/connect to it over network.

Some clients/applications might not have a duckdb option but have a way to connect to a postgres server. (eg: grafana)

# Installation

As usual you can do

```bash
go install github.com/ybrs/pgduckdb@latest
```

# Usage

You can run pgduckdb pointing to your duckdb file

```
```
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func main() {
bootQueries := []string{
"INSTALL 'json'",
"LOAD 'json'",
"INSTALL 'icu'",
"LOAD 'icu'",
}

for _, qry := range bootQueries {
Expand Down
169 changes: 130 additions & 39 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
_ "github.com/marcboeker/go-duckdb"
"io"
"net"
"strings"
)

// Postgres settings.
const (
ServerVersion = "13.0.0"
)

type DuckDbBackend struct {
Expand Down Expand Up @@ -110,7 +116,7 @@ func scanRow(rows *sql.Rows, cols []*sql.ColumnType) (*pgproto3.DataRow, error)
refs[i] = &values[i]
}

// Scan from SQLite database.
// Scan from Duckdb database.
if err := rows.Scan(refs...); err != nil {
return nil, fmt.Errorf("scan: %w", err)
}
Expand All @@ -123,6 +129,76 @@ func scanRow(rows *sql.Rows, cols []*sql.ColumnType) (*pgproto3.DataRow, error)
return &row, nil
}

func (p *DuckDbBackend) HandleQuery(query string) error {
_, err := p.db.Prepare(query)
if err != nil {
fmt.Println("coulnt handle query", query)
writeMessages(p.conn,
&pgproto3.ErrorResponse{Message: err.Error()},
&pgproto3.ReadyForQuery{TxStatus: 'I'},
)
return nil
}

rows, err := p.db.QueryContext(p.ctx, query)
if err != nil {
fmt.Println("couldnt handle query 2", query)
writeMessages(p.conn,
&pgproto3.ErrorResponse{Message: err.Error()},
&pgproto3.ReadyForQuery{TxStatus: 'I'},
)
return nil
}
defer rows.Close()

cols, err := rows.ColumnTypes()
if err != nil {
return fmt.Errorf("column types: %w", err)
}
buf := toRowDescription(cols).Encode(nil)

// Iterate over each row and encode it to the wire protocol.
for rows.Next() {
row, err := scanRow(rows, cols)
if err != nil {
return fmt.Errorf("scan: %w", err)
}
buf = row.Encode(buf)
}
if err := rows.Err(); err != nil {
return fmt.Errorf("rows: %w", err)
}
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
_, err = p.conn.Write(buf)
//fmt.Println("sending buf", buf)
if err != nil {
return fmt.Errorf("error writing query response: %w", err)
}

return nil
}

// CREATE OR REPLACE FUNCTION _pg_expandarray(anyarray) AS (SELECT [unnest(anyarray), generate_subscripts(anyarray, 1)]);
// CREATE OR REPLACE FUNCTION array_upper(anyarray, f) AS len(anyarray) + 1;
func rewriteQuery(query string) string {
// poor mans attempt to pass parse commands
// make pycharm/goland/jetbrains ide database tool work
// ideally we should use something like https://pkg.go.dev/github.com/dallen66/pg_query_go/v2#section-readme
// to properly map things (or have a pgcatalog database??)
query = strings.ReplaceAll(query, "pgcatalog.current_database(", "current_database(")
query = strings.ReplaceAll(query, "pg_catalog.current_schema(", "current_schema(")
query = strings.ReplaceAll(query, "pg_catalog.", "")
query = strings.ReplaceAll(query, "::regclass", "::oid")
query = strings.ReplaceAll(query, "SHOW TRANSACTION ISOLATION LEVEL", "select 'read committed'")
query = strings.ReplaceAll(query, ` trim(both '"' from pg_get_indexdef(tmp.CI_OID, tmp.ORDINAL_POSITION, false)) `, `'' `)
query = strings.ReplaceAll(query, `(information_schema._pg_expandarray(i.indkey)).n `, `1`)
query = strings.ReplaceAll(query, `information_schema._pg_expandarray(`, `_pg_expandarray( `)
query = strings.ReplaceAll(query, ` (result.KEYS).x `, ` result.KEYS[0] `)
query = strings.ReplaceAll(query, `::regproc`, ``)
return query
}

func (p *DuckDbBackend) Run() error {
defer p.Close()

Expand All @@ -139,54 +215,68 @@ func (p *DuckDbBackend) Run() error {

switch msg.(type) {
case *pgproto3.Query:
fmt.Println("msg", msg)
//fmt.Println("msg", msg)
query := msg.(*pgproto3.Query)
p.HandleQuery(query.String)
case *pgproto3.Terminate:
return nil

_, err = p.db.Prepare(query.String)
if err != nil {
writeMessages(p.conn,
&pgproto3.ErrorResponse{Message: err.Error()},
&pgproto3.ReadyForQuery{TxStatus: 'I'},
)
case *pgproto3.Parse:
// For now we simply ignore parse messages
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
q := msg.(*pgproto3.Parse)
//fmt.Println("parse >>>", q)
query := q.Query
if strings.HasPrefix(q.Query, "SET ") ||
strings.Contains(q.Query, "proargmodes") ||
strings.Contains(q.Query, "prokind") ||
strings.Contains(q.Query, "from pg_catalog.pg_locks") {
//fmt.Println("skipped query returning ok")
buf := (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(nil)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
//fmt.Println("parse complete write buf")
_, err = p.conn.Write(buf)
if err != nil {
return fmt.Errorf("error writing query response: %w", err)
}
continue

}
//fmt.Println("run parse query", query)
if query == "" {
buf := (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(nil)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
buf = (&pgproto3.ParseComplete{}).Encode(buf)
//fmt.Println("parse complete write buf")
_, err = p.conn.Write(buf)
if err != nil {
return fmt.Errorf("error writing query response: %w", err)
}

rows, err := p.db.QueryContext(p.ctx, query.String)
if err != nil {
writeMessages(p.conn,
&pgproto3.ErrorResponse{Message: err.Error()},
&pgproto3.ReadyForQuery{TxStatus: 'I'},
)
continue
}
defer rows.Close()

cols, err := rows.ColumnTypes()
if err != nil {
return fmt.Errorf("column types: %w", err)
}
buf := toRowDescription(cols).Encode(nil)
query = rewriteQuery(query)
//fmt.Println("running rewritten >", query)
p.HandleQuery(query)
//buf := (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(nil)
//buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
//fmt.Println("parse complete write buf")
//_, err = p.conn.Write(buf)
//if err != nil {
// return fmt.Errorf("error writing query response: %w", err)
//}

// Iterate over each row and encode it to the wire protocol.
for rows.Next() {
row, err := scanRow(rows, cols)
if err != nil {
return fmt.Errorf("scan: %w", err)
}
buf = row.Encode(buf)
}
if err := rows.Err(); err != nil {
return fmt.Errorf("rows: %w", err)
}
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
_, err = p.conn.Write(buf)
if err != nil {
return fmt.Errorf("error writing query response: %w", err)
}
case *pgproto3.Terminate:
return nil
case *pgproto3.Bind:
continue
case *pgproto3.Execute:
continue
case *pgproto3.Sync:
continue
case *pgproto3.Describe:
continue
default:
print("coming to default ??", msg)
return fmt.Errorf("received message other than Query from client: %#v", msg)
}
}
Expand All @@ -201,6 +291,7 @@ func (p *DuckDbBackend) handleStartup() error {
switch startupMessage.(type) {
case *pgproto3.StartupMessage:
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
buf = (&pgproto3.ParameterStatus{Name: "server_version", Value: ServerVersion}).Encode(buf)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
_, err = p.conn.Write(buf)
if err != nil {
Expand Down

0 comments on commit 445871f

Please sign in to comment.