# Operator overloading in Rust

**With a bonus short introduction to Rust macros!**

*This post has been updated on 17 April, 2015, to cover Rust 1.0.0-beta. The comments as of this date are out of date, but I left them there to be more confusing. Thanks, and watch out for the rooster.*

Rust offers several features that may make frustrated C++ programmers feel more at home. One of those features is operator overloading, for at least some common operators. The technique of overloading operators in Rust is handled through traits: the Rust manual describes overloading by saying, "[Arithmetic] expressions are syntactic sugar for calls to built-in traits, defined in the **std::ops** module of the std library. This means that arithmetic operators can be overridden for user-defined types." (Technically, the description applies only to binary arithmetic operators, but similar traits exist for the three unary operators, arithmetic negation, logical negation, and pointer dereference. Also, it may not be immediately apparent that array indexing and the bit-wise operators are "arithmetic". Finally, comparison operators also have linked traits, in **std::cmp**.)

So, how does one use these magic traits?

Let's say you want to use complex numbers in your code, because you are the kind of person who knows what complex numbers are good for and can use them in a safe and appropriately sanitary fashion. Complex numbers, for those of us who do not know what they are good for and probably cannot use a spatula in a safe and sanitary manner, are an extension of more commonly seen sets of numbers such as integers and reals, and pair a real number and an imaginary number. An imaginary number, to back up a bit, is a multiple of a number \(i\) whose square is -1, so a complex number is expressible as \(a + bi\), where \(a\) is the real component and \(b\) is the imaginary component, multiplied by \(i\), the imaginary unit. The complex number system provides at least one root for every polynomial expression in much the same way that real numbers provide a value for every division, unlike the integers. Or at least that is the impression that Wikipedia gives me. Thank you, great Wiki!

In any case, this is what a complex number looks like in Rust, at least according to me.

#[derive(Debug,Copy,Clone)]

pub struct Complex {

r : f64,

j : f64

}

(In the code, \(i\), the traditional notation for the imaginary unit, is replaced by j by suggestion of englabenny and dfjkfskjhfshdfjhsdjl on reddit, because "1i64" is the Rust notation for 1 as a 64-bit integer and because "i" is associated with current in various disciplines that commonly use complex numbers.)

This is a structure containing two fields, the real component and the imaginary component. **f64** is the Rust machine-independent floating point type. (Alternatively, **f32** is the 32-bit machine independent floating point type.)

This structure uses Rust's ability to automatically derive an implementation for some useful traits; **std::fmt::Debug** converts the structure to a string (seen below) with the **{:?}** formatting option, **std::marker::Copy** is a marker trait (with no implementation interfaces) indicating that it can be copied "by simply copying bits (i.e. memcpy)"---it acts like a primitive number type, in this case---, and **std::clone::Clone** provides methods to safely copy an object.

One thing that is not automatically derivable is the ability to easily create a **Complex** number from, say, a 64-bit floating point number. Adding that ability is basic Rust.

trait ToComplex { fn to_complex(&self) -> Complex; }

impl ToComplex for f64 {

fn to_complex(&self) -> Complex { Complex { r : *self, j : 0.0f64 } }

}

Algebra for complex numbers makes use of a **conjugate** operation, which negates the imaginary component. (Note the use of the structure updating ".." syntax in this method.) A further operation for the type is one way of converting a complex number into a real: by viewing the complex as a vector on the 2-dimensional complex plane and returning its length or **magnitude**.

impl Complex {

fn conjugate(&self) -> Complex { Complex { j : -self.j, .. *self } }

fn magnitude(&self) -> f64 { ( self.r * self.r + self.j * self.j ).sqrt() }

}

The final preliminary is to provide an implementation of the **std::fmt::Display** trait, displaying the complex value as a string.

impl Display for Complex {

fn fmt(&self, formatter : &mut Formatter) -> fmt::Result {

write!(formatter, "{} + {}j", self.r, self.j)

}

}

To overload the operators **+** and ***** for complex numbers, just provide an implementation of the **Add** and **Mul** traits:

impl Add<Complex> for Complex {

type Output = Complex;

fn add(self, rhs : Complex) -> Complex {

Complex { r : self.r + rhs.r, j : self.j + rhs.j }

}

}

impl Mul<Complex> for Complex {

type Output = Complex;

fn mul(self, rhs : Complex) -> Complex {

Complex {

r : (self.r * rhs.r) - (self.j * rhs.j),

j : (self.r * rhs.j) + (self.j * rhs.r)

}

}

}

The **Add** trait contains one method, **add**, that performs the operation, and an associated type, **Output**, that describes the returned value of the operation. To unpack the types involved, implementing **Add**<X> **for** Z with the **Output** type of Y, for types X, Y, and Z, would provide an implementation of the operation where the left-hand side was a Z (the receiver of the method and the type for which the trait is being implemented), the right-hand side would be an X (the argument to the method), and the result of the operation would be a Y. The implementation of these two traits allows two complex numbers to be added or multiplied, producing a new complex number.

The implementation of division is similar to **Add** and **Mul**, but also illustrates the use of the overloaded operators previously defined.

impl Div<Complex> for Complex {

type Output = Complex;

fn div(self, rhs : Complex) -> Complex {

let rhs_conj = rhs.conjugate();

let num = self * rhs_conj;

let den = rhs * rhs_conj;

Complex { r : num.r / den.r, j : num.j / den.r }

}

}

In the code, **rhs_conj**, **num**, and **den** are **Complex** and the calculations of **num** and **den** involve the **Mul** trait above.

Another thing I would like to be able to do is to provide multiple implementations of, say, the **Add** trait, for other types:

impl Add<f64> for Complex {

type Output = Complex;

fn add(self, rhs : f64) -> Complex {

Complex { r : self.r + rhs, j : self.j }

}

}

impl Add<Complex> for f64 {

type Output = Complex;

fn add(self, rhs : Complex) -> Complex {

Complex { r : self + rhs.r, j : rhs.j }

}

}

Following the discussion above, the first implementation is used when a **Complex** number is on the left-hand side of the **+** and a **f64** floating point number is on the right. The second implementation is used when a **f64** is on the left-hand side of the **+** and a **Complex** number is on the right. Both traits return **Complex** numbers. Whereas previous versions of Rust had problems with this sort of convenience-overloading, Rust as of 1.0.0-beta allows these to work just like you think they would.

Unfortunately, there is a cloud attached to this silver lining: there are many types and many operations to overload (see below). Having explicit implementations of every trait for every combination of types would be maddening, especially since they are all very, very similar. This is where macros can be very useful.

Rust supports a hygenic, definition-by-example-ish macro system using **macro-rules!**. I cannot go into the entire system here (not the least because I have not explored it deeply), but the following definition provides a template implementation of the **Mul** trait for a given type argument:

macro_rules! scalar_impl (

($foo:ty) => (

// Implementation of multiplication for Complex and $foo

impl Mul^lt;$foo> for Complex {

type Output = Complex;

fn mul(self, rhs : $foo) -> Complex {

Complex { r : self.r * (rhs as f64), j : self.j * (rhs as f64) }

}

}

impl Mul^lt;Complex> for $foo {

type Output = Complex;

fn mul(self, rhs : Complex) -> Complex {

Complex { r : (self as f64) * rhs.r, j : (self as f64) * rhs.j }

}

}

)

);

This code defines a macro, **scalar_impl**! (the exclamation point is a necessary part of the macro's invocation). The definition uses one rule, meaning it supports one form where the macro is given a type argument (note that the type of **$foo** is **ty**). The expansion of this rule supplies the two, complementary, implementations of **Mul** involving the type **$foo**. The first is used when a **Complex** number is the left-hand side and a number of the type represented by **$foo** is the right-hand side; the second is used when a number of the type represented by **$foo** is the left-hand side.

Note that there are no conditions placed on the **$foo** type; the macro itself is not type checked but the expansion will be. In this case, the requirement is that the type **$foo** be convertible to a 64-bit floating point number (in the first implementation, **(rhs as f64)**; in the second, **(self as f64)**.) The macro is used as:

scalar_impl!(i8);

scalar_impl!(i16);

scalar_impl!(i32);

scalar_impl!(i64);

scalar_impl!(isize);

scalar_impl!(u8);

scalar_impl!(u16);

scalar_impl!(u32);

scalar_impl!(u64);

scalar_impl!(usize);

scalar_impl!(f64);

scalar_impl!(f32);

This list of invocations supplies implementations of the **Mul** trait for all of Rust's primitive numeric types.

How are these definitions used? Here are some complete, albeit useless, examples.

let w = 2.0.to_complex();

let x = Complex { r : 1.0, j : 0.0 };

let y = Complex { r : 3.0, j : 0.0 };

let z = x + y;

println!(" z: {:?}", z);

// => z: Complex { r: 4, j: 0 }

This first example shows the automatically derived debugging format, which includes the structure and field names.

println!("{}", ( z / w ));

// => 2 + 0j

println!("{}", ( y + 3.0 ));

// => 6 + 0j

println!("{}", ( 3.0 + y ));

// => 6 + 0j

println!("{}", ( y * 3isize ));

// => 9 + 0j

println!("{}", ( y * 3.0f64 ));

// => 9 + 0j

println!("{}", ( 4u8 * y ));

// => 12 + 0j

These examples demonstrate the basic arithmetic operations, and their use with various numeric types. The final examples show some of the **Complex** number's party tricks.

let n = Complex { r : 0.0, j : 1.0 };

println!("{}", ( n * n ));

// => -1 + 0j

println!("{}", ( (n * n) * 2 ));

// => -2 + 0j

let mu : Complex = (n * n) * 2;

println!("{}", mu.magnitude() );

// => 2

The first, that \((0+i)^2\) is \(-1\), the second, that \((0+i)^2 * 2\) is \(-2\), and the final that the magnitude of \(-2+0i\) is 2.

Isn't that just lovely, hmm?

The operators which can be overloaded, as of Rust 1.0.0-beta, are:

Operator | Trait |
---|---|

Arithmetic | |

+ | std::ops::Add |

- | std::ops::Sub |

* | std::ops::Mul |

/ | std::ops::Div |

% | std::ops::Rem |

- (unary negation) | std::ops::Neg |

Bitwise | |

& | std::ops::BitAnd |

| | std::ops::BitOr |

^ (exclusive or) | std::ops::BitXor |

<< (shift left) | std::ops::Shl |

>> (shift right) | std::ops::Shr |

Miscellaneous | |

! (Boolean negation) | std::ops::Not |

a[i] (indexing, immutable context) | std::ops::Index |

a[i] (indexing, mutable context) | std::ops::IndexMut |

*v (dereference, immutable context) | std::ops::Deref |

*v (dereference, mutable context) | std::ops::DerefMut |

Comparison | |

== | std::cmp::PartialEq and std::cmp::Eq |

!= | std::cmp::PartialEq and std::cmp::Eq |

< | std::cmp::PartialOrd and std::cmp::Ord |

> | std::cmp::PartialOrd and std::cmp::Ord |

<= | std::cmp::PartialOrd and std::cmp::Ord |

>= | std::cmp::Ord |

The two packages there are **std::ops** and **std::cmp**.

Once upon a time, in order to preserve sanity, Rust limited trait implementation. As kibwen pointed out in the reddit discussion, there were restrictions on where traits, types, and the implementations of traits for types could legitimately appear. Specifically, the implementation had to be in the same crate as either the type or the trait. kibwen writes,

Now note that the overloading traits are defined in libcore, which is shipped with the Rust compiler.

The implication then is that it is only possible to overload operators on types that you've defined yourself. You never have to worry about library A attempting distant overloads on types from library B; this also means that you never have to worry about libraries changing what 2+2 means.

That previous restriction has been removed (sometime in fall 2014, I think), as can be seen in the use of Add<Complex> for f64. The resulting changes go by the name of multidispatch and conditional dispatch, and the simple crate-based coherence rules have been replaced by something more complex that can only be described by a blog post. Or, at least, I don't have a good description of it. It's got something to do with those trait's associated types. Here are some references:

- Multi- and Conditional Dispatch in Traits
- A RFC 24 "trait reform rfc"
- RFC 195 Associated items
- Little Orphan Impls
- Implement associated items #17307
- RFC Associated items and multidispatch traits
- Some multiple dispatch magic

I must note that operator overloading, even with Rust's limits on it, is not something to be used without considerable thought. If you misuse overloading, doing something like C++'s 'cout << "hello world"', a Rust developer will find you. And Fix. Your. Little. Red. Wagon.

The source code for these examples in on github.

I would like to thank the commenters from Reddit, particularly englabenny, dfjkfskjhfshdfjhsdjl, and kibwen, for their help.

### Comments

In the last code snippet, how does NumCast::from know that you want to get a Complex back?

Alex

'2013-07-19T18:28:11.230-05:00'

The compiler knows what types are expected in the expressions where NumCast::from is used. For example, in "x * NumCast::from(3.0f)", x is a Complex and the only thing the compiler knows how to multiply a Complex number by is another Complex number, so it looks for an implementation of NumCast for Complex.

The tricky one is "let w = NumCast::from(2);" because there is no type information provided for w in that line. However, the compiler also knows that w is used in the expression "z / w" where z is Complex (since z is the result of adding two Complex's), so w also has to be Complex.

If you just had the "let w = ..." line, without a use of w that pinned-down its type or type annotations, the compiler would throw an error.

Tommy McGuire

'2013-07-22T11:32:32.184-05:00'