Friday, April 26, 2013

Operator overloading in Rust

Rust offers several features that may make frustrated C++ programmers feel more at home. One of those features is operator overloading, for at least some common operators. The technique of overloading operators in Rust is handled through traits: the Rust manual describes overloading by saying, "[arithmetic operator] expressions are syntactic sugar for calls to built-in traits, defined in the core::ops module of the core library. This means that arithmetic operators can be overridden for user-defined types." So, how does one use these built-in traits?

Let's say you want to use complex numbers in your code, because you are the kind of person who knows what complex numbers are good for and can use them in a safe and appropriately sanitary fashion. Complex numbers, for those of us who do not know what they are good for and probably cannot use a spatula in a safe and sanitary manner, are an extension of more commonly seen sets of numbers such as integers and real numbers, and pair a real number and an imaginary number. An imaginary number, to back up a bit, is a multiple of a number \(i\) whose square is -1, so a complex number is expressible as \(a + bi\), where \(a\) is the real component and \(b\) is the imaginary component, multiplied by \(i\), the imaginary unit. The complex number system provides at least one root for every polynomial expression in much the same way that real numbers provide a value for every division, unlike the integers. Or at least that is the impression that Wikipedia gives me. Thank you, great Wiki!

In any case, this is what a complex number looks like in Rust, at least according to me.

struct Complex {
    r : float,
    j : float
}

(In the code, \(i\), the traditional notation for the imaginary unit, is replaced by j by suggestion of englabenny and dfjkfskjhfshdfjhsdjl on reddit, because "1i" is the Rust notation for 1 as an int and because "i" is associated with current in various disciplines that commonly use complex numbers.)

This is a structure containing two fields, the real component and the imaginary component. float is the Rust machine-dependent floating point type, and represents the largest floating point type preferred on the target hardware. (Alternatively, f32 and f64 are the 32- and 64-bit machine independent floating point types.)

To overload the operator + for complex numbers, just provide an implementation of the Add trait:

impl Add<Complex,Complex> for Complex {
    fn add(&self, rhs: &Complex) -> Complex {
        Complex {
            r : self.r + rhs.r,
            j : self.j + rhs.j
        }
    }
}

The Add trait contains one method, add, that performs the operation. To unpack the impl line, implementing Add<X,Y> for Z for types X, Y, and Z would provide an implementation of the operation where the left-hand side was a Z (the receiver of the method and the type for which the trait is being implemented), the right-hand side would be an X (the argument to the method), and the result would be a Y. This implementation allows two complex numbers to be added, producing a new complex number.

One thing I would like to be able to do is to provide multiple implementations of the Add trait, such as:

// impl Add<float,Complex> for Complex {
//     fn add(&self, rhs : &float) -> Complex {
//         Complex { r : self.r + *rhs, j : self.j }
//     }
// }

This implementation would allow floating point numbers as the right-hand side, producing a complex result. Unfortunately, the current Rust trait implementation does not allow that, complaining about conflicting operations.

complex.rs:52:0: 56:1 error: conflicting implementations for a trait
...
complex.rs:24:0: 28:1 error: conflicting implementations for a trait
...
complex.rs:81:12: 81:17 error: multiple applicable methods in scope
...

Something to look forward to in future releases.

With that said, the following code now works and produces the expected results:

    let x = Complex { r : 1.0, j : 0.0 };
    let y = Complex { r : 3.0, j : 0.0 };
    let z = x + y;

Two further examples implement subtraction and multiplication for Complex values.

impl Sub<Complex,Complex> for Complex {
    fn sub(&self, rhs : &Complex) -> Complex {
        Complex { r : self.r - rhs.r, j : self.j - rhs.j }
    }
}

impl Mul<Complex,Complex> for Complex {
    fn mul(&self, rhs : &Complex) -> Complex {
        Complex {
            r : (self.r * rhs.r) - (self.j * rhs.j),
            j : (self.r * rhs.j) + (self.j * rhs.r)
        }
    }
}

Isn't that just lovely, hmm?

The operators which can be overloaded, as of Rust 0.6, are:

OperatorTrait
Arithmetic
+core::ops::Add<RHS,Result>
-core::ops::Sub<RHS,Result>
*core::ops::Mul<RHS,Result>
/core::ops::Div<RHS,Result>
%core::ops::Modulo<RHS,Result>
- (unary negation)core::ops::Neg<Result>
Bitwise
&core::ops::BitAnd<RHS,Result>
|core::ops::BitOr<RHS,Result>
^ (exclusive or)core::ops::BitXor<RHS,Result>
<< (shift left)core::ops::Shl<RHS,Result>
>> (shift right)core::ops::Shr<RHS,Result>
Comparison
==core::cmp::Eq
!=core::cmp::Eq
<core::cmp::Ord
>core::cmp::Ord
<=core::cmp::Ord
>=core::cmp::Ord
Miscellaneous
! (Boolean negation)core::ops::Not<Result>
a[i] (indexing)core::ops::Index<RHS,Result>

In order to preserve sanity, Rust limits operator overloading. As kibwen pointed out in the reddit discussion, there are restrictions on where traits, types, and the implementations of traits for types can legitimately appear. Specifically, the implementation must be in the same crate as either the type or the trait. kibwen writes,

Now note that the overloading traits are defined in libcore, which is shipped with the Rust compiler.

The implication then is that it is only possible to overload operators on types that you've defined yourself. You never have to worry about library A attempting distant overloads on types from library B; this also means that you never have to worry about libraries changing what 2+2 means.

At this point, I must note that operator overloading, even with Rust's limits on it, is not something to be used without considerable thought. If you misuse overloading, doing something like C++'s 'cout << "hello world"', a Rust developer will find you. And Fix. Your. Little. Red. Wagon.

Going further with complex numbers

Complex numbers have two specialized operations that are useful to me here: conjugation, which negates the imaginary component, and magnitude, which computes the distance of a complex number from the origin of the complex number plane. (The magnitude is the closest I can come to a meaningful way of down-converting from a complex number to a floating value.)

impl Complex {
    fn conjugate(&self) -> Complex { Complex { r : self.r, j : -self.j } }
    fn magnitude(&self) -> float { float::sqrt( self.r * self.r + self.j * self.j ) }
}

This implementation provides conjugate and magnitude methods for the Complex structure; think of it as the implementation of an anonymous trait. One thing we need the conjugate method for is implementing division for complex numbers:

impl Div<Complex,Complex> for Complex {
    fn div(&self, rhs : &Complex) -> Complex {
        let rhs_conj = rhs.conjugate();
        let num      = self * rhs_conj;
        let denom    = rhs * rhs_conj;
        Complex {
            r : (num.r / denom.r),
            j : (num.i / denom.r)
        }
    }
}

Speaking of traits, ToStr is a trait supported by most Rust types, converting a value to a string.

impl ToStr for Complex {
    fn to_str(&self) -> ~str { fmt!("(%f + %fi)", self.r, self. i) }
}

For Complex, to_str will produce a string of the form "(a + bi)". One further thing that I would like to be able to do is to convert other values, such as float's into Complex numbers.

trait ToComplex { fn to_complex(&self) -> Complex; }

impl ToComplex for float {
    fn to_complex(&self) -> Complex { Complex { r : *self, j : 0.0f } }
}

That trait and implementation adds a to_complex method to any float value, such as 3.0f.to_complex().

There is a standard trait for converting among numeric types in Rust, NumCast. This is the implementation of it for Complex numbers:

impl NumCast for Complex {
    fn from<N:NumCast>(n: N) -> Complex { n.to_float().to_complex() }

    fn to_u8(&self)    -> u8    { self.magnitude() as u8    }
    fn to_u16(&self)   -> u16   { self.magnitude() as u16   }
    fn to_u32(&self)   -> u32   { self.magnitude() as u32   }
    fn to_u64(&self)   -> u64   { self.magnitude() as u64   }
    fn to_uint(&self)  -> uint  { self.magnitude() as uint  }

    fn to_i8(&self)    -> i8    { self.magnitude() as i8    }
    fn to_i16(&self)   -> i16   { self.magnitude() as i16   }
    fn to_i32(&self)   -> i32   { self.magnitude() as i32   }
    fn to_i64(&self)   -> i64   { self.magnitude() as i64   }
    fn to_int(&self)   -> int   { self.magnitude() as int   }

    fn to_f32(&self)   -> f32   { self.magnitude() as f32   }
    fn to_f64(&self)   -> f64   { self.magnitude() as f64   }
    fn to_float(&self) -> float { self.magnitude()          }
}

With that addition, I can write the following main:

fn main() {
    let x = Complex { r : 1.0, j : 0.0 };
    let y = Complex { r : 3.0, j : 0.0 };
    let z = x + y;
    let w = NumCast::from(2);
    println(( y + 3.0f.to_complex()   ).to_str());
    println(( x * NumCast::from(3.0f) ).to_str());
    println(( x * NumCast::from(4)    ).to_str());
    println(( z / w                   ).to_str());

    let n = Complex { r : 0.0, j : 1.0 };
    println(( n * n                   ).to_str());
}

Which produces the following output:

(6 + 0i)
(3 + 0i)
(4 + 0i)
(2 + 0i)
(-1 + 0i)

The source code for these examples in on github.

I would like to thank the commenters from Reddit, particularly englabenny, dfjkfskjhfshdfjhsdjl, and kibwen, for their help.

2 comments:

Alex said...

In the last code snippet, how does NumCast::from know that you want to get a Complex back?

Tommy McGuire said...

The compiler knows what types are expected in the expressions where NumCast::from is used. For example, in "x * NumCast::from(3.0f)", x is a Complex and the only thing the compiler knows how to multiply a Complex number by is another Complex number, so it looks for an implementation of NumCast for Complex.

The tricky one is "let w = NumCast::from(2);" because there is no type information provided for w in that line. However, the compiler also knows that w is used in the expression "z / w" where z is Complex (since z is the result of adding two Complex's), so w also has to be Complex.

If you just had the "let w = ..." line, without a use of w that pinned-down its type or type annotations, the compiler would throw an error.