Implementing SetDiff in Crystal

I bringing attention to this issue as it might improve Crystal.

It has to do with the fastest way to remove elements in one array|list from another.

It arose from how different language performed this task, as presented in this thread.

I have two arrays, lhr and lhr_del.
I want to remove|delete the elements in lhr_del from lhr.

In Crystal|Ruby you can just do: lhr -= lhr_del

In D it’s: lhr_del.sort!("a < b"); lhr = setDifference(lhr, lhr_del).array;

In Rust, it was originally done as: lhr.retain(|&m| !lhr_del.contains(&m));

This is conceptually keeping the members in lhr not in lhr_del.

This is slow in Rust, and was greatly improved doing:

lhr_del.sort_unstable();
lhr.retain(|&m| !lhr_del.binary_search(&m).is_ok());

Someone in the Rust forum then did a Rust equivalent of D’s setDifference to allow doing this:

lhr_del.sort_unstable();
let lhr: Vec<u32> = SetDiff::new(&lhr, &lhr_del).collect();

This is even faster. Here’s the Rust code implementation for the u32 version.
Change the u32 references to usize for u64 size numbers (which use more memory).

struct SetDiff<'a> {
  r1: &'a [u32],
  r2: &'a [u32],
}

impl<'a> SetDiff<'a> {
  fn adjust_pos(&mut self) {
    while !self.r1.is_empty() {
      if self.r2.is_empty() || self.r1[0] < self.r2[0] {
        break;
      } else if self.r2[0] < self.r1 [0] {
        self.r2 = &self.r2[1..];
      } else {
        self.r1 = &self.r1[1..];
        self.r2 = &self.r2[1..];
  } } }

  fn new(r1: &'a [u32], r2: &'a [u32]) -> Self {
    let mut s = SetDiff{ r1, r2 };
    s.adjust_pos();
    s
} }

impl<'a> Iterator for SetDiff<'a> {
  type Item = u32;

  fn next(&mut self) -> Option<Self::Item> {
    let val = self.r1.get(0).copied();
    if val.is_some() {
      self.r1 = &self.r1[1..];
    }
    self.adjust_pos();
    return val
} }

Here are u32 time comparisons for the last two Rust versions and D.
An indexing limitation, fixed in Crystal 1.16, limits it to just the first value here.

----------------------------------------------------
       n      | b-srch | setdif |   D    | Crystal
------------_-|--------|--------|--------|----------
  100,000,000 |  3.35  |  2.96  |  7.21  | 10.25
--------------|--------|------- |--------|----------
  200,000,000 |  7.77  |  6.19  | 15.86  |
------- ------|--------|--------|--------|
  300,000,000 |  6.40  |  5.73  | 10.89  |
--------------|--------|--------|--------|
  400,000,000 | 14.44  | 12.80  | 33.13  |
---- ---------|--------|--------|--------|
  500,000,000 | 18.44  | 16.32  | 42.58  |
--------------|--------|--------|--------|
  600,000,000 | 13.47  | 12.05  | 22.22  |
---- ---------|--------|--------|--------|
  700,000,000 | 21.72  | 19.51  | 46.19  |
--------------|--------|--------|--------|
  800,000,000 | 30.97  | 27.51  | 71.44  |
---- ---------|--------|--------|--------|
  900,000,000 | 22.95  | 18.30  | 35.66  |
--------------|--------|--------|--------|
1,000,000,000 | 38.99  | 34.81  | 88.74  |

Here are the source files, and compiling settings.

# Crystal 1.15.1
# Compile: $ crystal build --release --mcpu native prime_pairs_lohi.cr
# Run as:  $ ./primes_pairs_lohi 123_456_780

def prime_pairs_lohi(n)
  return puts "Input not even n > 2" unless n.even? && n > 2
  return (pp [n, 1]; pp [n//2, n//2]; pp [n//2, n//2]) if n <= 6

  # generate the low-half-residues (lhr) r < n/2
  lhr = 3u64.step(to: n//2, by: 2).select { |r| r if r.gcd(n) == 1 }.to_a
  ndiv2, rhi = n//2, n-2           # lhr:hhr midpoint, max residue limit

  # store all powers and cross-products of the lhr members < n-2
  lhr_mults = [] of typeof(n)      # lhr multiples, not part of a pcp
  lhr_dup = lhr.dup                # make copy of the lhr members list
  while (r = lhr_dup.shift) && !lhr_dup.empty? # do mults of current r w/others
    rmax = rhi // r                # ri can't multiply r with values > this
    lhr_mults << r * r if r < rmax # for r^2 multiples
    break if lhr_dup[0] > rmax     # exit if product of consecutive r’s > n-2
    lhr_dup.each do |ri|           # for each residue in reduced list
      break if ri > rmax           # exit for r if cross-product with ri > n-2
      lhr_mults << r * ri          # store value if < n-2
    end                            # check cross-products of next lhr member
  end

  # remove from lhr its lhr_mults, convert vals > n/2 to lhr complements first
  lhr -= lhr_mults.map { |r_del| r_del > ndiv2 ? n - r_del : r_del }

  pp [n, lhr.size]                 # show n and pcp prime pairs count
  pp [lhr.first, n-lhr.first]      # show first pcp prime pair of n
  pp [lhr.last,  n-lhr.last]       # show last  pcp prime pair of n
end

def gen_pcp
  n = (ARGV[0].to_u64 underscore: true) # get n input from terminal
  t1 = Time.monotonic             # start execution timing
  prime_pairs_lohi(n)             # execute code
  pp Time.monotonic - t1          # show execution time
end

gen_pcp
---------------------------------------------------------------
/*
  D LDC2 1.40
  Compile with ldc2: $ ldc2 --release -O3 -mcpu native prime_pairs_lohi.d
  Run as: $ ./prime_pairs_lohi_u32 123_456
*/
module prime_pairs;

import std;
import std.datetime.stopwatch : StopWatch;

void prime_pairs_lohi(uint n) {     // inputs can be of size u32
  if ((n&1) == 1 || n < 4) { return writeln("Input not even n > 2"); }
  if (n <= 6) { writeln([n, 1]); writeln([n/2, n/2]); writeln([n/2, n/2]); return; }

  // generate the low-half-residues (lhr) r < n/2
  auto ndiv2 = n/2;                 // llr:hhr midpoint
  auto rhi   = n-2;                 // max residue limit
  uint[] lhr = iota(3, ndiv2, 2).filter!(e => gcd(e, n) == 1).array;

  // identify and store all powers and cross-products of the lhr members < n-2
  uint[] lhr_del;                   // lhr multiples, not part of a pcp
  foreach(i, r; lhr) {              // iterate thru lhr to find prime multiples
    auto rmax = rhi / r;            // ri can't multiply r with values > this
    if (r < rmax) lhr_del ~= (r*r < ndiv2) ? r*r : n - r*r; // for r^2 multiples
    if (lhr[i+1] > rmax) break  ;   // exit if product of consecutive r’s > n-2
    foreach(ri; lhr[i+1..$]) {      // for each residue in reduced list
      if (ri > rmax) break;         // exit for r if cross-product with ri > n-2
      lhr_del ~= (r*ri < ndiv2) ? r*ri : n - r*ri;         // store value if < n-2
  } }

  // remove from lhr its lhr mulitples, the pcp remain
  lhr_del.sort!("a < b");
  lhr = setDifference(lhr, lhr_del).array;

  writeln([n,     lhr.length]);     // show n and pcp prime pairs count
  writeln([lhr[0],  n-lhr[0]]);     // show first pcp prime pair of n
  writeln([lhr[$-1],n-lhr[$-1]]);   // show last  pcp prime pair of n
}

void main(string[] args) {          // directly recieve input from terminal
  string[] inputs = args[1..$];     // can include '_': 123_456
  auto nums = inputs.map!(i => i.filter!(n => n != '_'));
  auto n    = nums.map!(f => f.to!uint())[0];

  auto timer = StopWatch();         // create execution timer
  timer.start();                    // start it
  prime_pairs_lohi(n);              // run routine
  writeln(timer.peek());            // show timer results
}
------------------------------------------------------------------------------------

/*
   Rust 1.84.1
   Can compile as: $ cargo build --release
   or: $ RUSTFLAGS="-C opt-level=3 -C debuginfo=0 -C target-cpu=native" cargo build --release
   Run as: $ ./prime_pairs_lohi 123_456
*/

use std::time::Instant;

fn coprime(mut m: u32, mut n: u32) -> bool {
  while m|1 != 1 { let t = m; m = n % m; n = t }
  m > 0
}

fn prime_pairs_lohi(n : u32) {               // for u32 input values
  if n&1 == 1 || n < 4 { return println!("Input not even n > 2"); }
  if n <= 6 { return println!("[{}, {}]\n[{}, {}]\n[{}, {}]",n,1,n/2,n/2,n/2,n/2); };

  // generate the low-half-residues (lhr) r < n/2
  let (ndiv2, rhi) = (n/2, n-2);             // lhr:hhr midpoint, max residue limit
  let mut lhr: Vec<u32> = (3..ndiv2).step_by(2).filter(|&r| coprime(r, n)).collect();

  // identify and store all powers and cross-products of the lhr members < n-2
  let mut lhr_del = Vec::with_capacity(lhr.len() as usize); // lhr multiples, not part of a pcp
  for i in 1..lhr.len()-1 {                  // iterate thru lhr to find prime multiples
    let (mut j, r) = (i, lhr[i-1]);          // for current residue
    let rmax = rhi / r;                      // ri can't multiply r with values > this
    if r < rmax { lhr_del.push(if r*r < ndiv2 {r*r} else {n - r*r} ); } // for r^2 multiples
    if lhr[i] > rmax { break }               // exit if product of consecutive r’s > n-2
    while lhr[j] <= rmax {                   // stop for r if cross-product with ri > n-2
      lhr_del.push(if r*lhr[j] < ndiv2 {r*lhr[j]} else {n - r*lhr[j]}); // store value if < n-2
      j += 1;                                // get next lhr value
  } }

  lhr_del.sort_unstable();                   // remove from lhr its lhr mults, pcp remain
  lhr.retain(|&m| !lhr_del.binary_search(&m).is_ok());
  let lcnt = lhr.len();                      // number of pcp prime pairs
  println!("[{}, {}]", n, lcnt);             // show n and pcp prime pairs count
  println!("[{}, {}]", lhr[0],n-lhr[0]);     // show first pcp prime pair of n
  println!("[{}, {}]", lhr[lcnt-1], n-lhr[lcnt-1]); // show last  pcp prime pair of n
}

fn main() {
  let n: u32 = std::env::args()
    .nth(1).expect("missing count argument")
    .replace('_', "").parse().expect("one input");

  let start = Instant::now();
  prime_pairs_lohi(n);
  println!("{:?}", start.elapsed());
}
---------------------------------------------------------------------------------------------
/*
   Rust 1.84.1
   Can compile as: $ cargo build --release
   or: $ RUSTFLAGS="-C opt-level=3 -C debuginfo=0 -C target-cpu=native" cargo build --release
   Run as: $ ./prime_pairs_lohi 123_456
*/

use std::time::Instant;

struct SetDiff<'a> {
  r1: &'a [u32],
  r2: &'a [u32],
}

impl<'a> SetDiff<'a> {
  fn adjust_pos(&mut self) {
    while !self.r1.is_empty() {
      if self.r2.is_empty() || self.r1[0] < self.r2[0] {
        break;
      } else if self.r2[0] < self.r1 [0] {
        self.r2 = &self.r2[1..];
      } else {
        self.r1 = &self.r1[1..];
        self.r2 = &self.r2[1..];
  } } }

  fn new(r1: &'a [u32], r2: &'a [u32]) -> Self {
    let mut s = SetDiff{ r1, r2 };
    s.adjust_pos();
    s
} }

impl<'a> Iterator for SetDiff<'a> {
  type Item = u32;

  fn next(&mut self) -> Option<Self::Item> {
    let val = self.r1.get(0).copied();
    if val.is_some() {
      self.r1 = &self.r1[1..];
    }
    self.adjust_pos();
    return val
} }

fn coprime(mut m: u32, mut n: u32) -> bool {
  while m|1 != 1 { let t = m; m = n % m; n = t }
  m > 0
}

fn prime_pairs_lohi(n : u32) {               // for u32 input values
  if n&1 == 1 || n < 4 { return println!("Input not even n > 2"); }
  if n <= 6 { return println!("[{}, {}]\n[{}, {}]\n[{}, {}]",n,1,n/2,n/2,n/2,n/2); };

  // generate the low-half-residues (lhr) r < n/2
  let (ndiv2, rhi) = (n/2, n-2);             // lhr:hhr midpoint, max residue limit
  let mut lhr: Vec<u32> = (3..ndiv2).step_by(2).filter(|&r| coprime(r, n)).collect();

  // identify and store all powers and cross-products of the lhr members < n-2
  let mut lhr_del = Vec::with_capacity(lhr.len() as usize); // lhr multiples, not part of a pcp
  for i in 1..lhr.len()-1 {                  // iterate thru lhr to find prime multiples
    let (mut j, r) = (i, lhr[i-1]);          // for current residue
    let rmax = rhi / r;                      // ri can't multiply r with values > this
    if r < rmax { lhr_del.push(if r*r < ndiv2 {r*r} else {n - r*r} ); } // for r^2 multiples
    if lhr[i] > rmax { break }               // exit if product of consecutive r’s > n-2
    while lhr[j] <= rmax {                   // stop for r if cross-product with ri > n-2
      lhr_del.push(if r*lhr[j] < ndiv2 {r*lhr[j]} else {n - r*lhr[j]}); // store value if < n-2
      j += 1;                                // get next lhr value
  } }

  lhr_del.sort_unstable();                   // remove from lhr its lhr mults, pcp remain
  let lhr: Vec<u32> = SetDiff::new(&lhr, &lhr_del).collect();
  let lcnt = lhr.len();                      // number of pcp prime pairs
  println!("[{}, {}]", n, lcnt);             // show n and pcp prime pairs count
  println!("[{}, {}]", lhr[0],n-lhr[0]);     // show first pcp prime pair of n
  println!("[{}, {}]", lhr[lcnt-1], n-lhr[lcnt-1]); // show last  pcp prime pair of n
}

fn main() {
  let n: u32 = std::env::args()
    .nth(1).expect("missing count argument")
    .replace('_', "").parse().expect("one input");

  let start = Instant::now();
  prime_pairs_lohi(n);
  println!("{:?}", start.elapsed());
}

Removing one dataset from another is a frequent|important operation in machine learning, statistics, data analytics, etc. Having a very fast, mem efficient, way to do this in Crystal would make it even more attractive for use in such fields.

1 Like

Sadly SetDiff algorithm cannot be used as a direct replacement to lhr -= lhr_del because it requires elements to be sorted. So i’m not sure that it is a good addition to std library.
As a sketch, here is an implementation for Indexable values (with some macro magic, it can be modified to use also Iterable, but creating Iterator creates overhead anyway).

struct SetDiff(T)
  @r1 : Indexable(T)
  @r2 : Indexable(T)

  def initialize(@r1, @r2)
  end

  include Enumerable(T)

  def each(&)
    r1i = 0
    r2i = 0
    r1size = @r1.size
    r2size = @r2.size
    loop do
      # adjust pos
      loop do
        return if r1i >= r1size
        break if r2i >= r2size || @r1[r1i] < @r2[r2i]
        if @r2[r2i] < @r1[r1i]
          r2i += 1
        else
          r1i += 1
          r2i += 1
        end
      end
      yield @r1[r1i]
      r1i += 1
    end
  end
end

def set_diff(a, b, &)
  SetDiff.new(a, b).each { yield }
end

def set_diff(a, b)
  SetDiff.new(a, b).to_a
end

puts set_diff([1, 2, 3], [2, 4])

https://carc.in/#/r/hpel

I wasn’t suggesting it replace being able to do: lhr -= lhr_del, but to see it it could have a place in Crystal. The Set class was my immediate thought where it could be added.

Also, I see this as something for Crystal 2.0 anyway.

Rust and D are currently faster doing the same tasks. I suspect it has to do with their memory models, but I’m not a compiler person, so that’s just my speculation. But Crystal needs to get closer to their performance and memory efficiency. Since Rust and Crystal are based on LLVM, hopefully that can be done in 2.0.