Keysort: The Schwartzian Transform in Go

This article talks about a problem I encountered while sorting complex types in Go, and the solution I came up with, keysort.

by on

keysort is available on Github!

Encountering the Problem

They say necessity is the mother of invention. I say some of the most interesting problems to solve are those you inflict upon yourself. In one of my recent toy projects1 I found myself needing to repeatedly sort slices of a complex type by a property that was expensive to compute.

In Go, sort.Interface looks like this:

type Interface interface {
        // Len is the number of elements in the collection.
        Len() int
        // Less reports whether the element with
        // index i should sort before the element with index j.
        Less(i, j int) bool
        // Swap swaps the elements with indexes i and j.
        Swap(i, j int)
}

Any collection that implements sort.Interface can be sorted by a variety of sorting algorithms provided in the sort package and elsewhere. However, the problem I ran into was caused by the fact that even an optimal algorithm will call Less() some multiple of n log n times.

The problem is that a naive (and yet, syntactically simple) implementation of Less() calculates the expensive function on both items, and then discards the value after returning a boolean.

As you can see in this go playground snippet2, the number of calls to the "expensive function" TrueValue grows at a rate greater than linear.

Now, since I encountered this problem while I was in the middle of another project, I made the decision not to chase this particular yak down the rabbit hole with my safety razor. To get that project out the metaphorical door, I decided to memoize the expensive calculation by hand, and then use the memoized value in the code. It wasn't pretty, but it got the job done.

However, my dissatisfaction remained with me, because I had encountered the idiom designed to solve this problem before. This problem cried out for the application of a Schwartzian transform3. In Python, the Schwartzian transform is implemented directly in the sorted built-in, and you can access it with something as simple as

sorted(hard_to_sorts, key=operator.methodcaller('true_value'))

The implementation ensures that true_value() is called at most once per item. I was looking to enable an idiom that was proportionately lightweight, that required a small amount of work more than implementing sort.Interface to allow this automatically.

The solution

I knew that I wanted to provide a type that implemented sort.Interface, so that it could be used with the sorting algorithms provided in package sort.

The implementation of Len() and Swap() wouldn't need to change drastically at all. However, instead of clients implementing Less() themselves, I needed the user to pass in a Key() function that returns a value they wish to sort by.

Since the value returned by Key() could be of any type (from numeric types, to strings, slices, maps or even structs), I wanted the user to provide a comparator function to do the work. This is why the user needs to implement LessVal(). I essentially adopted the same strategy that the sort package uses.

One aspect of this solution that I'm not pleased with is that both the return type of Key(), and the argument types of LessVal() are unknown to me, and so I'm forced to use interface{}. In practice, however, this should not be too much of a hindrance, since these types must always agree, and a type assertion is all that is needed to straighten things out.

To make use of the finished product, the user simply needs to implement keysort.Interface, which is:

type Interface interface {
     LessVal(i, j interface{}) bool
     Key(i int) (interface{}, error)
     Swap(i, j int)
     Len() int
}

I'm okay with this being just one function larger that sort.Interface. Once you've implemented those functions, the actual sort step looks like:

sort.Sort(keysort.Keysort(ByTrueValue(hs)))

And here is the code for sorting with keysort in the Go Playground

Moar speed!

Now, this solution works, and does bring a noticeable speedup for me in solving this problem. However, I spent good money for the multiple cores in my computer, and they are not being utilized effectively to calculate the result of calling Key(). Since sort.Sort() runs all the comparisons within one goroutine, any opportunity for parallelism in calculating Key() must occur at an earlier step. What I really wanted was an easy way for the library to pre-cache the values before sorting begins.

I implemented this using PrimedKeysort, which is almost a drop-in replacement for Keysort.

sort.Sort(keysort.PrimedKeysort(ByTrueValue(hs), 0))

Calling this function kicks off a goroutine to memoize the result of calling Key() on each item. This allows you to take advantage of any parallelism this may afford you. In cases where the key function is I/O bound, this may net you a significant increase in throughput. In other cases where the key function is computationally intensive, you still may get a speedup proportional to GOMAXPROCS.

There's no online demo for this, because I don't want to write a slow version of TrueValue() on the playground. However, if you check out the example file in my repo, you'll see how much of a speed up it may bring you, depending on the type of key function you provide.

Currently, this solution is pretty simplistic, and there are two avenues for improvement that I'm already desigining for:

  1. it spawns a goroutine for every item and doesn't allow the caller to control how many such goroutines can run at a time, and
  2. it ignores any errors that may crop up on executing the key function.

The latter is especially important, because an error within the key function would lead to the return value for the key function of an element being left as its zero-value. Imagine trying to sort urls by the length of their page titles. The key function may include firing an HTTP request for the page, which may fail. Although no one would usually write a sorting function for urls based on page lengths, it doesn't take away from the fact that key functions could be complex, and could fail.

PrimedKeysort needs a way to communicate to its caller that it was unsuccessful, and that the result of the sort is bogus. I have a couple of ideas on how I want to approach this, and I'm eager to hear of other approaches.

With those caveats, my code for keysort is live on Github. Feel free to take a look and use it if it makes your life better. Pull requests are always encouraged!

I'd love to hear your comments on Hacker News.


  1. Which I may write about later. [return]
  2. All the examples can also be run via this file in my github repo. [return]
  3. I know, I know: Wikipedia says that the term Schwartzian transform applies to that specific idiom, and not to the algorithm in general. I think I can still use it for three reasons. One, the next best name that could apply, decorate-sort-undecorate, doesn't exactly roll off the tongue. Two, when used the way it is in the examples, it's pretty much the Go version of map/sort/map, and even though there is a named array created, it is private to a temporarily created struct, so it may as well be anonymous. Three, terminology and language are ever evolving. [return]