Computing function return type at compile-time

I’d like to write a function that takes 2 arrays of Numbers and returns an array of the smallest common type they both can be coerced to.

For example, if Array(Int32) and Array(Int16) are passed in, then it should return Array(Int32) because all Int16's can be safely cast to Int32.

So, I need to create the output Array of type that is computed from T and U, i.e. write:

def merge(a : Array(T), b : Array(U)) forall T, U
  res = Array(???).new()

what should the ??? be? Or is it not possible to do it right now?

I’ve tried to create a macro commonType like:

  res = Array(commonType(T,U)).new()

but could not make it work.
Also attempts to use conditional compilation were not successful so far.

1 Like

There are (at least) two paths if you definitely want to deal with that scenario without typing too mucg

  1. Use macros to generate all the variants.
  2. Use some typeof tricks to be able to express what you want.

For the first case something like the following could help. Building the data needed for expanding the code and using it when needed.

{% for d in [
              {res: Int32, left: Int32, right: Int16, left_method: :itself, right_method: :to_i32},
              {res: Int32, left: Int16, right: Int32, left_method: :to_i32, right_method: :itself},
              {res: Int16, left: Int8, right: Int16, left_method: :to_i16, right_method: :itself},
            ] %}
def merge(a : Array({{d[:left]}}), b : Array({{d[:right]}}))
  res = Array({{d[:res]}}).new(a.size + b.size)
  a.each { |e| res << e.{{d[:left_method].id}} }
  b.each { |e| res << e.{{d[:right_method].id}} }
  res
end
{% end %}

For the second case, we will need first to have something that will make the following work

typeof(join(Int32, Int16)) # => Int32
typeof(join(Int16, Int32)) # => Int32
typeof(join(Int8, Int16)) # => Int16

we could use the same technique as before, but we would even expand the possibilities directly

def join(a : Int32.class, b : Int16.class)
  0i32
end

def join(a : Int16.class, b : Int32.class)
  0i32
end

def join(a : Int8.class, b : Int16.class)
  0i16
end

# ... etc ...

We are not going to run join.

NB: I call it join instead of commonType because common seems more like a meet from Join & Meet

We will also need some construct like:

struct Int
  def to_num(t : Int16.class)
    self.to_i16
  end

  def to_num(t : Int32.class)
    self.to_i32
  end
end

And finally

def merge2(a : Array(T), b : Array(U)) forall T, U
  res = Array(typeof(join(T, U))).new(a.size + b.size)
  a.each { |e| res << e.to_num(typeof(join(T, U))) }
  b.each { |e| res << e.to_num(typeof(join(T, U))) }
  res
end

:tada:

But when modeling things these kind of unions are hardly needed IMO. Yet is a good exercise with types.

2 Likes

Brian, thanks a lot for a quick response!
I definitely like approach #2 better, and I made it work, but I don’t understand why it works :slight_smile:, so I have a few more questions:

  1. What is the type of T.class expression? If the type of both Int16.class and Int32.class is just Class, then how does method overloading works for lct?
  2. It looks like the order of lct definitions does not matter. Then why more specific overloads are preferred over Commutativity rule?
  3. How is it that Commutativity rule does not cause infinite recursion during compilation if an opposite rule is not defined? Do you have a recursion limit during resolution? How big?

Thanks again, Crystal has a truly unique type system that I enjoy exploring.

The code:

# lct stands for Least Common Type
def lct(a :  Int16.class, b : UInt16.class) 0_i32 end
def lct(a :  Int16.class, b :  Int8 .class) 0_i16 end
def lct(a :  Int16.class, b : UInt8 .class) 0_i16 end
def lct(a : UInt16.class, b :  Int8 .class) 0_u16 end
def lct(a : UInt16.class, b : UInt8 .class) 0_u16 end
# ... similar for all other numeric types

# Identity
def lct(a : T.class, b : T.class) forall T T.new(0) end
# Commutativity
def lct(a : T.class, b : U.class) forall T, U lct(b, a) end

def merge_sorted(a : Iterable(T), b : Iterable(U)) forall T, U
  res = Array(typeof(lct(T,U))).new()

  ia = a.each
  ib = b.each

  ea = ia.next.as?(T)
  eb = ib.next.as?(U)
  loop do
    if ea && (!eb || ea <= eb)
      res << typeof(lct(T,U)).new(ea)
      ea = ia.next.as?(T)
    elsif eb
      res << typeof(lct(T,U)).new(eb)
      eb = ib.next.as?(U)
    else
      return res
    end
  end
end
1 Like

Int32.class is the metaclass of Int32. When a method has a restriction m(a : Int32) that overload will be used when the argument is of type Int32. When the restriction is Int32.class same rule applies. But the only value that is of type Int32.class is Int32. So it’s more like defining a function case by case.

Because we put a lot of love in the compiler. But due to how method dispatch and overloads work in crystal basically the last definition are not hiding the previous one. In order to redefine the method you will need to match the exactly the same restrictions.

Because of the same fact I mention in 2. The method lookup is able to grab use the first set of definitions.

2 Likes