Schnelles AVX512 Modulo bei gleichem Divisor

Lesezeit: 14 Minuten

Benutzeravatar von Nuutti
Nuutti

Ich habe versucht, Teiler für potenzielle Primzahlen (Zahl der Form n!+-1) zu finden, und weil ich kürzlich eine Skylake-X-Workstation gekauft habe, dachte ich, dass ich mit den AVX512-Anweisungen etwas schneller werden könnte.

Der Algorithmus ist einfach und der Hauptschritt besteht darin, Modulo wiederholt in Bezug auf denselben Divisor zu nehmen. Die Hauptsache ist, einen großen Bereich von n Werten zu durchlaufen. Hier ist ein naiver Ansatz, der in c geschrieben ist (P ist eine Primzahltabelle):

uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
    residue = 2;
    for (n=3; n <= nmax; n++){
        residue *=  n;
        residue %= P[i];
        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1){
                report_factor(n, -1, P[i]);
            }
            if(residue == P[i]- 1){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

Hier besteht die Idee darin, einen großen Bereich von n, zB 1.000.000 -> 10.000.000, mit demselben Teilersatz zu vergleichen. Also werden wir den Modulo-Respekt mehrere Millionen Mal auf denselben Divisor setzen. Die Verwendung von DIV ist sehr langsam, daher gibt es je nach Umfang der Berechnungen mehrere mögliche Ansätze. Hier ist in meinem Fall n höchstwahrscheinlich kleiner als 10 ^ 7 und der potenzielle Teiler p ist kleiner als 10.000 G (< 10 ^ 13). Zahlen sind also kleiner als 64 Bit und auch kleiner als 53 Bit!, aber das Produkt von der maximale Rest (p-1) mal n ist größer als 64 Bit. Also dachte ich, dass die einfachste Version der Montgomery-Methode nicht funktioniert, weil wir Modulo von einer Zahl nehmen, die größer als 64-Bit ist.

Ich habe einen alten Code für Power PC gefunden, bei dem FMA verwendet wurde, um ein genaues Produkt mit bis zu 106 Bit (schätze ich) zu erhalten, wenn Doubles verwendet werden. Also habe ich diesen Ansatz auf AVX 512 Assembler (Intel Intrinsics) umgestellt. Hier ist eine einfache Version der FMA-Methode, die auf der Arbeit von Dekker (1971), dem Dekker-Produkt und der FMA-Version von TwoProduct basiert. Dies sind nützliche Wörter, wenn Sie versuchen, Gründe dafür zu finden / zu googeln. Auch dieser Ansatz wurde in diesem Forum (zB hier) diskutiert.

int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;

for (i = 0; i < APP_BUFLEN; i++){
    residue = 2.0;
    prime_double = (double)P[i];
    prime_double_reciprocal = 1.0 / prime_double;
    n_double = 3.0;
    for (n=3; n <= nmax; n++){
        nr =  n_double * residue;
        quotient = fma(nr, prime_double_reciprocal, rounding_constant);
        quotient -= rounding_constant;
        prime_times_quotient_high= prime_double * quotient;
        prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
        residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;

        if (residue < 0.0) residue += prime_double;
        n_double += 1.0;

        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1.0){
                report_factor(n, -1, P[i]);
            }
            if(residue == prime_double - 1.0){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

Hier habe ich eine magische Konstante verwendet

static const double rounding_constant = 6755399441055744.0; 

das ist 2^51 + 2^52 magische Zahl für Doppel.

Ich habe dies in AVX512 (32 potenzielle Teiler pro Schleife) konvertiert und das Ergebnis mit IACA analysiert. Es wurde mitgeteilt, dass Durchsatzengpass: Backend und Backend-Zuweisung aufgrund nicht verfügbarer Zuweisungsressourcen ins Stocken geraten waren. Ich bin nicht sehr erfahren mit Assembler, also ist meine Frage, ob ich irgendetwas tun kann, um dies zu beschleunigen und diesen Backend-Engpass zu lösen?

AVX512-Code ist hier und kann auch von gefunden werden github

uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA

const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};


uint64_t n;

__m512d zero, rounding_const, one, n_double;

__m512i prime1, prime2, prime3, prime4;

__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;

__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;

// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();

// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);

// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd

// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);

// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);

// residue init
residue1 =  _mm512_set1_pd(2.0);
residue2 =  _mm512_set1_pd(2.0);
residue3 =  _mm512_set1_pd(2.0);
residue4 =  _mm512_set1_pd(2.0);

// double counter init
n_double = _mm512_set1_pd(3.0);

// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000

for (n=3; n<=nmax; n++) // main loop
{

    // timings for instructions:
    // _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
    // _mm512_load_pd = vmovapd : L 1, T 0.5
    // _mm512_set1_pd
    // _mm512_div_pd = vdivpd : L 23, T 16
    // _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5

    // _mm512_mul_pd = vmulpd :  L 4, T 0.5
    // _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd :  L 4, T 0.5
    // _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
    // _mm512_sub_pd = vsubpd : L 4, T 0.5
    // _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
    // _mm512_mask_add_pd = vaddpd :  L 4, T 0.5
    // _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
    // _mm512_kor = korw L 1, T 1

    // nr = residue *  n
    nr1 = _mm512_mul_pd (residue1, n_double);
    nr2 = _mm512_mul_pd (residue2, n_double);
    nr3 = _mm512_mul_pd (residue3, n_double);
    nr4 = _mm512_mul_pd (residue4, n_double);

    // quotient = nr * 1.0/ prime_double + rounding_constant
    quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
    quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
    quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
    quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);

    // quotient -= rounding_constant, now quotient is rounded to integer
    // countient should be at maximum nmax (10,000,000)
    quotient1 = _mm512_sub_pd(quotient1, rounding_const);
    quotient2 = _mm512_sub_pd(quotient2, rounding_const);
    quotient3 = _mm512_sub_pd(quotient3, rounding_const);
    quotient4 = _mm512_sub_pd(quotient4, rounding_const);

    // now we calculate high and low for prime * quotient using decker product (FMA).
    // quotient is calculated using approximation but this is accurate for given quotient
    prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
    prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
    prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
    prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);


    prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
    prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
    prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
    prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);

    // now we calculate new reminder using decker product and using original values
    // we subtract above calculated prime * quotient (quotient is aproximation)

    residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
    residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
    residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
    residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);

    residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
    residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
    residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
    residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);

    // lets check if reminder < 0
    negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
    negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
    negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
    negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);

    // we and prime back to reminder using mask if it was < 0
    residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
    residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
    residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
    residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);

    n_double = _mm512_add_pd(n_double,one);

    // if we are below nmin then we continue next iteration
    if (n < nmin) continue;

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
    { // we find factor very rarely

        double *residual_list1 = (double *) &residue1;
        double *residual_list2 = (double *) &residue2;
        double *residual_list3 = (double *) &residue3;
        double *residual_list4 = (double *) &residue4;

        double *prime_list1 = (double *) &prime_double1;
        double *prime_list2 = (double *) &prime_double2;
        double *prime_list3 = (double *) &prime_double3;
        double *prime_list4 = (double *) &prime_double4;



        for (int i=0; i <8; i++){
            if( residual_list1[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
            }
            if( residual_list2[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
            }
            if( residual_list3[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
            }
            if( residual_list4[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
            }

            if(residual_list1[i] == (prime_list1[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
            }
            if(residual_list2[i] == (prime_list2[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
            }
            if(residual_list3[i] == (prime_list3[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
            }
            if(residual_list4[i] == (prime_list4[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
            }
        }
    }

}

return EXIT_SUCCESS;
}

  • Stimmen Sie für eine detaillierte und gut gestellte Frage ab. Willkommen bei Stapelüberlauf!

    – fuz

    17. Dezember 2017 um 14:07 Uhr

  • Nur aus Neugier, tut dies if(residue == prime_double - 1.0) arbeiten zuverlässig (==)? Es ist mir nicht klar, wenn ich nur die Quelle lese, dass die Werte nur ganzzahlig und innerhalb der doppelten Mantissengrenzen bleiben, sodass keine niedrigen Ziffern verloren gehen. Aber es kann sein, hängt davon ab fma Implementierung … fühlt sich für mich immer noch zerbrechlich genug an, um einen zusätzlichen Quellenkommentar wert zu sein, warum es funktionieren sollte.

    – Ped7g

    17. Dezember 2017 um 18:33 Uhr

  • @Nuutti: Ein Back-End-Engpass beim FMA-Durchsatz ist gut, es bedeutet, dass Sie den FMA-Durchsatz der Maschine sättigen, anstatt einen Engpass bei der Latenz oder dem Front-End zu verursachen. (Ich denke, das meinen Sie mit “Ressourcenzuweisung”, aber posten Sie die IACA-Zusammenfassungsausgabe.) Es wird immer einen Engpass irgendeiner Art geben. Für die korrekte Anwendung von Brute-Force ist der FMA-Durchsatz (Port0 / Port5 gesättigt) der Engpass, den Sie erreichen möchten. Wenn Sie schneller laufen, müssen Sie Ihre Operationen neu kombinieren, um mehr FMA und weniger Add / Mul auszuführen oder auf andere Weise Operationen zu sparen, aber das ist mit exakten Ergebnissen möglicherweise nicht möglich.

    – Peter Cordes

    17. Dezember 2017 um 19:03 Uhr

  • IACA_Trace_Analyse: github.com/NudeSurfer/Factoring/blob/master/… IACA-Analyse: github.com/NudeSurfer/Factoring/blob/master/IACA_analysis.txt

    – Nuutti

    17. Dezember 2017 um 20:57 Uhr


  • Außerdem müssen Sie nicht so schnell verzweigen. Unter der Annahme, dass die Wahrscheinlichkeit, dass ein bestimmter Faktor erfolgreich ist, extrem gering ist, können Sie einfach alle Masken zusammen verodern und sie alle 1000 überprüfen? Iterationen? Wenn es dann einen Erfolg zeigt, können Sie den Block erneut ausführen, um genau herauszufinden, um welchen Faktor es sich handelt.

    – Mystisch

    18. Dezember 2017 um 16:47 Uhr


Wie einige Kommentatoren vorgeschlagen haben: Ein “Backend” -Engpass ist das, was Sie für diesen Code erwarten würden. Das deutet darauf hin, dass Sie die Dinge ziemlich gut ernähren, was Sie wollen.

Wenn man sich den Bericht ansieht, sollte es in diesem Abschnitt eine Gelegenheit geben:

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)

Aus der IACA-Analyse:

|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r11d, k0
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw eax, k1
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw ecx, k2
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw esi, k3
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw edi, k4
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r8d, k5
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r9d, k6
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r10d, k7
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, eax
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, ecx
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, esi
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, edi
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, r8d
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, r9d
|   1*     |             |      |             |             |      |      |      |      | or r11d, r10d

Der Prozessor verschiebt die resultierenden Vergleichsmasken (k0-k7) für die “oder”-Operation in reguläre Register. Sie sollten in der Lage sein, diese Züge zu eliminieren UND das „oder“-Rollup in 6ops vs. 8 durchzuführen.

HINWEIS: Die Typen found_factor_mask sind definiert als __mmask8wo sie sein sollten __mask16 (16x Double Floats in einem 512-Bit-Effekt). Dadurch könnte der Compiler einige Optimierungen vornehmen. Wenn nicht, gehen Sie zur Versammlung, wie ein Kommentator bemerkte.

Und verwandt: Welcher Bruchteil der Iterationen löst diese Or-Mask-Klausel aus? Wie ein anderer Kommentator bemerkte, sollten Sie in der Lage sein, dies mit einer kumulierenden “oder” -Operation aufzurollen. Überprüfen Sie den akkumulierten “oder”-Wert am Ende jeder abgewickelten Iteration (oder nach N Iterationen), und wenn er “wahr” ist, gehen Sie zurück und wiederholen Sie die Werte, um herauszufinden, welche n-Werte ihn ausgelöst haben.

(Und Sie können innerhalb der “Rolle” binär suchen, um den passenden n-Wert zu finden – das könnte einen gewissen Gewinn bringen).

Als nächstes sollten Sie in der Lage sein, diesen Mid-Loop-Check loszuwerden:

    // if we are below nmin then we continue next iteration, we
    if (n < nmin) continue;

Was hier auftaucht:

|   1*     |             |      |             |             |      |      |      |      | cmp r14, 0x3e8
|   0*F    |             |      |             |             |      |      |      |      | jb 0x229

Es ist vielleicht kein großer Gewinn, da der Prädiktor (wahrscheinlich) diesen (meistens) richtig macht, aber Sie sollten einige Gewinne erzielen, indem Sie zwei unterschiedliche Schleifen für zwei “Phasen” haben:

  • n=3 bis n=nmin-1
  • n=nmin und darüber hinaus

Selbst wenn Sie einen Zyklus gewinnen, sind das 3%. Und da dies im Allgemeinen mit der großen „Oder“-Operation oben zusammenhängt, ist möglicherweise mehr Cleverness darin zu finden.

  • Das Entfernen des Zweigs und das Aufteilen der Schleife in zwei Phasen wird wahrscheinlich überhaupt nicht helfen, wenn der Code wirklich ist Backend gebunden, selbst wenn es genommen wird und möglicherweise einige Front-End-Blasen erzeugt. cmp/jcc läuft auf Port 6, der keine Vektor-ALUs hat. Aber einen Versuch ist es wert, und ein geringerer uop-Durchsatz macht es etwas hyperthreading-freundlicher, zu den sehr geringen Kosten eines etwas größeren uop-Cache-Fußabdrucks.

    – Peter Cordes

    25. September 2018 um 0:35 Uhr


1390610cookie-checkSchnelles AVX512 Modulo bei gleichem Divisor

This website is using cookies to improve the user-friendliness. You agree by using the website further.

Privacy policy