May 23, 2020

Pushing data upstream in Context

Go’s context is a hefty tool that was added as an official package back in 1.7. The context package provides contextual information that a goroutine may need such as how long it should run and how and when it should end. It can also pass informational key-value pairs for use down the call chain. But what if we need to pass information up in the call chain? Pointers come to the rescue.

Most of Go’s functions and methods, especially in a web server codebase, receive context.Context as their first argument. It can be used for deadlines when invoking HTTP services and databases. Also I use it to pass authorization information from middleware, like userID, role, and tenant. The main issue is that the value in the context is propagated down the stack. If a middleware, invoked earlier in the request, needs information from database calls via context, passing it as a simple value won’t work.

An example of this is logging all database queries. To do so, I utilize go-pg’s AfterQuery. Earlier in the logging middleware, I put a pointer to a string slice in the logger, and just append to it.

var queries []string
defer func(){
logger.Strs("queries", queries)
}()
ctx = context.WithValue(ctx, ContextKey("query"), &queries)

Now in the AfterQuery method, I get the value from log and append to it.

func (d dbLogger) AfterQuery(c context.Context, q *pg.QueryEvent) error {
    queries, ok := c.Value(ContextKey("query")).(*[]string)
    if ok {
        query, err := q.FormattedQuery()
        *queries = append(*queries, query)
        return err
    }
    return nil
}

This way all my queries are logged in a structured format without any repetitions.

I use a similar concept for transaction middleware.

func Tx(db *pg.DB) mux.MiddlewareFunc {
    return func(h http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            ctx := r.Context()
            tx, err := db.Begin()
            if err != nil {
                http.Error(w, 500, "error starting transaction")
                return
            }

            r = r.WithContext(context.WithValue(ctx, ContextKey("tx"), tx))

            defer func() {
                if r := recover(); r != nil {
                    tx.Rollback()
                    panic(r)
                }
                hasErrored, ok := ctx.Value(ContextKey("errored")).(*bool)
                if !ok {
                http.Error(w, "invalid transaction", 500)
                    return
                }

                if !*hasErrored {
                    tx.Commit()
                    return
                }

                tx.Rollback()
            }()

            h.ServeHTTP(w, r)

        })
    }
}

I have a centralized error handling mechanism (for HTTP status codes/errors), that in case of error updates this key to true, which rollbacks the transaction. Another use case I recently utilized this in is parallelized API invocation. I have a service that needs to copy lots of data from external API, for different entities. I do this parallelized, but if one of the calls aborts others should as well. Instead of using channels to pass this information, I used context and updated “failure” to true in case of error, which makes other goroutines stop invoking the API.

2024 © Emir Ribic - Some rights reserved; please attribute properly and link back. Code snippets are MIT Licensed

Powered by Hugo & Kiss.