Friday, April 26, 2013

Operator overloading in Rust

With a bonus short introduction to Rust macros!

This post has been updated on 17 April, 2015, to cover Rust 1.0.0-beta. The comments as of this date are out of date, but I left them there to be more confusing. Thanks, and watch out for the rooster.

Rust offers several features that may make frustrated C++ programmers feel more at home. One of those features is operator overloading, for at least some common operators. The technique of overloading operators in Rust is handled through traits: the Rust manual describes overloading by saying, "[Arithmetic] expressions are syntactic sugar for calls to built-in traits, defined in the std::ops module of the std library. This means that arithmetic operators can be overridden for user-defined types." (Technically, the description applies only to binary arithmetic operators, but similar traits exist for the three unary operators, arithmetic negation, logical negation, and pointer dereference. Also, it may not be immediately apparent that array indexing and the bit-wise operators are "arithmetic". Finally, comparison operators also have linked traits, in std::cmp.)

So, how does one use these magic traits?

Let's say you want to use complex numbers in your code, because you are the kind of person who knows what complex numbers are good for and can use them in a safe and appropriately sanitary fashion. Complex numbers, for those of us who do not know what they are good for and probably cannot use a spatula in a safe and sanitary manner, are an extension of more commonly seen sets of numbers such as integers and reals, and pair a real number and an imaginary number. An imaginary number, to back up a bit, is a multiple of a number \(i\) whose square is -1, so a complex number is expressible as \(a + bi\), where \(a\) is the real component and \(b\) is the imaginary component, multiplied by \(i\), the imaginary unit. The complex number system provides at least one root for every polynomial expression in much the same way that real numbers provide a value for every division, unlike the integers. Or at least that is the impression that Wikipedia gives me. Thank you, great Wiki!

In any case, this is what a complex number looks like in Rust, at least according to me.

#[derive(Debug,Copy,Clone)]
pub struct Complex {
    r : f64,
    j : f64
}

(In the code, \(i\), the traditional notation for the imaginary unit, is replaced by j by suggestion of englabenny and dfjkfskjhfshdfjhsdjl on reddit, because "1i64" is the Rust notation for 1 as a 64-bit integer and because "i" is associated with current in various disciplines that commonly use complex numbers.)

This is a structure containing two fields, the real component and the imaginary component. f64 is the Rust machine-independent floating point type. (Alternatively, f32 is the 32-bit machine independent floating point type.)

This structure uses Rust's ability to automatically derive an implementation for some useful traits; std::fmt::Debug converts the structure to a string (seen below) with the {:?} formatting option, std::marker::Copy is a marker trait (with no implementation interfaces) indicating that it can be copied "by simply copying bits (i.e. memcpy)"---it acts like a primitive number type, in this case---, and std::clone::Clone provides methods to safely copy an object.

One thing that is not automatically derivable is the ability to easily create a Complex number from, say, a 64-bit floating point number. Adding that ability is basic Rust.

trait ToComplex { fn to_complex(&self) -> Complex; }

impl ToComplex for f64 {
    fn to_complex(&self) -> Complex { Complex { r : *self, j : 0.0f64 } }
}

Algebra for complex numbers makes use of a conjugate operation, which negates the imaginary component. (Note the use of the structure updating ".." syntax in this method.) A further operation for the type is one way of converting a complex number into a real: by viewing the complex as a vector on the 2-dimensional complex plane and returning its length or magnitude.

impl Complex {
    fn conjugate(&self) -> Complex { Complex { j : -self.j, .. *self } }
    fn magnitude(&self) -> f64 { ( self.r * self.r + self.j * self.j ).sqrt() }
}

The final preliminary is to provide an implementation of the std::fmt::Display trait, displaying the complex value as a string.

impl Display for Complex {
    fn fmt(&self, formatter : &mut Formatter) -> fmt::Result {
        write!(formatter, "{} + {}j", self.r, self.j)
    }
}

To overload the operators + and * for complex numbers, just provide an implementation of the Add and Mul traits:

impl Add<Complex> for Complex {
    type Output = Complex;
    fn add(self, rhs : Complex) -> Complex {
        Complex { r : self.r + rhs.r, j : self.j + rhs.j }
    }
}

impl Mul<Complex> for Complex {
    type Output = Complex;
    fn mul(self, rhs : Complex) -> Complex {
        Complex {
            r : (self.r * rhs.r) - (self.j * rhs.j),
            j : (self.r * rhs.j) + (self.j * rhs.r)
        }
    }
}

The Add trait contains one method, add, that performs the operation, and an associated type, Output, that describes the returned value of the operation. To unpack the types involved, implementing Add<X> for Z with the Output type of Y, for types X, Y, and Z, would provide an implementation of the operation where the left-hand side was a Z (the receiver of the method and the type for which the trait is being implemented), the right-hand side would be an X (the argument to the method), and the result of the operation would be a Y. The implementation of these two traits allows two complex numbers to be added or multiplied, producing a new complex number.

The implementation of division is similar to Add and Mul, but also illustrates the use of the overloaded operators previously defined.

impl Div<Complex> for Complex {
    type Output = Complex;
    fn div(self, rhs : Complex) -> Complex {
        let rhs_conj = rhs.conjugate();
        let num = self * rhs_conj;
        let den = rhs * rhs_conj;
        Complex { r : num.r / den.r, j : num.j / den.r }
    }
}

In the code, rhs_conj, num, and den are Complex and the calculations of num and den involve the Mul trait above.

Another thing I would like to be able to do is to provide multiple implementations of, say, the Add trait, for other types:

impl Add<f64> for Complex {
    type Output = Complex;
    fn add(self, rhs : f64) -> Complex {
        Complex { r : self.r + rhs, j : self.j }
    }
}

impl Add<Complex> for f64 {
    type Output = Complex;
    fn add(self, rhs : Complex) -> Complex {
        Complex { r : self + rhs.r, j : rhs.j }
    }
}

Following the discussion above, the first implementation is used when a Complex number is on the left-hand side of the + and a f64 floating point number is on the right. The second implementation is used when a f64 is on the left-hand side of the + and a Complex number is on the right. Both traits return Complex numbers. Whereas previous versions of Rust had problems with this sort of convenience-overloading, Rust as of 1.0.0-beta allows these to work just like you think they would.

Unfortunately, there is a cloud attached to this silver lining: there are many types and many operations to overload (see below). Having explicit implementations of every trait for every combination of types would be maddening, especially since they are all very, very similar. This is where macros can be very useful.

Rust supports a hygenic, definition-by-example-ish macro system using macro-rules!. I cannot go into the entire system here (not the least because I have not explored it deeply), but the following definition provides a template implementation of the Mul trait for a given type argument:

macro_rules! scalar_impl (
    ($foo:ty) => (

        // Implementation of multiplication for Complex and $foo
        impl Mul^lt;$foo> for Complex {
            type Output = Complex;
            fn mul(self, rhs : $foo) -> Complex {
                Complex { r : self.r * (rhs as f64), j : self.j * (rhs as f64) }
            }
        }
        impl Mul^lt;Complex> for $foo {
            type Output = Complex;
            fn mul(self, rhs : Complex) -> Complex {
                Complex { r : (self as f64) * rhs.r, j : (self as f64) * rhs.j }
            }
        }
    )
);

This code defines a macro, scalar_impl! (the exclamation point is a necessary part of the macro's invocation). The definition uses one rule, meaning it supports one form where the macro is given a type argument (note that the type of $foo is ty). The expansion of this rule supplies the two, complementary, implementations of Mul involving the type $foo. The first is used when a Complex number is the left-hand side and a number of the type represented by $foo is the right-hand side; the second is used when a number of the type represented by $foo is the left-hand side.

Note that there are no conditions placed on the $foo type; the macro itself is not type checked but the expansion will be. In this case, the requirement is that the type $foo be convertible to a 64-bit floating point number (in the first implementation, (rhs as f64); in the second, (self as f64).) The macro is used as:

scalar_impl!(i8);
scalar_impl!(i16);
scalar_impl!(i32);
scalar_impl!(i64);
scalar_impl!(isize);
scalar_impl!(u8);
scalar_impl!(u16);
scalar_impl!(u32);
scalar_impl!(u64);
scalar_impl!(usize);
scalar_impl!(f64);
scalar_impl!(f32);

This list of invocations supplies implementations of the Mul trait for all of Rust's primitive numeric types.

How are these definitions used? Here are some complete, albeit useless, examples.

    let w = 2.0.to_complex();

    let x = Complex { r : 1.0, j : 0.0 };
    let y = Complex { r : 3.0, j : 0.0 };
    let z = x + y;

    println!("  z: {:?}", z);
    // =>   z: Complex { r: 4, j: 0 }

This first example shows the automatically derived debugging format, which includes the structure and field names.

    println!("{}", ( z / w                  ));
    // => 2 + 0j

    println!("{}", ( y + 3.0                ));
    // => 6 + 0j

    println!("{}", ( 3.0 + y                ));
    // => 6 + 0j

    println!("{}", ( y * 3isize             ));
    // => 9 + 0j

    println!("{}", ( y * 3.0f64             ));
    // => 9 + 0j

    println!("{}", ( 4u8 * y                ));
    // => 12 + 0j

These examples demonstrate the basic arithmetic operations, and their use with various numeric types. The final examples show some of the Complex number's party tricks.

    let n = Complex { r : 0.0, j : 1.0 };
    println!("{}", (  n * n                 ));
    // => -1 + 0j

    println!("{}", ( (n * n) * 2            ));
    // => -2 + 0j

    let mu : Complex = (n * n) * 2;
    println!("{}", mu.magnitude() );
    // => 2

The first, that \((0+i)^2\) is \(-1\), the second, that \((0+i)^2 * 2\) is \(-2\), and the final that the magnitude of \(-2+0i\) is 2.

Isn't that just lovely, hmm?

The operators which can be overloaded, as of Rust 1.0.0-beta, are:

OperatorTrait
Arithmetic
+std::ops::Add
-std::ops::Sub
*std::ops::Mul
/std::ops::Div
%std::ops::Rem
- (unary negation)std::ops::Neg
Bitwise
&std::ops::BitAnd
|std::ops::BitOr
^ (exclusive or)std::ops::BitXor
<< (shift left)std::ops::Shl
>> (shift right)std::ops::Shr
Miscellaneous
! (Boolean negation)std::ops::Not
a[i] (indexing, immutable context)std::ops::Index
a[i] (indexing, mutable context)std::ops::IndexMut
*v (dereference, immutable context)std::ops::Deref
*v (dereference, mutable context)std::ops::DerefMut
Comparison
==std::cmp::PartialEq and std::cmp::Eq
!=std::cmp::PartialEq and std::cmp::Eq
<std::cmp::PartialOrd and std::cmp::Ord
>std::cmp::PartialOrd and std::cmp::Ord
<=std::cmp::PartialOrd and std::cmp::Ord
>=std::cmp::Ord

The two packages there are std::ops and std::cmp.

Once upon a time, in order to preserve sanity, Rust limited trait implementation. As kibwen pointed out in the reddit discussion, there were restrictions on where traits, types, and the implementations of traits for types could legitimately appear. Specifically, the implementation had to be in the same crate as either the type or the trait. kibwen writes,

Now note that the overloading traits are defined in libcore, which is shipped with the Rust compiler.

The implication then is that it is only possible to overload operators on types that you've defined yourself. You never have to worry about library A attempting distant overloads on types from library B; this also means that you never have to worry about libraries changing what 2+2 means.

That previous restriction has been removed (sometime in fall 2014, I think), as can be seen in the use of Add<Complex> for f64. The resulting changes go by the name of multidispatch and conditional dispatch, and the simple crate-based coherence rules have been replaced by something more complex that can only be described by a blog post. Or, at least, I don't have a good description of it. It's got something to do with those trait's associated types. Here are some references:

I must note that operator overloading, even with Rust's limits on it, is not something to be used without considerable thought. If you misuse overloading, doing something like C++'s 'cout << "hello world"', a Rust developer will find you. And Fix. Your. Little. Red. Wagon.

The source code for these examples in on github.

I would like to thank the commenters from Reddit, particularly englabenny, dfjkfskjhfshdfjhsdjl, and kibwen, for their help.

2 comments:

Alex said...

In the last code snippet, how does NumCast::from know that you want to get a Complex back?

Tommy McGuire said...

The compiler knows what types are expected in the expressions where NumCast::from is used. For example, in "x * NumCast::from(3.0f)", x is a Complex and the only thing the compiler knows how to multiply a Complex number by is another Complex number, so it looks for an implementation of NumCast for Complex.

The tricky one is "let w = NumCast::from(2);" because there is no type information provided for w in that line. However, the compiler also knows that w is used in the expression "z / w" where z is Complex (since z is the result of adding two Complex's), so w also has to be Complex.

If you just had the "let w = ..." line, without a use of w that pinned-down its type or type annotations, the compiler would throw an error.