NN-512

Back

Index

Files

Top || go.mod

module NN-512

go 1.16

Top || main.go

// NN-512 (https://NN-512.com)
//
// Copyright (C) 2019 [
// 37ef ced3 3727 60b4
// 3c29 f9c6 dc30 d518
// f4f3 4106 6964 cab4
// a06f c1a3 83fd 090e
// ]
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in
// the documentation and/or other materials provided with the
// distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

package main

import (
"NN-512/internal/compile"
"NN-512/internal/doc"
"NN-512/internal/example"
"NN-512/internal/serve"
"NN-512/internal/version"
"errors"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"strings"
)

const (
newline = "\n"
space = " "
indent = space + space + space + space
usage = newline + "Usage:" + newline + newline + indent + "NN-512" + space
)

func cmdCompile() error {
if len(os.Args) == 4 {
from := os.Args[2]
if from == "-" {
from = "/dev/stdin"
}
text, err := ioutil.ReadFile(from)
if err != nil {
return err
}
result, err := compile.Compile(string(text))
if err != nil {
return err
}
prefix := filepath.Join(os.Args[3], result.Name)
const perm os.FileMode = 0666
if err := ioutil.WriteFile(prefix+".h", result.H, perm); err != nil {
return err
}
return ioutil.WriteFile(prefix+".c", result.C, perm)
}
return errors.New(usage +
os.Args[1] + space + "GRAPH" + space + "DIR" + newline +
newline +
"The GRAPH argument specifies an input file that contains a" + newline +
"graph language description of a neural net. - means stdin." + newline +
newline +
indent + "Example: densenet.graph" + newline +
indent + "Example: ../graphs/denseNet" + newline +
indent + "Example: /opt/nets/DenseNet" + newline +
indent + "Example: -" + newline +
newline +
"The DIR argument specifies an output directory where the" + newline +
"generated C99 files will be written." + newline +
newline +
indent + "Example: ." + newline +
indent + "Example: ../src" + newline +
indent + "Example: /tmp/" + newline)
}

func cmdDoc() error {
if len(os.Args) > 2 {
return errors.New(usage + os.Args[1] + newline)
}
_, err := os.Stdout.Write(doc.Bytes())
return err
}

func cmdExample() error {
if len(os.Args) == 3 {
if gen := example.Generate(os.Args[2]); gen != nil {
_, err := os.Stdout.Write(gen)
return err
}
}
list := strings.Join(example.Names(), newline+indent)
return errors.New(usage +
os.Args[1] + space + "NAME" + newline +
newline +
"The NAME argument can be:" + newline +
newline +
indent + list + newline)
}

func cmdServe() error {
if len(os.Args) == 7 {
var (
addr = os.Args[2]
cert = os.Args[3]
key = os.Args[4]
src = os.Args[5]
bin = os.Args[6]
)
return serve.Website(addr, cert, key, src, bin)
}
return errors.New(usage +
os.Args[1] + space +
"ADDR" + space +
"CERT" + space + "KEY" + space +
"SRC" + space + "BIN" + newline +
newline +
"The ADDR argument specifies the TCP network addresses to" + newline +
"listen on (in the form \"host:port\"). If the host spec is" + newline +
"omitted, the server will listen on all available unicast" + newline +
"and anycast IP addresses of the local system." + newline +
newline +
indent + "Example: :https" + newline +
indent + "Example: :443" + newline +
indent + "Example: 127.0.0.1:https" + newline +
indent + "Example: localhost:4321" + newline +
indent + "Example: 173.230.145.5:443" + newline +
indent + "Example: [2600:3c01::2]:443" + newline +
newline +
"The CERT argument specifies a PEM-encoded HTTPS certificate" + newline +
"file. If the certificate is signed by a certificate authority," + newline +
"the file should consist of the server's certificate, then any" + newline +
"intermediate certificates, then the authority's certificate" + newline +
"(concatenated in that order)." + newline +
newline +
indent + "Example: fullchain.pem" + newline +
indent + "Example: tls/cert.pem" + newline +
indent + "Example: /root/certs/my.crt" + newline +
newline +
"The KEY argument specifies a PEM-encoded HTTPS private key" + newline +
"file that matches the previously specified certificate." + newline +
newline +
indent + "Example: privkey.pem" + newline +
indent + "Example: tls/key.pem" + newline +
indent + "Example: /root/certs/my.key" + newline +
newline +
"The SRC argument specifies a directory that contains the Go" + newline +
"source code of this program." + newline +
newline +
indent + "Example: /go/src/NN-512" + newline +
indent + "Example: /go/src/NN-512/" + newline +
indent + "Example: ../src/NN-512" + newline +
newline +
"The BIN argument specifies this program's executable file." + newline +
newline +
indent + "Example: /go/bin/NN-512" + newline +
indent + "Example: ../bin/NN-512" + newline)
}

func cmdVersion() error {
if len(os.Args) > 2 {
return errors.New(usage + os.Args[1] + newline)
}
_, err := os.Stdout.WriteString(
strconv.Itoa(version.Int) + newline,
)
return err
}

var cmds = [...]struct {
name string
hint string
call func() error
}{
{"compile", "Read graph language for a neural net and write C99.", cmdCompile},
{"doc", "Write documentation for the graph language to stdout.", cmdDoc},
{"example", "Write graph language for an example neural net to stdout.", cmdExample},
{"serve", "Serve this program's website as HTML over HTTPS.", cmdServe},
{"version", "Write the version number of this program to stdout.", cmdVersion},
}

func run() error {
if len(os.Args) >= 2 {
arg := os.Args[1]
for i := range &cmds {
if cmds[i].name == arg {
return cmds[i].call()
}
}
}
max := 0
for i := range &cmds {
if alt := len(cmds[i].name); max < alt {
max = alt
}
}
tot := max + len(indent)
var list string
for i := range &cmds {
name, hint := cmds[i].name, cmds[i].hint
align := strings.Repeat(space, tot-len(name))
list += indent + name + align + hint + newline
}
return errors.New(usage +
"COMMAND" + newline +
newline +
"The COMMAND argument can be:" + newline +
newline +
list)
}

func main() {
if err := run(); err != nil {
_, _ = os.Stderr.WriteString(err.Error() + newline)
os.Exit(1)
}
os.Exit(0)
}

var _ uint = 1 << 63

Top || internal/compile/compile.go

package compile

import (
"NN-512/internal/compile/author"
"NN-512/internal/compile/plan"
"NN-512/internal/raw"
"errors"
"fmt"
"math"
"sort"
"strings"
"unicode"
)

type Result struct {
Name string
H, C []byte
}

func Compile(text string) (*Result, error) {
nodes, err := raw.Parse(text)
if err != nil {
return nil, err
}
st := &state{
nodes: nodes,
}
if err := st.stages(); err != nil {
return nil, errors.New("compile failed: " + err.Error())
}
return &Result{
Name: st.config.Prefix,
H: st.h,
C: st.c,
}, nil
}

const origNameShield = "A"

func origName(tensor string) string {
i := strings.LastIndexFunc(tensor, unicode.IsUpper)
return tensor[:i]
}

func anError(msg, tensor string, lines ...int) error {
var pre string
if n := len(lines); n != 0 {
if n > 2 {
panic("bug")
}
l0 := lines[0]
if n == 1 || l0 == lines[1] {
pre = fmt.Sprintf("line %d: ", l0)
} else {
l1 := lines[1]
if l0 > l1 {
l0, l1 = l1, l0
}
pre = fmt.Sprintf("lines %d and %d: ", l0, l1)
}
}
if tensor != "" {
pre += origName(tensor) + ": "
}
return errors.New(pre + msg)
}

type arc struct {
tensor string
attach int
}

type shape struct {
tensor string
nchw [4]int
}

type state struct {
nodes []raw.Node
config *raw.Config
inputs []int
outputs []int
fanins [][]arc
fanouts [][]arc
shapes [][]shape
plan plan.Plan
h, c []byte
}

var stages = [...]func(*state) error{
(*state).stage1,
(*state).stage2,
(*state).stage3,
(*state).stage4,
(*state).stage5,
(*state).stage6,
(*state).stage7,
(*state).stage8,
(*state).stage9,
(*state).stage10,
(*state).stage11,
(*state).stage12,
(*state).stage13,
(*state).stage14,
(*state).stage15,
(*state).stage16,
(*state).stage17,
}

func (st *state) stages() error {
for _, stage := range &stages {
if err := stage(st); err != nil {
return err
}
}
return nil
}

func (st *state) stage1() error {
for i, node := range st.nodes {
if config, ok := node.(*raw.Config); ok {
if st.config != nil {
return anError("second Config", "", config.LineNum)
}
st.config = config
copy(st.nodes[1:], st.nodes[:i])
st.nodes = st.nodes[1:]
}
}
if st.config == nil {
return anError("no Config", "")
}
return nil
}

func (st *state) stage2() error {
const directly = "Input is directly connected to Output"
seen := make(map[string]int)
for i, node := range st.nodes {
switch at := node.(type) {
case *raw.Input:
to := at.ToTensor
if prev := seen[to]; prev == 0 {
seen[to] = -at.LineNum
st.inputs = append(st.inputs, i)
} else if prev < 0 {
return anError("Inputs have the same ToTensor", to, -prev, at.LineNum)
} else {
return anError(directly, to, prev, at.LineNum)
}
case *raw.Output:
from := at.FromTensor
if prev := seen[from]; prev == 0 {
seen[from] = at.LineNum
st.outputs = append(st.outputs, i)
} else if prev < 0 {
return anError(directly, from, -prev, at.LineNum)
} else {
return anError("Outputs have the same FromTensor", from, prev, at.LineNum)
}
}
}
if len(st.inputs) == 0 {
return anError("no Input", "")
}
if len(st.outputs) == 0 {
return anError("no Output", "")
}
return nil
}

func (st *state) stage3() error {
n := len(st.nodes)
use := make(map[string][]int, n)
gen := make(map[string]int, n)
for i, node := range st.nodes {
for _, tensor := range node.FromTensors() {
use[tensor] = append(use[tensor], i)
}
for _, tensor := range node.ToTensors() {
if j, ok := gen[tensor]; ok {
ii, jj := node.LineNumber(), st.nodes[j].LineNumber()
return anError("tensor is produced more than once", tensor, ii, jj)
}
gen[tensor] = i
}
}
st.fanins = make([][]arc, n)
st.fanouts = make([][]arc, n)
insert := func(fan *[]arc, z arc) {
for _, have := range *fan {
if have == z {
return
}
}
*fan = append(*fan, z)
}
for i, node := range st.nodes {
for _, tensor := range node.FromTensors() {
if j, ok := gen[tensor]; ok {
if i == j {
return anError("self-loop", tensor, node.LineNumber())
}
insert(&st.fanins[i], arc{tensor, j})
continue
}
line := node.LineNumber()
return anError("tensor is consumed but never produced", tensor, line)
}
for _, tensor := range node.ToTensors() {
for _, j := range use[tensor] {
insert(&st.fanouts[i], arc{tensor, j})
}
}
}
return nil
}

func (st *state) stage4() error {
const (
white byte = iota
gray
black
)
type frame struct {
to []arc
from int
}
v := 0
n := len(st.nodes)
color := make([]byte, n)
stack := make([]frame, n)
for _, i := range st.outputs {
v += 1
color[i] = gray
stack[0] = frame{st.fanins[i], i}
for j := 0; j >= 0; {
top := &stack[j]
if len(top.to) == 0 {
color[top.from] = black
j -= 1
continue
}
arc0 := top.to[0]
top.to = top.to[1:]
k := arc0.attach
if color[k] == white {
v += 1
color[k] = gray
j += 1
stack[j] = frame{st.fanins[k], k}
continue
}
if color[k] == black {
continue
}
prev := st.nodes[k].LineNumber()
curr := st.nodes[top.from].LineNumber()
return anError("circular dependency", arc0.tensor, prev, curr)
}
}
if v == n {
return nil
}
for i := n - 1; ; i-- {
if color[i] == white {
node := st.nodes[i]
line := node.LineNumber()
tensor := node.ToTensors()[0]
return anError("no path to an Output", tensor, line)
}
}
}

func (st *state) stage5FeedsOutput(i int) bool {
fanout := st.fanouts[i]
for j := range fanout {
a := fanout[j].attach
if _, ok := st.nodes[a].(*raw.Output); ok {
return true
}
}
return false
}

func (st *state) stage5DetachFanin(i int) {
fanin := st.fanins[i]
st.fanins[i] = nil
for j := range fanin {
a := fanin[j].attach
fanout := st.fanouts[a]
n := len(fanout)
for k := 0; k < n; k++ {
if fanout[k].attach == i {
n -= 1
fanout[k] = fanout[n]
k -= 1
}
}
st.fanouts[a] = fanout[:n]
}
}

func (st *state) stage5Unique(a []arc) []arc {
n := len(a)
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
if a[j] == a[i] {
n -= 1
a[i] = a[n]
i -= 1
break
}
}
}
return a[:n]
}

func (st *state) stage5Replace(this, with int) {
st.stage5DetachFanin(this)
tensors1 := st.nodes[this].ToTensors()
tensors2 := st.nodes[with].ToTensors()
match := func(tensor string) string {
for i, ii := range tensors1 {
if ii == tensor {
return tensors2[i]
}
}
return tensor
}
fanout1 := st.fanouts[this]
fanout2 := st.fanouts[with]
for i := range fanout1 {
a := fanout1[i].attach
fanin := st.fanins[a]
for j := range fanin {
if fanin[j].attach == this {
tensor := match(fanin[j].tensor)
fanin[j] = arc{tensor, with}
fanout2 = append(fanout2, arc{tensor, a})
}
}
st.fanins[a] = st.stage5Unique(fanin)
}
st.nodes[this] = nil
st.fanouts[this] = nil
st.fanouts[with] = st.stage5Unique(fanout2)
for i := range fanout1 {
a := fanout1[i].attach
switch aa := st.nodes[a].(type) {
case *raw.Activation:
aa.FromTensor = match(aa.FromTensor)
case *raw.Add:
aa.FromTensor1 = match(aa.FromTensor1)
aa.FromTensor2 = match(aa.FromTensor2)
case *raw.BatchNorm:
aa.FromTensor = match(aa.FromTensor)
case *raw.Concat:
aa.FromTensor1 = match(aa.FromTensor1)
aa.FromTensor2 = match(aa.FromTensor2)
case *raw.Conv:
aa.FromTensor = match(aa.FromTensor)
case *raw.FullyConnected:
aa.FromTensor = match(aa.FromTensor)
case *raw.Pooling:
aa.FromTensor = match(aa.FromTensor)
case *raw.Softmax:
aa.FromTensor = match(aa.FromTensor)
default:
panic("bug")
}
}
}

func (st *state) stage5() error {
const (
activation int = iota
add
concat
pooling
softmax
)
type sig struct {
s1, s2 string
i1, i2 int
i3, i4 int
}
n := len(st.nodes)
seen := make([]bool, n)
repl := make(map[sig]int, n)
var dedup func(int)
dedup = func(i int) {
fanin := st.fanins[i]
saved := make([]int, len(fanin))
for j := range fanin {
saved[j] = fanin[j].attach
}
for _, j := range saved {
if !seen[j] {
seen[j] = true
dedup(j)
}
}
var subst int
switch at := st.nodes[i].(type) {
case *raw.Activation:
i3 := int(math.Float32bits(at.Param))
sg := sig{
s1: at.FromTensor,
i1: activation,
i2: int(at.Kind),
i3: i3,
}
if subst = repl[sg]; subst == 0 {
repl[sg] = i + 1
if at.Kind == raw.ReLU {
sg.i3 = -1
if subst = repl[sg]; subst == 0 {
if at.Param <= 0 {
sg.s1 = at.ToTensor
repl[sg] = i + 1
}
} else {
sg.i3 = i3
delete(repl, sg)
}
}
}
case *raw.Add:
sg := sig{s1: at.FromTensor1, s2: at.FromTensor2, i1: add}
if subst = repl[sg]; subst == 0 {
repl[sg] = i + 1
sg.s1, sg.s2 = sg.s2, sg.s1
repl[sg] = i + 1
}
case *raw.Concat:
sg := sig{s1: at.FromTensor1, s2: at.FromTensor2, i1: concat}
if subst = repl[sg]; subst == 0 {
repl[sg] = i + 1
}
case *raw.Pooling:
sg := sig{
s1: at.FromTensor,
i1: pooling,
i2: int(at.Kind),
i3: at.PaddingH,
i4: at.PaddingW,
}
if subst = repl[sg]; subst == 0 {
repl[sg] = i + 1
}
case *raw.Softmax:
sg := sig{s1: at.FromTensor, i1: softmax}
if subst = repl[sg]; subst == 0 {
repl[sg] = i + 1
}
}
if subst != 0 && !st.stage5FeedsOutput(i) {
st.stage5Replace(i, subst-1)
}
}
for _, i := range st.outputs {
dedup(i)
}
return nil
}

func (st *state) stage6Check(nchw *[4]int) error {
x := 1
for _, y := range nchw {
if y <= 0 {
return errors.New("tensor is empty")
}
xy := x * y
if xy/x != y || xy >= 1<<48 {
return errors.New("tensor is too large")
}
x = xy
}
return nil
}

func (st *state) stage6Add(at *raw.Add, from1, from2 *shape) ([]shape, error) {
if from1.nchw != from2.nchw {
const chw = "%s is %dx%dx%d"
msg := fmt.Sprintf(chw+" but "+chw+" (CxHxW mismatch)",
origName(from1.tensor), from1.nchw[1], from1.nchw[2], from1.nchw[3],
origName(from2.tensor), from2.nchw[1], from2.nchw[2], from2.nchw[3])
return nil, anError(msg, at.ToTensor, at.LineNum)
}
return []shape{{at.ToTensor, from1.nchw}}, nil
}

func (st *state) stage6BatchNorm(at *raw.BatchNorm, from *shape) []shape {
perChannel := [4]int{1, from.nchw[1], 1, 1}
return []shape{
{at.ToTensor, from.nchw},
{at.MeansTensor, perChannel},
{at.VariancesTensor, perChannel},
{at.ScalesTensor, perChannel},
{at.ShiftsTensor, perChannel},
}
}

func (st *state) stage6Concat(at *raw.Concat, from1, from2 *shape) ([]shape, error) {
h1, w1 := from1.nchw[2], from1.nchw[3]
h2, w2 := from2.nchw[2], from2.nchw[3]
if h1 != h2 || w1 != w2 {
const hw = "%s is spatially %dx%d"
msg := fmt.Sprintf(hw+" but "+hw+" (HxW mismatch)",
origName(from1.tensor), h1, w1,
origName(from2.tensor), h2, w2)
return nil, anError(msg, at.ToTensor, at.LineNum)
}
c := from1.nchw[1] + from2.nchw[1]
shapes := []shape{{at.ToTensor, [4]int{1, c, h1, w1}}}
if err := st.stage6Check(&shapes[0].nchw); err != nil {
return nil, anError(err.Error(), at.ToTensor, at.LineNum)
}
return shapes, nil
}

func (st *state) stage6Conv(at *raw.Conv, from *shape) ([]shape, error) {
const notDivisible = " has %d channels (not divisible by %d groups)"
c, k, g := from.nchw[1], at.ToChannels, at.Groups
if c%g != 0 {
msg := fmt.Sprintf("FromTensor"+notDivisible, c, g)
return nil, anError(msg, at.FromTensor, at.LineNum)
}
if k%g != 0 {
msg := fmt.Sprintf("ToTensor"+notDivisible, k, g)
return nil, anError(msg, at.ToTensor, at.LineNum)
}
h := from.nchw[2] + 2*at.PaddingH - (1 + (at.FilterH-1)*at.DilationH)
w := from.nchw[3] + 2*at.PaddingW - (1 + (at.FilterW-1)*at.DilationW)
if h >= 0 {
h = h/at.StrideH + 1
}
if w >= 0 {
w = w/at.StrideW + 1
}
shapes := []shape{
{at.ToTensor, [4]int{1, k, h, w}},
{at.WeightsTensor, [4]int{k, c / g, at.FilterH, at.FilterW}},
{at.BiasesTensor, [4]int{1, k, 1, 1}},
}
if err := st.stage6Check(&shapes[0].nchw); err != nil {
return nil, anError(err.Error(), at.ToTensor, at.LineNum)
}
if err := st.stage6Check(&shapes[1].nchw); err != nil {
return nil, anError(err.Error(), at.WeightsTensor+origNameShield, at.LineNum)
}
return shapes, nil
}

func (st *state) stage6FullyConnected(at *raw.FullyConnected, from *shape) ([]shape, error) {
k := at.ToChannels
shapes := []shape{
{at.ToTensor, [4]int{1, k, 1, 1}},
{at.WeightsTensor, [4]int{k, from.nchw[1], from.nchw[2], from.nchw[3]}},
{at.BiasesTensor, [4]int{1, k, 1, 1}},
}
if err := st.stage6Check(&shapes[1].nchw); err != nil {
return nil, anError(err.Error(), at.WeightsTensor+origNameShield, at.LineNum)
}
return shapes, nil
}

func (st *state) stage6Input(at *raw.Input) ([]shape, error) {
shapes := []shape{
{at.ToTensor, [4]int{1, at.Channels, at.Height, at.Width}},
}
if err := st.stage6Check(&shapes[0].nchw); err != nil {
return nil, anError(err.Error(), at.ToTensor, at.LineNum)
}
return shapes, nil
}

func (st *state) stage6Pooling(at *raw.Pooling, from *shape) ([]shape, error) {
var side int
const stride = 2
switch at.Kind {
case raw.Max2x2Stride2, raw.Avg2x2Stride2:
side = 2
case raw.Max3x3Stride2, raw.Avg3x3Stride2:
side = 3
case raw.MaxGlobal, raw.AvgGlobal:
side = 1
default:
panic("bug")
}
if at.PaddingH >= side || at.PaddingW >= side {
return nil, anError("too much padding", at.ToTensor, at.LineNum)
}
c, h, w := from.nchw[1], 1, 1
if side != 1 {
h = from.nchw[2] + 2*at.PaddingH - side
w = from.nchw[3] + 2*at.PaddingW - side
if h >= 0 {
h = h/stride + 1
}
if w >= 0 {
w = w/stride + 1
}
}
shapes := []shape{{at.ToTensor, [4]int{1, c, h, w}}}
if err := st.stage6Check(&shapes[0].nchw); err != nil {
return nil, anError(err.Error(), at.ToTensor, at.LineNum)
}
return shapes, nil
}

func (st *state) stage6Fill(i int) (err error) {
slot := &st.shapes[i]
if *slot != nil {
return
}
fanin := st.fanins[i]
for j := range fanin {
k := fanin[j].attach
if err = st.stage6Fill(k); err != nil {
return
}
}
lookup := func(tensor string) *shape {
for j := range fanin {
if fanin[j].tensor == tensor {
shapes := st.shapes[fanin[j].attach]
for k := range shapes {
if shapes[k].tensor == tensor {
return &shapes[k]
}
}
}
}
panic("bug")
}
switch at := st.nodes[i].(type) {
case *raw.Activation:
from := lookup(at.FromTensor)
*slot = []shape{{at.ToTensor, from.nchw}}
case *raw.Add:
from1 := lookup(at.FromTensor1)
from2 := lookup(at.FromTensor2)
*slot, err = st.stage6Add(at, from1, from2)
case *raw.BatchNorm:
from := lookup(at.FromTensor)
*slot = st.stage6BatchNorm(at, from)
case *raw.Concat:
from1 := lookup(at.FromTensor1)
from2 := lookup(at.FromTensor2)
*slot, err = st.stage6Concat(at, from1, from2)
case *raw.Conv:
from := lookup(at.FromTensor)
*slot, err = st.stage6Conv(at, from)
case *raw.FullyConnected:
from := lookup(at.FromTensor)
*slot, err = st.stage6FullyConnected(at, from)
case *raw.Input:
*slot, err = st.stage6Input(at)
case *raw.Output:
from := lookup(at.FromTensor)
*slot = []shape{*from}
case *raw.Pooling:
from := lookup(at.FromTensor)
*slot, err = st.stage6Pooling(at, from)
case *raw.Softmax:
from := lookup(at.FromTensor)
*slot = []shape{{at.ToTensor, from.nchw}}
default:
panic("bug")
}
return
}

func (st *state) stage6() error {
n := len(st.nodes)
st.shapes = make([][]shape, n)
for _, i := range st.outputs {
if err := st.stage6Fill(i); err != nil {
return err
}
}
return nil
}

func (st *state) stage7() error {
in := len(st.inputs)
st.plan = plan.Plan{
Config: st.config,
Seq: make([]*plan.Op, in, len(st.nodes)),
}
ops := make([]*plan.Op, len(st.nodes))
var mirror func(int)
mirror = func(i int) {
fanin := st.fanins[i]
for j := range fanin {
if k := fanin[j].attach; ops[k] == nil {
mirror(k)
}
}
node := st.nodes[i]
op := &plan.Op{Nodes: []raw.Node{node}}
switch node.(type) {
case *raw.Input:
in--
st.plan.Seq[in] = op
default:
st.plan.Seq = append(st.plan.Seq, op)
}
ops[i] = op
shapes := st.shapes[i]
params := node.ParamTensors()
op.Params = [][]plan.Param{make([]plan.Param, len(params))}
op.ParamMods = make([][2][]plan.Mod, 1)
for j, tensor := range params {
var nchw *[4]int
for k := range shapes {
if shapes[k].tensor == tensor {
nchw = &shapes[k].nchw
break
}
}
op.Params[0][j] = plan.Param{
Tensor: tensor,
NCHW: *nchw,
}
}
from := node.FromTensors()
op.From = make([]*plan.Span, len(from))
op.FromMods = make([][]plan.Mod, len(from))
for j, tensor := range from {
var pile *plan.Pile
for k := range fanin {
if fanin[k].tensor == tensor {
for _, span := range ops[fanin[k].attach].To {
if span.Tensors[0] == tensor {
pile = span.Piles[0]
break
}
}
break
}
}
span := &plan.Span{
Piles: []*plan.Pile{pile},
Offsets: []int{0},
Tensors: []string{tensor},
Counts: []int{pile.Channels},
Op: op,
}
pile.Readers = append(pile.Readers, span)
op.From[j] = span
}
to := node.ToTensors()
op.To = make([]*plan.Span, len(to))
op.ToMods = make([][]plan.Mod, len(to))
for j, tensor := range to {
var nchw *[4]int
for k := range shapes {
if shapes[k].tensor == tensor {
nchw = &shapes[k].nchw
break
}
}
pile := &plan.Pile{
Channels: nchw[1],
Height: nchw[2],
Width: nchw[3],
}
span := &plan.Span{
Piles: []*plan.Pile{pile},
Offsets: []int{0},
Tensors: []string{tensor},
Counts: []int{nchw[1]},
Op: op,
}
pile.Writers = []*plan.Span{span}
op.To[j] = span
}
}
for _, i := range st.outputs {
mirror(i)
}
return nil
}

func (st *state) stage8Pre(op1 *plan.Op) {
var mods []plan.Mod
for span1 := op1.From[0]; ; {
pile1 := span1.Piles[0]
span2 := pile1.Writers[0]
op2 := span2.Op
if _, ok := op2.Nodes[0].(*raw.BatchNorm); !ok {
break
}
mods = append(mods, plan.Mod{
Nodes: op2.Nodes,
Params: op2.Params[0],
})
span3 := op2.From[0]
pile2 := span3.Piles[0]
span1.Piles[0] = pile2
if len(pile1.Readers) == 1 {
for i, ii := range pile2.Readers {
if ii == span3 {
pile2.Readers[i] = span1
break
}
}
*op2 = plan.Op{}
continue
}
pile2.Readers = append(pile2.Readers, span1)
for i, ii := range pile1.Readers {
if ii == span1 {
j := len(pile1.Readers) - 1
pile1.Readers[i] = pile1.Readers[j]
pile1.Readers[j] = nil
pile1.Readers = pile1.Readers[:j]
break
}
}
op2.Nodes = []raw.Node{op2.Nodes[0]}
params := make([]plan.Param, len(op2.Params[0]))
copy(params, op2.Params[0])
op2.Params[0] = params
}
for i, j := 0, len(mods)-1; i < j; i, j = i+1, j-1 {
mods[i], mods[j] = mods[j], mods[i]
}
op1.ParamMods[0][0] = mods
}

func (st *state) stage8Post(op1 *plan.Op) {
var mods []plan.Mod
for span1 := op1.To[0]; ; {
pile1 := span1.Piles[0]
if len(pile1.Readers) != 1 {
break
}
span2 := pile1.Readers[0]
op2 := span2.Op
if _, ok := op2.Nodes[0].(*raw.BatchNorm); !ok {
break
}
mods = append(mods, plan.Mod{
Nodes: op2.Nodes,
Params: op2.Params[0],
})
span3 := op2.To[0]
pile2 := span3.Piles[0]
span1.Piles[0] = pile2
pile2.Writers[0] = span1
*op2 = plan.Op{}
}
op1.ParamMods[0][1] = mods
}

func (st *state) stage8() error {
for _, op := range st.plan.Seq {
if op.Nodes == nil {
continue
}
switch node := op.Nodes[0].(type) {
case *raw.Conv:
if node.PaddingH == 0 &&
node.PaddingW == 0 {
st.stage8Pre(op)
}
st.stage8Post(op)
case *raw.FullyConnected:
st.stage8Pre(op)
st.stage8Post(op)
}
}
return nil
}

func (st *state) stage9() error {
for _, op1 := range st.plan.Seq {
if op1.Nodes == nil {
continue
}
if _, ok := op1.Nodes[0].(*raw.Add); !ok {
continue
}
for i, span1 := range op1.From {
pile1 := span1.Piles[0]
if len(pile1.Readers) != 1 {
continue
}
span2 := pile1.Writers[0]
op2 := span2.Op
if _, ok := op2.Nodes[0].(*raw.Add); !ok {
continue
}
op1.Nodes = append(op1.Nodes, op2.Nodes...)
op1.Params = append(op1.Params, op2.Params...)
op1.ParamMods = append(op1.ParamMods, op2.ParamMods...)
for _, span3 := range op2.From {
span3.Op = op1
}
op1.From[i] = op2.From[0]
op1.From = append(op1.From, op2.From[1:]...)
op1.FromMods = append(op1.FromMods, op2.FromMods[1:]...)
*op2 = plan.Op{}
}
}
return nil
}

func (st *state) stage10Via(op1 *plan.Op) int {
var i int
for j, span1 := range op1.From {
pile1 := span1.Piles[0]
span2 := pile1.Writers[0]
op2 := span2.Op
switch op2.Nodes[0].(type) {
case *raw.Activation, *raw.BatchNorm:
if i == 0 {
i = j + 1
} else {
return -1
}
}
}
if i == 0 {
return 0
}
return i - 1
}

func (st *state) stage10Pre(op1 *plan.Op, pass int) {
for i, span1 := range op1.From {
var mods []plan.Mod
for {
pile1 := span1.Piles[0]
span2 := pile1.Writers[0]
op2 := span2.Op
mod := plan.Mod{Nodes: op2.Nodes}
var span3 *plan.Span
switch mod.Nodes[0].(type) {
case *raw.Activation:
span3 = op2.From[0]
case *raw.Add:
if pass >= 2 {
break
}
if len(pile1.Readers) != 1 {
break
}
j := st.stage10Via(op2)
if j < 0 {
break
}
mod.From = op2.From
span3 = mod.From[j]
k := len(mod.From) - 1
mod.From[j] = mod.From[k]
mod.From[k] = nil
mod.From = mod.From[:k]
for _, span4 := range mod.From {
span4.Op = op1
}
case *raw.BatchNorm:
if pass >= 3 {
break
}
mod.Params = op2.Params[0]
span3 = op2.From[0]
}
if span3 == nil {
break
}
mods = append(mods, mod)
pile2 := span3.Piles[0]
span1.Piles[0] = pile2
if len(pile1.Readers) == 1 {
for j, span4 := range pile2.Readers {
if span4 == span3 {
pile2.Readers[j] = span1
break
}
}
*op2 = plan.Op{}
continue
}
pile2.Readers = append(pile2.Readers, span1)
for j, span4 := range pile1.Readers {
if span4 == span1 {
k := len(pile1.Readers) - 1
pile1.Readers[j] = pile1.Readers[k]
pile1.Readers[k] = nil
pile1.Readers = pile1.Readers[:k]
break
}
}
op2.Nodes = []raw.Node{op2.Nodes[0]}
if mod.Params != nil {
op2.Params[0] = make([]plan.Param, len(mod.Params))
copy(op2.Params[0], mod.Params)
}
}
for j, k := 0, len(mods)-1; j < k; j, k = j+1, k-1 {
mods[j], mods[k] = mods[k], mods[j]
}
op1.FromMods[i] = mods
}
}

func (st *state) stage10Post(op1 *plan.Op, pass int) {
for i, span1 := range op1.To {
var mods []plan.Mod
for {
pile1 := span1.Piles[0]
if len(pile1.Readers) != 1 {
break
}
span2 := pile1.Readers[0]
op2 := span2.Op
mod := plan.Mod{Nodes: op2.Nodes}
switch mod.Nodes[0].(type) {
case *raw.Activation:
case *raw.Add:
if pass > 2 {
mod.Nodes = nil
break
}
var j int
for k, span3 := range op2.From {
if span3 == span2 {
j = k
} else if span3.Piles[0].ElemBytes != pass {
mod.Nodes = nil
break
}
}
if mod.Nodes == nil {
break
}
mod.From = op2.From
k := len(mod.From) - 1
mod.From[j] = mod.From[k]
mod.From[k] = nil
mod.From = mod.From[:k]
for _, span3 := range mod.From {
span3.Op = op1
}
case *raw.BatchNorm:
if pass > 3 {
mod.Nodes = nil
break
}
mod.Params = op2.Params[0]
default:
mod.Nodes = nil
}
if mod.Nodes == nil {
break
}
mods = append(mods, mod)
span3 := op2.To[0]
pile2 := span3.Piles[0]
span1.Piles[0] = pile2
pile2.Writers[0] = span1
*op2 = plan.Op{}
}
op1.ToMods[i] = mods
}
}

func (st *state) stage10() error {
pass := 1
for ; pass <= 3; pass++ {
for _, op := range st.plan.Seq {
if op.Nodes == nil {
continue
}
switch z := op.Nodes[0]; pass {
case 1:
switch z.(type) {
case *raw.Conv, *raw.Pooling:
st.stage10Pre(op, pass)
st.stage10Post(op, pass)
case *raw.FullyConnected:
st.stage10Post(op, pass)
}
for _, span := range op.To {
span.Piles[0].ElemBytes = pass
}
case 2:
switch z.(type) {
case *raw.Add:
st.stage10Pre(op, pass)
st.stage10Post(op, pass)
}
for _, span := range op.To {
span.Piles[0].ElemBytes = pass
}
case 3:
switch z.(type) {
case *raw.BatchNorm:
st.stage10Pre(op, pass)
st.stage10Post(op, pass)
}
}
}
}
for i := len(st.plan.Seq) - 1; i >= 0; i-- {
op := st.plan.Seq[i]
if op.Nodes == nil {
continue
}
switch op.Nodes[0].(type) {
case *raw.Activation:
st.stage10Pre(op, pass)
}
}
n := 0
for i, op := range st.plan.Seq {
st.plan.Seq[i] = nil
if op.Nodes == nil {
continue
}
for _, span := range op.To {
span.Piles[0].ElemBytes = 0
}
st.plan.Seq[n] = op
n += 1
}
st.plan.Seq = st.plan.Seq[:n]
return nil
}

func (st *state) stage11Reduce(mods []plan.Mod) []plan.Mod {
phase := 0
for i := range mods {
node := mods[i].Nodes[0]
if node, ok := node.(*raw.Activation); ok {
if node.Kind == raw.ReLU {
if phase++; phase == 2 {
break
}
continue
}
}
phase = 0
}
if phase != 2 {
return mods
}
var first int
var slopes []float32
keep, have := 0, len(mods)
for i := 0; i <= have; i++ {
if i < have {
node := mods[i].Nodes[0]
if node, ok := node.(*raw.Activation); ok {
if node.Kind == raw.ReLU {
if len(slopes) == 0 {
first = i
}
slopes = append(slopes, node.Param)
continue
}
}
}
if run := len(slopes); run != 0 {
if run == 1 {
mods[keep] = mods[first]
} else {
var param float32 = 1
for _, slope := range slopes {
if param *= slope; param <= 0 {
break
}
}
cross := *mods[first+run-1].Nodes[0].(*raw.Activation)
cross.Param = param
nodes := make([]raw.Node, 1, 1+run)
nodes[0] = &cross
for j := first; j < first+run; j++ {
nodes = append(nodes, mods[j].Nodes[0])
}
mods[keep] = plan.Mod{Nodes: nodes}
}
keep += 1
slopes = slopes[:0]
}
if i < have {
mods[keep] = mods[i]
keep += 1
}
}
for i := keep; i < have; i++ {
mods[i] = plan.Mod{}
}
return mods[:keep]
}

func (st *state) stage11Absorb(op *plan.Op) {
at, ok := op.Nodes[0].(*raw.Activation)
if !ok || at.Kind != raw.ReLU {
return
}
mods := op.FromMods[0]
i := len(mods) - 1
if i < 0 {
return
}
nodes := mods[i].Nodes
node, ok := nodes[0].(*raw.Activation)
if !ok || node.Kind != raw.ReLU {
return
}
mods[i] = plan.Mod{}
op.FromMods[0] = mods[:i]
param := node.Param
if param > 0 {
param *= at.Param
}
cross := *at
cross.Param = param
if len(nodes) > 1 {
nodes = nodes[1:]
}
cnt := 1 + len(nodes) + 1
all := make([]raw.Node, 1, cnt)
all[0] = &cross
all = append(all, nodes...)
all = append(all, at)
op.Nodes = all
op.Params = make([][]plan.Param, cnt)
op.ParamMods = make([][2][]plan.Mod, cnt)
}

func (st *state) stage11() error {
for _, op := range st.plan.Seq {
for i, mods := range op.FromMods {
op.FromMods[i] = st.stage11Reduce(mods)
}
for i, mods := range op.ToMods {
op.ToMods[i] = st.stage11Reduce(mods)
}
st.stage11Absorb(op)
}
return nil
}

func (st *state) stage12Compatible(node1, node2 raw.Node) bool {
switch node1 := node1.(type) {
case *raw.Conv:
node2, ok := node2.(*raw.Conv)
return ok &&
node1.FilterH == node2.FilterH &&
node1.FilterW == node2.FilterW &&
node1.StrideH == node2.StrideH &&
node1.StrideW == node2.StrideW &&
node1.PaddingH == node2.PaddingH &&
node1.PaddingW == node2.PaddingW &&
node1.DilationH == node2.DilationH &&
node1.DilationW == node2.DilationW &&
node1.Groups == node2.Groups &&
node1.Groups == 1
}
return false
}

func (st *state) stage12AddsEq(from1, from2 []*plan.Span) bool {
cnt := len(from1)
if len(from2) != cnt {
return false
}
used := make([]bool, cnt)
for _, span1 := range from1 {
found := false
for i, span2 := range from2 {
if !used[i] &&
span1.Piles[0] == span2.Piles[0] &&
span1.Offsets[0] == span2.Offsets[0] {
used[i] = true
found = true
break
}
}
if !found {
return false
}
}
return true
}

func (st *state) stage12Common(mods1, mods2 []plan.Mod, allow bool) int {
n1, n2 := len(mods1), len(mods2)
if n1 > n2 {
n1 = n2
}
for i := 0; i < n1; i++ {
switch node1 := mods1[i].Nodes[0].(type) {
case *raw.Activation:
node2, ok := mods2[i].Nodes[0].(*raw.Activation)
if ok && node1.Kind == node2.Kind && node1.Param == node2.Param {
continue
}
case *raw.Add:
if !allow {
break
}
_, ok := mods2[i].Nodes[0].(*raw.Add)
if ok && st.stage12AddsEq(mods1[i].From, mods2[i].From) {
continue
}
case *raw.BatchNorm:
if !allow {
break
}
node2, ok := mods2[i].Nodes[0].(*raw.BatchNorm)
if ok && node1.MeansTensor == node2.MeansTensor {
continue
}
default:
panic("bug")
}
return i
}
return n1
}

func (st *state) stage12Stackable(span1, span2 *plan.Span) bool {
if span1.Offsets[0] != span2.Offsets[0] {
return false
}
op1, op2 := span1.Op, span2.Op
if op1 == op2 {
return false
}
nodes1, nodes2 := op1.Nodes, op2.Nodes
if len(nodes1) != 1 || len(nodes2) != 1 {
return false
}
if !st.stage12Compatible(nodes1[0], nodes2[0]) {
return false
}
from1, from2 := op1.From, op2.From
if len(from1) != 1 || len(from2) != 1 {
return false
}
if from1[0] != span1 || from2[0] != span2 {
return false
}
mods1, mods2 := op1.FromMods[0], op2.FromMods[0]
if len(mods1) != len(mods2) {
return false
}
if len(mods1) != st.stage12Common(mods1, mods2, true) {
return false
}
if len(op1.To) != 1 || len(op2.To) != 1 {
return false
}
return true
}

func (st *state) stage12Clone(a *plan.Span) *plan.Span {
return &plan.Span{
Piles: []*plan.Pile{a.Piles[0]},
Offsets: []int{a.Offsets[0]},
Tensors: []string{a.Tensors[0]},
Counts: []int{a.Counts[0]},
Op: a.Op,
}
}

func (st *state) stage12Indirect(op1 *plan.Op, broadcast bool) *plan.Op {
span1 := op1.To[0]
pile1 := span1.Piles[0]
readers, split := pile1.Readers, 0
for i, span2 := range readers {
op2 := span2.Op
switch op2.Nodes[0].(type) {
case *raw.Concat, *raw.Output:
case *raw.FullyConnected, *raw.Softmax:
if !broadcast {
continue
}
default:
if !broadcast {
continue
}
mod := true
for _, span3 := range op2.From {
if span3 == span2 {
mod = false
break
}
}
if !mod {
continue
}
}
readers[i] = readers[split]
readers[split] = span2
split += 1
}
if split == 0 {
return nil
}
span2 := st.stage12Clone(readers[0])
span3 := st.stage12Clone(readers[0])
op2 := &plan.Op{
Nodes: []raw.Node{
&raw.Activation{
LineNum: op1.Nodes[0].LineNumber(),
FromTensor: span2.Tensors[0],
ToTensor: span3.Tensors[0],
Kind: raw.ReLU,
Param: 1,
},
},
Params: make([][]plan.Param, 1),
ParamMods: make([][2][]plan.Mod, 1),
From: []*plan.Span{span2},
FromMods: make([][]plan.Mod, 1),
To: []*plan.Span{span3},
ToMods: make([][]plan.Mod, 1),
}
span2.Op = op2
span3.Op = op2
pile2 := &plan.Pile{
Channels: pile1.Channels,
Height: pile1.Height,
Width: pile1.Width,
Writers: []*plan.Span{span3},
Readers: make([]*plan.Span, split),
}
span3.Piles[0] = pile2
copy(pile2.Readers, readers)
for _, span4 := range pile2.Readers {
span4.Piles[0] = pile2
}
remain := 1 + len(readers) - split
pile1.Readers = readers[:remain]
readers[0] = span2
copy(readers[1:], readers[split:])
for i, n := remain, len(readers); i < n; i++ {
readers[i] = nil
}
return op2
}

func (st *state) stage12Detach(mods []plan.Mod) {
for i := range mods {
switch mods[i].Nodes[0].(type) {
case *raw.Add:
for _, span1 := range mods[i].From {
pile := span1.Piles[0]
j := -1
for k, span2 := range pile.Readers {
if span2 == span1 {
j = k
break
}
}
copy(pile.Readers[j:], pile.Readers[j+1:])
k := len(pile.Readers) - 1
pile.Readers[k] = nil
pile.Readers = pile.Readers[:k]
}
}
}
}

func (st *state) stage12Broadcast(mods1 []plan.Mod, pile1 *plan.Pile) {
for _, span1 := range pile1.Readers {
op1 := span1.Op
i := -1
for j, span2 := range op1.From {
if span2 == span1 {
i = j
break
}
}
mods2 := op1.FromMods[i]
mods3 := make([]plan.Mod, len(mods1)+len(mods2))
op1.FromMods[i] = mods3
copy(mods3[len(mods1):], mods2)
for j := range mods1 {
nodes1 := mods1[j].Nodes
nodes2 := make([]raw.Node, len(nodes1))
mods3[j].Nodes = nodes2
copy(nodes2, nodes1)
switch nodes1[0].(type) {
case *raw.Add:
from1 := mods1[j].From
from2 := make([]*plan.Span, len(from1))
mods3[j].From = from2
for k, span2 := range from1 {
span3 := st.stage12Clone(span2)
span3.Op = op1
from2[k] = span3
pile2 := span3.Piles[0]
pile2.Readers = append(pile2.Readers, span3)
}
case *raw.BatchNorm:
params1 := mods1[j].Params
params2 := make([]plan.Param, len(params1))
mods3[j].Params = params2
copy(params2, params1)
}
}
}
}

func (st *state) stage12Fanout(edits map[*plan.Op][2]*plan.Op, loose []*plan.Op, fused *plan.Op) {
op1 := loose[0]
mods1 := op1.ToMods[0]
for _, op2 := range loose[1:] {
mods2 := op2.ToMods[0]
n := st.stage12Common(mods1, mods2, false)
mods1 = mods1[:n]
}
for _, op2 := range loose {
mods2 := op2.ToMods[0]
broadcast := len(mods1) < len(mods2)
op3 := st.stage12Indirect(op2, broadcast)
edits[op2] = [2]*plan.Op{fused, op3}
}
for _, op2 := range loose {
mods2 := op2.ToMods[0][len(mods1):]
if len(mods2) != 0 {
st.stage12Detach(mods2)
span2 := op2.To[0]
pile2 := span2.Piles[0]
st.stage12Broadcast(mods2, pile2)
}
}
span1 := op1.To[0]
pile1 := span1.Piles[0]
span3 := &plan.Span{
Piles: []*plan.Pile{nil},
Offsets: []int{0},
Tensors: make([]string, len(loose)),
Counts: make([]int, len(loose)),
Op: fused,
}
pile3 := &plan.Pile{
Height: pile1.Height,
Width: pile1.Width,
Writers: []*plan.Span{span3},
}
span3.Piles[0] = pile3
fused.To = []*plan.Span{span3}
fused.ToMods = [][]plan.Mod{make([]plan.Mod, len(mods1))}
copy(fused.ToMods[0], mods1)
for i, op2 := range loose {
span2 := op2.To[0]
pile2 := span2.Piles[0]
span3.Tensors[i] = span2.Tensors[0]
span3.Counts[i] = span2.Counts[0]
offset := pile3.Channels
pile3.Channels += pile2.Channels
for _, span4 := range pile2.Readers {
span4.Piles[0] = pile3
span4.Offsets[0] = offset
pile3.Readers = append(pile3.Readers, span4)
}
}
}

func (st *state) stage12Fusion(edits map[*plan.Op][2]*plan.Op, loose []*plan.Op) {
fused := &plan.Op{
Nodes: make([]raw.Node, len(loose)),
Params: make([][]plan.Param, len(loose)),
ParamMods: make([][2][]plan.Mod, len(loose)),
From: loose[0].From,
FromMods: loose[0].FromMods,
}
for i, op := range loose {
fused.Nodes[i] = op.Nodes[0]
fused.Params[i] = op.Params[0]
fused.ParamMods[i] = op.ParamMods[0]
}
fused.From[0].Op = fused
mods := fused.FromMods[0]
for i := range mods {
switch mods[i].Nodes[0].(type) {
case *raw.Add:
for _, span := range mods[i].From {
span.Op = fused
}
}
}
st.stage12Fanout(edits, loose, fused)
for _, op := range loose {
*op = plan.Op{}
}
}

func (st *state) stage12Readers(edits map[*plan.Op][2]*plan.Op, pile *plan.Pile) {
var loose []*plan.Op
n := len(pile.Readers)
for i := 0; i < n; i++ {
span1 := pile.Readers[i]
for j := i + 1; j < n; j++ {
span2 := pile.Readers[j]
if !st.stage12Stackable(span1, span2) {
continue
}
if len(loose) == 0 {
loose = append(loose, span1.Op)
}
loose = append(loose, span2.Op)
n -= 1
pile.Readers[j] = pile.Readers[n]
pile.Readers[n] = nil
pile.Readers = pile.Readers[:n]
j -= 1
}
if len(loose) == 0 {
continue
}
for _, op := range loose[1:] {
st.stage12Detach(op.FromMods[0])
}
st.stage12Fusion(edits, loose)
loose = loose[:0]
n = len(pile.Readers)
if i >= n {
i = n - 1
}
for pile.Readers[i] != span1 {
i -= 1
}
}
}

func (st *state) stage12() error {
edits := make(map[*plan.Op][2]*plan.Op)
for _, op := range st.plan.Seq {
if len(op.From) == 0 {
continue
}
span := op.From[0]
pile := span.Piles[0]
if pile.ElemBytes != 0 {
continue
}
pile.ElemBytes = -1
st.stage12Readers(edits, pile)
}
if n := len(edits); n != 0 {
seq := make([]*plan.Op, 0, len(st.plan.Seq)+n)
seen := make(map[*plan.Op]bool, n)
for _, op := range st.plan.Seq {
if op.Nodes != nil {
seq = append(seq, op)
continue
}
ed := edits[op]
if fused := ed[0]; !seen[fused] {
seen[fused] = true
seq = append(seq, fused)
}
if indirect := ed[1]; indirect != nil {
seq = append(seq, indirect)
}
}
st.plan.Seq = seq
}
for _, op := range st.plan.Seq {
for _, span := range op.To {
span.Piles[0].ElemBytes = 0
}
}
return nil
}

func (st *state) stage13Fork(op1 *plan.Op) {
for _, span1 := range op1.To {
if len(span1.Tensors) != 1 {
continue
}
pile1 := span1.Piles[0]
first := true
var moving []*plan.Span
n := len(pile1.Readers)
for i := 0; i < n; i++ {
span2 := pile1.Readers[i]
op2 := span2.Op
switch op2.Nodes[0].(type) {
case *raw.Concat, *raw.Output:
if first {
first = false
break
}
moving = append(moving, span2)
n -= 1
pile1.Readers[i] = pile1.Readers[n]
pile1.Readers[n] = nil
pile1.Readers = pile1.Readers[:n]
i -= 1
}
}
for _, span2 := range moving {
pile2 := &plan.Pile{
Channels: pile1.Channels,
Height: pile1.Height,
Width: pile1.Width,
Writers: []*plan.Span{span1},
Readers: []*plan.Span{span2},
}
span1.Piles = append(span1.Piles, pile2)
span1.Offsets = append(span1.Offsets, 0)
span2.Piles[0] = pile2
}
}
}

func (st *state) stage13IndirectFork(edits map[*plan.Op][]*plan.Op, op1 *plan.Op) {
for _, span1 := range op1.To {
if len(span1.Tensors) != 1 {
continue
}
pile1 := span1.Piles[0]
var moving []*plan.Span
n := len(pile1.Readers)
for i := 0; i < n; i++ {
span2 := pile1.Readers[i]
op2 := span2.Op
switch op2.Nodes[0].(type) {
case *raw.Concat, *raw.Output:
moving = append(moving, span2)
n -= 1
pile1.Readers[i] = pile1.Readers[n]
pile1.Readers[n] = nil
i -= 1
}
}
if len(moving) == 0 {
continue
}
span2 := &plan.Span{
Piles: []*plan.Pile{pile1},
Offsets: []int{0},
Tensors: []string{moving[0].Tensors[0]},
Counts: []int{moving[0].Counts[0]},
}
pile1.Readers[n] = span2
pile1.Readers = pile1.Readers[:n+1]
op2 := &plan.Op{
Nodes: []raw.Node{
&raw.Activation{
LineNum: op1.Nodes[0].LineNumber(),
FromTensor: span2.Tensors[0],
ToTensor: span2.Tensors[0],
Kind: raw.ReLU,
Param: 1,
},
},
Params: make([][]plan.Param, 1),
ParamMods: make([][2][]plan.Mod, 1),
From: []*plan.Span{span2},
FromMods: make([][]plan.Mod, 1),
To: make([]*plan.Span, 1),
ToMods: make([][]plan.Mod, 1),
}
span2.Op = op2
edits[op1] = append(edits[op1], op2)
span3 := &plan.Span{
Piles: make([]*plan.Pile, len(moving)),
Offsets: make([]int, len(moving)),
Tensors: []string{span2.Tensors[0]},
Counts: []int{span2.Counts[0]},
Op: op2,
}
op2.To[0] = span3
for i, span4 := range moving {
pile2 := &plan.Pile{
Channels: pile1.Channels,
Height: pile1.Height,
Width: pile1.Width,
Writers: []*plan.Span{span3},
Readers: []*plan.Span{span4},
}
span3.Piles[i] = pile2
span4.Piles[0] = pile2
}
}
}

func (st *state) stage13Fission(edits map[*plan.Op][]*plan.Op, op1 *plan.Op) {
span1 := op1.To[0]
pile1 := span1.Piles[0]
first := true
var moving []*plan.Span
n := len(pile1.Readers)
for i := 0; i < n; i++ {
span2 := pile1.Readers[i]
op2 := span2.Op
switch op2.Nodes[0].(type) {
case *raw.Concat, *raw.Output:
if first {
first = false
break
}
moving = append(moving, span2)
n -= 1
pile1.Readers[i] = pile1.Readers[n]
pile1.Readers[n] = nil
i -= 1
}
}
if len(moving) == 0 {
return
}
pile1.Readers = pile1.Readers[:n]
span2 := op1.From[0]
span3 := op1.From[1]
pile2 := span2.Piles[0]
pile3 := span3.Piles[0]
for _, span4 := range moving {
pile4 := &plan.Pile{
Channels: pile1.Channels,
Height: pile1.Height,
Width: pile1.Width,
Writers: []*plan.Span{nil},
Readers: []*plan.Span{span4},
}
span4.Piles[0] = pile4
span5 := &plan.Span{
Piles: []*plan.Pile{pile4},
Offsets: []int{span1.Offsets[0]},
Tensors: []string{span1.Tensors[0]},
Counts: []int{span1.Counts[0]},
}
pile4.Writers[0] = span5
op2 := &plan.Op{
Nodes: []raw.Node{op1.Nodes[0]},
Params: make([][]plan.Param, 1),
ParamMods: make([][2][]plan.Mod, 1),
From: make([]*plan.Span, 2),
FromMods: make([][]plan.Mod, 2),
To: []*plan.Span{span5},
ToMods: make([][]plan.Mod, 1),
}
edits[op1] = append(edits[op1], op2)
span5.Op = op2
span6 := &plan.Span{
Piles: []*plan.Pile{pile2},
Offsets: []int{span2.Offsets[0]},
Tensors: []string{span2.Tensors[0]},
Counts: []int{span2.Counts[0]},
Op: op2,
}
pile2.Readers = append(pile2.Readers, span6)
op2.From[0] = span6
span7 := &plan.Span{
Piles: []*plan.Pile{pile3},
Offsets: []int{span3.Offsets[0]},
Tensors: []string{span3.Tensors[0]},
Counts: []int{span3.Counts[0]},
Op: op2,
}
pile3.Readers = append(pile3.Readers, span7)
op2.From[1] = span7
}
}

func (st *state) stage13() error {
edits := make(map[*plan.Op][]*plan.Op)
for i := len(st.plan.Seq) - 1; i >= 0; i-- {
op := st.plan.Seq[i]
switch op.Nodes[0].(type) {
case *raw.Concat:
st.stage13Fission(edits, op)
case *raw.Input:
st.stage13IndirectFork(edits, op)
default:
st.stage13Fork(op)
}
}
if len(edits) != 0 {
cnt := len(st.plan.Seq)
for _, insert := range edits {
cnt += len(insert)
}
seq := make([]*plan.Op, 0, cnt)
for _, op := range st.plan.Seq {
seq = append(seq, op)
if insert, ok := edits[op]; ok {
seq = append(seq, insert...)
}
}
st.plan.Seq = seq
}
return nil
}

func (st *state) stage14Rewire(op *plan.Op, pile1, pile2 *plan.Pile, offset int) {
pile1.Writers = append(pile1.Writers, pile2.Writers...)
for _, span := range pile2.Writers {
for i, pile3 := range span.Piles {
if pile3 == pile2 {
span.Piles[i] = pile1
span.Offsets[i] += offset
}
}
}
for _, span := range pile2.Readers {
if span.Op != op {
pile1.Readers = append(pile1.Readers, span)
span.Piles[0] = pile1
span.Offsets[0] += offset
}
}
}

func (st *state) stage14Bypass(op *plan.Op) {
pile1 := op.To[0].Piles[0]
pile2 := op.From[0].Piles[0]
pile3 := op.From[1].Piles[0]
fanin := len(pile2.Writers) + len(pile3.Writers)
pile1.Writers = make([]*plan.Span, 0, fanin)
st.stage14Rewire(op, pile1, pile2, 0)
st.stage14Rewire(op, pile1, pile3, pile2.Channels)
}

func (st *state) stage14() error {
n := 0
for i, op := range st.plan.Seq {
switch op.Nodes[0].(type) {
case *raw.Concat:
st.stage14Bypass(op)
*op = plan.Op{}
continue
}
if n < i {
st.plan.Seq[n] = op
st.plan.Seq[i] = nil
}
n += 1
}
st.plan.Seq = st.plan.Seq[:n]
return nil
}

func (st *state) stage15InputOutput(pile *plan.Pile) bool {
for _, span := range pile.Writers {
switch span.Op.Nodes[0].(type) {
case *raw.Input:
return true
}
}
for _, span := range pile.Readers {
switch span.Op.Nodes[0].(type) {
case *raw.Output:
return true
}
}
return false
}

func (st *state) stage15Include(pile *plan.Pile) bool {
for _, span := range pile.Readers {
op := span.Op
switch op.Nodes[0].(type) {
case *raw.FullyConnected:
if op.From[0] == span {
return false
}
}
}
return true
}

func (st *state) stage15Edit(pile *plan.Pile) {
if pile.ElemBytes != 0 {
return
}
elem := 4
pitch1 := pile.Width * elem
pitch2 := pile.Height * pitch1
size := pile.Channels * pitch2
offset := 0
if st.stage15InputOutput(pile) {
offset = -1
} else if st.stage15Include(pile) {
switch st.config.Platform {
case raw.AVX512Float32:
pad := func(y int) int {
const line = 1 << 6
if y <= line {
return y
}
return (y+line-1)&-line | line
}
pitch2 = pad(pitch2)
size = pad(pile.Channels * pitch2)
default:
panic("bug")
}
}
pile.ElemBytes = elem
pile.Pitch1Bytes = pitch1
pile.Pitch2Bytes = pitch2
pile.SizeBytes = size
pile.OffsetBytes = offset
}

func (st *state) stage15() error {
for _, op := range st.plan.Seq {
for _, span := range op.To {
for _, pile := range span.Piles {
st.stage15Edit(pile)
}
}
}
return nil
}

type stage16Cell struct {
pile *plan.Pile
first int
last int
guide []int
}

func stage16Rank(cells []stage16Cell, i1, i2 int) bool {
cell1 := &cells[i1]
cell2 := &cells[i2]
size1 := cell1.pile.SizeBytes
size2 := cell2.pile.SizeBytes
if size1 != size2 {
return size1 > size2
}
diff1 := cell1.last - cell1.first
diff2 := cell2.last - cell2.first
if diff1 != diff2 {
return diff1 > diff2
}
return i1 < i2
}

func (st *state) stage16Guides(cells []stage16Cell) {
n := len(cells)
for i1 := range cells {
c1 := &cells[i1]
if c1.pile.OffsetBytes < 0 {
continue
}
last := c1.last
for i2 := i1 + 1; i2 < n; i2++ {
c2 := &cells[i2]
if last < c2.first {
break
}
if c2.pile.OffsetBytes < 0 {
continue
}
if stage16Rank(cells, i1, i2) {
c2.guide = append(c2.guide, i1)
continue
}
c1.guide = append(c1.guide, i2)
}
}
}

func (st *state) stage16Cells() []stage16Cell {
guess := len(st.plan.Seq)
index := make(map[*plan.Pile]int, guess)
cells := make([]stage16Cell, 0, guess)
for i, op := range st.plan.Seq {
for j, span1 := range op.From {
cells[index[span1.Piles[0]]].last = i
mods := op.FromMods[j]
for k := range mods {
for _, span2 := range mods[k].From {
cells[index[span2.Piles[0]]].last = i
}
}
}
for j, span1 := range op.To {
for _, pile := range span1.Piles {
if at, ok := index[pile]; ok {
cells[at].last = i
continue
}
index[pile] = len(cells)
cells = append(cells, stage16Cell{
pile: pile,
first: i,
last: i,
})
}
mods := op.ToMods[j]
for k := range mods {
for _, span2 := range mods[k].From {
cells[index[span2.Piles[0]]].last = i
}
}
}
}
st.stage16Guides(cells)
return cells
}

type stage16ByRank struct {
cells []stage16Cell
seq []int
}

func (by *stage16ByRank) Len() int {
return len(by.seq)
}

func (by *stage16ByRank) Less(i, j int) bool {
return stage16Rank(by.cells, by.seq[i], by.seq[j])
}

func (by *stage16ByRank) Swap(i, j int) {
by.seq[i], by.seq[j] = by.seq[j], by.seq[i]
}

type stage16ByOffset struct {
cells []stage16Cell
guide []int
}

func (by *stage16ByOffset) Len() int {
return len(by.guide)
}

func (by *stage16ByOffset) Less(i, j int) bool {
ii := by.cells[by.guide[i]].pile.OffsetBytes
jj := by.cells[by.guide[j]].pile.OffsetBytes
return ii < jj
}

func (by *stage16ByOffset) Swap(i, j int) {
by.guide[i], by.guide[j] = by.guide[j], by.guide[i]
}

func (st *state) stage16() error {
cells := st.stage16Cells()
byRank := &stage16ByRank{
cells: cells,
seq: make([]int, len(cells)),
}
for i := range byRank.seq {
byRank.seq[i] = i
}
sort.Sort(byRank)
byOffset := &stage16ByOffset{
cells: cells,
}
for _, i1 := range byRank.seq {
cell1 := &cells[i1]
pile1 := cell1.pile
if pile1.OffsetBytes < 0 {
continue
}
offset1, size1 := 0, pile1.SizeBytes
byOffset.guide = cell1.guide
sort.Sort(byOffset)
for _, i2 := range byOffset.guide {
pile2 := cells[i2].pile
offset2 := pile2.OffsetBytes
if offset1+size1 <= offset2 {
break
}
min := offset2 + pile2.SizeBytes
if offset1 < min {
offset1 = min
}
}
pile1.OffsetBytes = offset1
}
return nil
}

func (st *state) stage17() error {
st.h, st.c = author.Implement(&st.plan)
return nil
}

Top || internal/compile/author/author.go

package author

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/cpu"
"NN-512/internal/compile/author/elwi"
"NN-512/internal/compile/author/engine"
"NN-512/internal/compile/author/eof"
"NN-512/internal/compile/author/errmsg"
"NN-512/internal/compile/author/exp"
"NN-512/internal/compile/author/fc"
"NN-512/internal/compile/author/glopl"
"NN-512/internal/compile/author/hc"
"NN-512/internal/compile/author/include"
"NN-512/internal/compile/author/license"
"NN-512/internal/compile/author/loom"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/net"
"NN-512/internal/compile/author/one"
"NN-512/internal/compile/author/params"
"NN-512/internal/compile/author/rsqrt"
"NN-512/internal/compile/author/softmax"
"NN-512/internal/compile/author/strider"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/author/three"
"NN-512/internal/compile/author/thrpl"
"NN-512/internal/compile/author/tobuild"
"NN-512/internal/compile/author/twopl"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

func Implement(a *plan.Plan) (h, c []byte) {
st := state{pl: a, nms: nmsrc.New()}
st.stages()
return st.hc.Join()
}

type state struct {
pl *plan.Plan
hc hc.Sections
nms nmsrc.Src
paramsName string
errmsgCtx *errmsg.Ctx
threaderCtx *threader.Ctx
expCtx *exp.Ctx
softmaxCtx *softmax.Ctx
actCtx *act.Ctx
rsqrtCtx *rsqrt.Ctx
bnCtx *bn.Ctx
elwiCtx *elwi.Ctx
gloplCtx *glopl.Ctx
twoplCtx *twopl.Ctx
thrplCtx *thrpl.Ctx
fcCtx *fc.Ctx
oneCtx *one.Ctx
threeCtx *three.Ctx
striderCtx *strider.Ctx
loomCtx *loom.Ctx
netCtx *net.Ctx
engineCtx *engine.Ctx
}

func (st *state) stages() {
st.stage1()
st.stage2()
st.stage3()
st.stage4()
st.stage5()
st.stage6()
st.stage7()
st.stage8()
st.stage9()
st.stage10()
st.stage11()
st.stage12()
}

func (st *state) stage1() {
st.hc.Append(hc.HPragmaOnce, cgen.PragmaOnce, cgen.Newline)
st.hc.Append(hc.HLicense, license.Gen, cgen.Newline)
st.hc.Append(hc.HInclude, include.H(), cgen.Newline)
st.hc.Append(hc.HLinkage1, cgen.Linkage1, cgen.Newline)
st.hc.Append(hc.HLinkage2, cgen.Linkage2, cgen.Newline)
st.hc.Append(hc.HLast, eof.Gen)
}

func (st *state) stage2() {
st.hc.Append(hc.CToBuild, tobuild.Gen(st.pl), cgen.Newline)
st.hc.Append(hc.CLicense, license.Gen, cgen.Newline)
st.hc.Append(hc.CInclude, include.C(st.pl), cgen.Newline)
st.hc.Append(hc.CLast, eof.Gen)
}

func (st *state) stage3() {
name := params.Name(st.pl)
st.hc.Append(hc.HParams1, params.Fwd(name), cgen.Newline)
st.hc.Append(hc.HParams2, params.Def(st.pl, name), cgen.Newline)
st.paramsName = name
}

func (st *state) stage4() {
ctx := errmsg.NewCtx(st.pl, st.nms)
prep := &errmsg.Prep{Ctx: ctx}
st.hc.Append(hc.CErrmsg, prep, cgen.Newline)
st.errmsgCtx = ctx
}

func (st *state) stage5() {
ctx := threader.NewCtx(st.pl, st.nms, st.errmsgCtx)
prep := &threader.Prep{Ctx: ctx}
st.hc.Append(hc.CThreader, prep, cgen.Newline)
st.threaderCtx = ctx
}

func (st *state) stage6() {
ctx := exp.NewCtx(st.pl, st.nms)
prep := &exp.Prep{Ctx: ctx}
st.hc.Append(hc.CExp, prep, cgen.Newline)
st.expCtx = ctx
}

func (st *state) stage7() {
st.softmaxCtx = softmax.NewCtx(st.pl, st.nms, st.threaderCtx, st.expCtx)
st.actCtx = act.NewCtx(st.pl, st.nms)
}

func (st *state) stage8() {
ctx := rsqrt.NewCtx(st.pl, st.nms)
st.hc.Append(hc.CRsqrt, ctx.Prep(), cgen.Newline)
st.rsqrtCtx = ctx
}

func (st *state) stage9() {
st.bnCtx = bn.NewCtx(st.pl, st.nms, st.rsqrtCtx)
st.elwiCtx = elwi.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.gloplCtx = glopl.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.twoplCtx = twopl.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.thrplCtx = thrpl.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.fcCtx = fc.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.oneCtx = one.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.threeCtx = three.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.striderCtx = strider.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
st.loomCtx = loom.NewCtx(st.pl, st.nms, st.threaderCtx, st.actCtx, st.bnCtx)
}

func (st *state) stage10() {
ctx := net.NewCtx(st.pl, st.nms, st.paramsName)
st.hc.Append(hc.HNet, ctx.Comment(), cgen.Newline)
st.hc.Append(hc.HNet, ctx.StructFwd(), cgen.Newline)
st.hc.Append(hc.HNet, ctx.CreateDecl(), cgen.Newline)
st.hc.Append(hc.HNet, ctx.DestroyDecl(), cgen.Newline)
st.hc.Append(hc.CNet, ctx.StructDef(), cgen.Newline)
st.hc.Append(hc.CNet, ctx.DestroyDef(), cgen.Newline)
st.netCtx = ctx
}

func (st *state) stage11() {
ctx := engine.NewCtx(st.pl, st.nms, st.errmsgCtx, st.threaderCtx, st.netCtx)
st.hc.Append(hc.HEngine, ctx.Comment(), cgen.Newline)
st.hc.Append(hc.HEngine, ctx.StructFwd(), cgen.Newline)
st.hc.Append(hc.HEngine, ctx.CreateDecl(), cgen.Newline)
st.hc.Append(hc.HEngine, ctx.PthreadTDecl(), cgen.Newline)
st.hc.Append(hc.HEngine, ctx.InferenceDecl(), cgen.Newline)
st.hc.Append(hc.HEngine, ctx.DestroyDecl(), cgen.Newline)
st.hc.Append(hc.CEngine, ctx.StructDef(), cgen.Newline)
st.hc.Append(hc.CEngine, ctx.PthreadTDef(), cgen.Newline)
st.hc.Append(hc.CEngine, ctx.DestroyDef(), cgen.Newline)
st.engineCtx = ctx
}

func (st *state) stage12() {
type link struct {
chans int
height int
width int
elemBytes int
pitch1Bytes []int
pitch2Bytes []int
addrExprs []cgen.Gen
ops [][]mod.Op
}
type bank struct {
filts int
bnPre int
bnPost int
addrExprs []cgen.Gen
}
var (
usedParams bool
netAlloc cgen.Gen
netAlign cgen.Gen
netBytes int
tmpAlloc cgen.Gen
tmpAlign cgen.Gen
tmpEdge int
tmpBytes int
netTeam cgen.Gen
netStmts cgen.Stmts
netBlocks cgen.Stmts
engNetAlign cgen.Gen
engTeam cgen.Gen
engAlign cgen.Gen
engEdge int
engBytes int
engStmts cgen.Stmts
engBlocks cgen.Stmts
bnPersist map[string]bool
bnNetEng map[string][2]cgen.Gen
planOp *plan.Op
linkFrom *link
linkTo *link
banks []*bank
)
il := func(i int) cgen.Gen {
return cgen.IntLit(i)
}
vb := func(s string) cgen.Gen {
return cgen.Vb(s)
}
nm := func(s string) cgen.Gen {
return vb(st.nms.Name(s))
}
ptr := func(t cgen.Gen) cgen.Gen {
return cgen.Ptr{Type: t}
}
unused := func(what cgen.Gen) cgen.Gen {
return cgen.Cast{
Type: cgen.Void,
Expr: what,
}
}
param := func(name string) cgen.Gen {
usedParams = true
return cgen.Arrow{
Expr: st.netCtx.CreateParams,
Name: name,
}
}
netAddr := func() cgen.Gen {
if netAlloc == nil {
netAlloc = nm("alloc")
netAlign = nm("align")
}
netBytes += st.netCtx.Alignment - 1
netBytes &= -st.netCtx.Alignment
return cgen.Add{
Expr1: netAlign,
Expr2: il(netBytes),
}
}
netExtend := func(bytes int) {
netBytes += bytes
}
tmpAddr := func() cgen.Gen {
if tmpAlloc == nil {
tmpAlloc = nm("tmpAlloc")
tmpAlign = nm("tmpAlign")
}
tmpEdge += st.netCtx.Alignment - 1
tmpEdge &= -st.netCtx.Alignment
return cgen.Add{
Expr1: tmpAlign,
Expr2: il(tmpEdge),
}
}
tmpExtend := func(bytes int) {
tmpEdge += bytes
}
tmpRewind := func() {
if tmpBytes < tmpEdge {
tmpBytes = tmpEdge
}
tmpEdge = 0
}
NetTeam := func() cgen.Gen {
if netTeam == nil {
netTeam = nm("team")
}
return netTeam
}
netStmt := func(stmt cgen.Gen) {
netStmts = append(netStmts, stmt)
}
netBlock := func() {
if netStmts == nil {
return
}
netBlocks = append(
netBlocks, cgen.Block{
Inner: netStmts,
},
)
netStmts = nil
}
engNetAddr := func() cgen.Gen {
if engNetAlign == nil {
engNetAlign = nm("netAlign")
}
return cgen.Add{
Expr1: engNetAlign,
Expr2: il(netBytes),
}
}
EngTeam := func() cgen.Gen {
if engTeam == nil {
engTeam = nm("team")
}
return engTeam
}
engAddr := func(off int) cgen.Gen {
if engAlign == nil {
engAlign = nm("align")
}
if off == -1 {
engEdge += st.engineCtx.Alignment - 1
engEdge &= -st.engineCtx.Alignment
off = st.engineCtx.Split + engEdge
}
return cgen.Add{
Expr1: engAlign,
Expr2: il(off),
}
}
engExtend := func(bytes int) {
engEdge += bytes
}
engRewind := func() {
if engBytes < engEdge {
engBytes = engEdge
}
engEdge = 0
}
engStmt := func(stmt cgen.Gen) {
engStmts = append(engStmts, stmt)
}
engBlock := func() {
if engStmts == nil {
return
}
engBlocks = append(
engBlocks, cgen.Block{
Inner: engStmts,
},
)
engStmts = nil
}
bnKey := func(node *raw.BatchNorm) string {
return node.MeansTensor
}
bnSimplify := func(
node *raw.BatchNorm,
chans int,
) (netEng [2]cgen.Gen) {
var (
key = bnKey(node)
persist = bnPersist[key]
)
switch {
case persist:
var ok bool
if netEng, ok = bnNetEng[key]; ok {
return
}
netEng[0] = netAddr()
netEng[1] = engNetAddr()
bnNetEng[key] = netEng
default:
netEng[0] = tmpAddr()
netEng[1] = nil
}
simplify := &bn.Simplify{
Ctx: st.bnCtx,
Channels: chans,
Epsilon: node.Epsilon,
Means: param(node.MeansTensor),
Variances: param(node.VariancesTensor),
Scales: param(node.ScalesTensor),
Shifts: param(node.ShiftsTensor),
Mas: netEng[0],
}
st.hc.Append(hc.CBn, simplify.Prep())
bytes := simplify.MasBytes()
netStmt(simplify)
switch {
case persist:
netExtend(bytes)
default:
tmpExtend(bytes)
}
return
}
linker := func(
spans []*plan.Span,
pmods [][]plan.Mod,
prior bool,
) (lnk *link) {
doOnce := func(span *plan.Span) {
var (
chans = 0
pile = span.Piles[0]
)
for _, n := range span.Counts {
chans += n
}
lnk = &link{
chans: chans,
height: pile.Height,
width: pile.Width,
elemBytes: pile.ElemBytes,
ops: make(
[][]mod.Op, len(pmods),
),
}
}
ioName := func(pile *plan.Pile) string {
for _, span := range pile.Writers {
switch node := span.Op.Nodes[0].(type) {
case *raw.Input:
return node.ToTensor
}
}
for _, span := range pile.Readers {
switch node := span.Op.Nodes[0].(type) {
case *raw.Output:
return node.FromTensor
}
}
panic("bug")
}
doSpan := func(span *plan.Span) {
for i, pile := range span.Piles {
var (
pitch1 = pile.Pitch1Bytes
pitch2 = pile.Pitch2Bytes
off1 = pile.OffsetBytes
ae cgen.Gen
)
switch off1 {
case -1:
ae = cgen.Cast{
Type: cgen.PtrChar,
Expr: vb(ioName(pile)),
}
default:
off2 := span.Offsets[i] * pitch2
ae = engAddr(off1 + off2)
}
lnk.pitch1Bytes = append(
lnk.pitch1Bytes, pitch1,
)
lnk.pitch2Bytes = append(
lnk.pitch2Bytes, pitch2,
)
lnk.addrExprs = append(
lnk.addrExprs, ae,
)
}
}
doSpans := func(spans []*plan.Span) {
for _, span := range spans {
doSpan(span)
}
}
doBn := func(node *raw.BatchNorm) {
netEng := bnSimplify(
node, lnk.chans,
)
lnk.addrExprs = append(
lnk.addrExprs,
netEng[1],
)
}
doMod := func(pmod *plan.Mod) (op mod.Op) {
switch node := pmod.Nodes[0].(type) {
case *raw.Activation:
switch node.Kind {
case raw.ReLU:
op.Kind = mod.ReLU
op.Float = node.Param
default:
panic("bug")
}
case *raw.Add:
op.Kind = mod.Add
op.Int = len(pmod.From)
doSpans(pmod.From)
case *raw.BatchNorm:
op.Kind = mod.Bn
doBn(node)
default:
panic("bug")
}
return
}
doMods := func(i int) {
var (
pms = pmods[i]
ops = make([]mod.Op, len(pms))
)
for j := range pms {
ops[j] = doMod(&pms[j])
}
lnk.ops[i] = ops
}
for i, span := range spans {
if i == 0 {
doOnce(span)
}
if prior {
doSpan(span)
}
doMods(i)
if !prior {
doSpan(span)
}
}
return
}
layer7 := func() {
switch node := planOp.Nodes[0].(type) {
case *raw.Activation, *raw.Add, *raw.BatchNorm:
switch node := node.(type) {
case *raw.Activation:
switch node.Kind {
case raw.ReLU:
linkTo.ops[0] = append(
[]mod.Op{{
Kind: mod.ReLU,
Float: node.Param,
}},
linkTo.ops[0]...,
)
default:
panic("bug")
}
case *raw.BatchNorm:
netEng := bnSimplify(
node, linkTo.chans,
)
linkTo.addrExprs = append(
[]cgen.Gen{netEng[1]},
linkTo.addrExprs...,
)
linkTo.ops[0] = append(
[]mod.Op{{Kind: mod.Bn}},
linkTo.ops[0]...,
)
}
spec := &elwi.Spec{
Channels: linkFrom.chans,
Height: linkFrom.height,
Width: linkFrom.width,
ElemBytes: linkFrom.elemBytes,
Pitch1Bytes: append(
linkFrom.pitch1Bytes,
linkTo.pitch1Bytes...,
),
Pitch2Bytes: append(
linkFrom.pitch2Bytes,
linkTo.pitch2Bytes...,
),
Ops: append(
linkFrom.ops,
linkTo.ops[0],
),
}
call := &elwi.Call{
Ctx: st.elwiCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
linkFrom.addrExprs,
linkTo.addrExprs...,
),
}
st.hc.Append(hc.CElwi, call.Prep())
engStmt(call)
case *raw.Conv:
useOne := func() bool {
return true &&
node.FilterH == 1 &&
node.FilterW == 1
}
useThree := func() bool {
return true &&
node.FilterH == 3 &&
node.FilterW == 3 &&
node.StrideH == 1 &&
node.StrideW == 1 &&
node.DilationH == 1 &&
node.DilationW == 1
}
useStrider := func() bool {
return true &&
node.FilterH <= 14 &&
node.FilterW <= 14 &&
node.FilterH*node.FilterW >= 9 &&
node.StrideH == 2 &&
node.StrideW == 2 &&
node.DilationH == 1 &&
node.DilationW == 1
}
switch {
case useOne():
spec := &one.Spec{
From: one.SpecFrom{
Chans: linkFrom.chans,
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
Filts: make(
[]one.SpecFilts, len(banks),
),
To: one.SpecTo{
Pitch1Bytes: linkTo.pitch1Bytes,
Pitch2Bytes: linkTo.pitch2Bytes,
Ops: linkTo.ops[0],
},
StrideH: node.StrideH,
StrideW: node.StrideW,
PaddingH: node.PaddingH,
PaddingW: node.PaddingW,
Groups: node.Groups,
}
for i, bnk := range banks {
spec.Filts[i] = one.SpecFilts{
Cnt: bnk.filts,
BnPre: bnk.bnPre,
BnPost: bnk.bnPost,
}
if i == 0 {
continue
}
banks[0].addrExprs = append(
banks[0].addrExprs,
bnk.addrExprs...,
)
}
arrangeWts := &one.ArrangeWts{
Ctx: st.oneCtx,
Spec: spec,
Team: NetTeam(),
Tensors: append(
banks[0].addrExprs,
netAddr(),
),
}
addrWts := engNetAddr()
st.hc.Append(hc.COne, arrangeWts.Prep())
netExtend(arrangeWts.Bytes())
netStmt(arrangeWts)
addrDats := engAddr(-1)
arrangeDats := &one.ArrangeDats{
Ctx: st.oneCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
linkFrom.addrExprs,
addrDats,
),
}
st.hc.Append(hc.COne, arrangeDats.Prep())
engExtend(arrangeDats.Bytes())
engStmt(arrangeDats)
apply := &one.Apply{
Ctx: st.oneCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
[]cgen.Gen{
addrWts,
addrDats,
},
linkTo.addrExprs...,
),
}
st.hc.Append(hc.COne, apply.Prep())
engStmt(apply)
case useThree():
spec := &three.Spec{
From: three.SpecFrom{
Chans: linkFrom.chans,
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
Filts: make(
[]three.SpecFilts, len(banks),
),
To: three.SpecTo{
Pitch1Bytes: linkTo.pitch1Bytes,
Pitch2Bytes: linkTo.pitch2Bytes,
Ops: linkTo.ops[0],
},
StrideH: node.StrideH,
StrideW: node.StrideW,
PaddingH: node.PaddingH,
PaddingW: node.PaddingW,
Groups: node.Groups,
}
for i, bnk := range banks {
spec.Filts[i] = three.SpecFilts{
Cnt: bnk.filts,
BnPre: bnk.bnPre,
BnPost: bnk.bnPost,
}
if i == 0 {
continue
}
banks[0].addrExprs = append(
banks[0].addrExprs,
bnk.addrExprs...,
)
}
arrangeFilts := &three.ArrangeFilts{
Ctx: st.threeCtx,
Spec: spec,
Team: NetTeam(),
Tensors: append(
banks[0].addrExprs,
netAddr(),
),
}
addrFilts := engNetAddr()
st.hc.Append(hc.CThree, arrangeFilts.Prep())
netExtend(arrangeFilts.Bytes())
netStmt(arrangeFilts)
addrDats := engAddr(-1)
arrangeDats := &three.ArrangeDats{
Ctx: st.threeCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
linkFrom.addrExprs,
addrDats,
),
}
st.hc.Append(hc.CThree, arrangeDats.Prep())
engExtend(arrangeDats.Bytes())
engStmt(arrangeDats)
addrSums := engAddr(-1)
produceSums := &three.ProduceSums{
Ctx: st.threeCtx,
Spec: spec,
Team: EngTeam(),
Tensors: []cgen.Gen{
addrFilts,
addrDats,
addrSums,
},
}
st.hc.Append(hc.CThree, produceSums.Prep())
engExtend(produceSums.Bytes())
engStmt(produceSums)
consumeSums := &three.ConsumeSums{
Ctx: st.threeCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
[]cgen.Gen{addrSums},
linkTo.addrExprs...,
),
}
st.hc.Append(hc.CThree, consumeSums.Prep())
engStmt(consumeSums)
case useStrider():
spec := &strider.Spec{
From: strider.SpecFrom{
Chans: linkFrom.chans,
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
Filts: make(
[]strider.SpecFilts, len(banks),
),
To: strider.SpecTo{
Pitch1Bytes: linkTo.pitch1Bytes,
Pitch2Bytes: linkTo.pitch2Bytes,
Ops: linkTo.ops[0],
},
FilterH: node.FilterH,
FilterW: node.FilterW,
PaddingH: node.PaddingH,
PaddingW: node.PaddingW,
DilationH: node.DilationH,
DilationW: node.DilationW,
Groups: node.Groups,
}
for i, bnk := range banks {
spec.Filts[i] = strider.SpecFilts{
Cnt: bnk.filts,
BnPre: bnk.bnPre,
BnPost: bnk.bnPost,
}
if i == 0 {
continue
}
banks[0].addrExprs = append(
banks[0].addrExprs,
bnk.addrExprs...,
)
}
arrangeFilts := &strider.ArrangeFilts{
Ctx: st.striderCtx,
Spec: spec,
Team: NetTeam(),
Tensors: append(
banks[0].addrExprs,
netAddr(),
),
}
addrFilts := engNetAddr()
st.hc.Append(hc.CStrider, arrangeFilts.Prep())
netExtend(arrangeFilts.Bytes())
netStmt(arrangeFilts)
addrDats := engAddr(-1)
arrangeDats := &strider.ArrangeDats{
Ctx: st.striderCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
linkFrom.addrExprs,
addrDats,
),
}
st.hc.Append(hc.CStrider, arrangeDats.Prep())
engExtend(arrangeDats.Bytes())
engStmt(arrangeDats)
addrSums := engAddr(-1)
produceSums := &strider.ProduceSums{
Ctx: st.striderCtx,
Spec: spec,
Team: EngTeam(),
Tensors: []cgen.Gen{
addrFilts,
addrDats,
addrSums,
},
}
st.hc.Append(hc.CStrider, produceSums.Prep())
engExtend(produceSums.Bytes())
engStmt(produceSums)
consumeSums := &strider.ConsumeSums{
Ctx: st.striderCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
[]cgen.Gen{addrSums},
linkTo.addrExprs...,
),
}
st.hc.Append(hc.CStrider, consumeSums.Prep())
engStmt(consumeSums)
default:
spec := &loom.Spec{
From: loom.SpecFrom{
Chans: linkFrom.chans,
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
Filts: make(
[]loom.SpecFilts, len(banks),
),
To: loom.SpecTo{
Pitch1Bytes: linkTo.pitch1Bytes,
Pitch2Bytes: linkTo.pitch2Bytes,
Ops: linkTo.ops[0],
},
FilterH: node.FilterH,
FilterW: node.FilterW,
StrideH: node.StrideH,
StrideW: node.StrideW,
PaddingH: node.PaddingH,
PaddingW: node.PaddingW,
DilationH: node.DilationH,
DilationW: node.DilationW,
Groups: node.Groups,
}
for i, bnk := range banks {
spec.Filts[i] = loom.SpecFilts{
Cnt: bnk.filts,
BnPre: bnk.bnPre,
BnPost: bnk.bnPost,
}
if i == 0 {
continue
}
banks[0].addrExprs = append(
banks[0].addrExprs,
bnk.addrExprs...,
)
}
arrangeFilts := &loom.ArrangeFilts{
Ctx: st.loomCtx,
Spec: spec,
Team: NetTeam(),
Tensors: append(
banks[0].addrExprs,
netAddr(),
),
}
addrFilts := engNetAddr()
st.hc.Append(hc.CLoom, arrangeFilts.Prep())
netExtend(arrangeFilts.Bytes())
netStmt(arrangeFilts)
addrDats := engAddr(-1)
arrangeDats := &loom.ArrangeDats{
Ctx: st.loomCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
linkFrom.addrExprs,
addrDats,
),
}
st.hc.Append(hc.CLoom, arrangeDats.Prep())
engExtend(arrangeDats.Bytes())
engStmt(arrangeDats)
addrSums := engAddr(-1)
produceSums := &loom.ProduceSums{
Ctx: st.loomCtx,
Spec: spec,
Team: EngTeam(),
Tensors: []cgen.Gen{
addrFilts,
addrDats,
addrSums,
},
}
st.hc.Append(hc.CLoom, produceSums.Prep())
engExtend(produceSums.Bytes())
engStmt(produceSums)
consumeSums := &loom.ConsumeSums{
Ctx: st.loomCtx,
Spec: spec,
Team: EngTeam(),
Tensors: append(
[]cgen.Gen{addrSums},
linkTo.addrExprs...,
),
}
st.hc.Append(hc.CLoom, consumeSums.Prep())
engStmt(consumeSums)
}
case *raw.FullyConnected:
netEng := [2]cgen.Gen{
netAddr(),
engNetAddr(),
}
arrange := &fc.Arrange{
Ctx: st.fcCtx,
ToC: linkTo.chans,
FromC: linkFrom.chans,
FromH: linkFrom.height,
FromW: linkFrom.width,
BnPre: banks[0].bnPre,
BnPost: banks[0].bnPost,
Team: NetTeam(),
Tensors: append(
banks[0].addrExprs,
netEng[0],
),
}
st.hc.Append(hc.CFc, arrange.Prep())
netExtend(arrange.Bytes())
netStmt(arrange)
apply := &fc.Apply{
Ctx: st.fcCtx,
ToC: linkTo.chans,
FromC: linkFrom.chans,
FromH: linkFrom.height,
FromW: linkFrom.width,
Ops: linkTo.ops[0],
Team: EngTeam(),
Tensors: append(
[]cgen.Gen{
netEng[1],
linkFrom.addrExprs[0],
},
linkTo.addrExprs...,
),
}
st.hc.Append(hc.CFc, apply.Prep())
engStmt(apply)
case *raw.Input:
case *raw.Output:
case *raw.Pooling:
tensors := append(
linkFrom.addrExprs,
linkTo.addrExprs...,
)
switch node.Kind {
case raw.Max2x2Stride2, raw.Avg2x2Stride2:
spec := &twopl.Spec{
Kind: node.Kind,
PaddingH: node.PaddingH,
PaddingW: node.PaddingW,
Channels: linkFrom.chans,
From: twopl.SpecFrom{
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
To: twopl.SpecTo{
Pitch1Bytes: linkTo.pitch1Bytes,
Pitch2Bytes: linkTo.pitch2Bytes,
Ops: linkTo.ops[0],
},
}
call := &twopl.Call{
Ctx: st.twoplCtx,
Spec: spec,
Team: EngTeam(),
Tensors: tensors,
}
st.hc.Append(hc.CTwopl, call.Prep())
engStmt(call)
case raw.Max3x3Stride2, raw.Avg3x3Stride2:
spec := &thrpl.Spec{
Kind: node.Kind,
PaddingH: node.PaddingH,
PaddingW: node.PaddingW,
Channels: linkFrom.chans,
From: thrpl.SpecFrom{
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
To: thrpl.SpecTo{
Pitch1Bytes: linkTo.pitch1Bytes,
Pitch2Bytes: linkTo.pitch2Bytes,
Ops: linkTo.ops[0],
},
}
call := &thrpl.Call{
Ctx: st.thrplCtx,
Spec: spec,
Team: EngTeam(),
Tensors: tensors,
}
st.hc.Append(hc.CThrpl, call.Prep())
engStmt(call)
case raw.MaxGlobal, raw.AvgGlobal:
spec := &glopl.Spec{
Kind: node.Kind,
Channels: linkFrom.chans,
ElemBytes: linkFrom.elemBytes,
From: glopl.SpecFrom{
Height: linkFrom.height,
Width: linkFrom.width,
Pitch1Bytes: linkFrom.pitch1Bytes,
Pitch2Bytes: linkFrom.pitch2Bytes,
Ops: linkFrom.ops[0],
},
To: glopl.SpecTo{
Ops: linkTo.ops[0],
Cnt: len(planOp.To[0].Piles),
},
}
call := &glopl.Call{
Ctx: st.gloplCtx,
Spec: spec,
Team: EngTeam(),
Tensors: tensors,
}
st.hc.Append(hc.CGlopl, call.Prep())
engStmt(call)
default:
panic("bug")
}
case *raw.Softmax:
var (
n = 1 + len(linkTo.addrExprs)
shapes = make([]softmax.Shape, n)
tensors = make([]cgen.Gen, n)
)
for i := 0; i < n; i++ {
var (
lnk = linkFrom
j = 0
)
if i > 0 {
lnk = linkTo
j = i - 1
}
shapes[i] = softmax.Shape{
Channels: lnk.chans,
Height: lnk.height,
Width: lnk.width,
ElemBytes: lnk.elemBytes,
Pitch1Bytes: lnk.pitch1Bytes[j],
Pitch2Bytes: lnk.pitch2Bytes[j],
}
tensors[i] = lnk.addrExprs[j]
}
call := &softmax.Call{
Ctx: st.softmaxCtx,
Team: EngTeam(),
Tensors: tensors,
Shapes: shapes,
}
st.hc.Append(hc.CSoftmax, call.Prep())
engStmt(call)
default:
panic("bug")
}
}
layer6 := func() {
sublayer2 := func() {
var (
span = planOp.To[0]
n1 = len(span.Counts)
)
banks = make([]*bank, n1)
for i, filts := range span.Counts {
var (
ps = planOp.Params[i]
pms = planOp.ParamMods[i]
bnPre = len(pms[0])
bnPost = len(pms[1])
n2 = 2 + bnPre + bnPost
aes = make([]cgen.Gen, n2)
)
banks[i] = &bank{
filts: filts,
bnPre: bnPre,
bnPost: bnPost,
addrExprs: aes,
}
for j := 0; j < 2; j++ {
aes[j] = cgen.Cast{
Type: cgen.PtrChar,
Expr: param(ps[j].Tensor),
}
}
for j := 0; j < bnPre; j++ {
node := pms[0][j].Nodes[0].(*raw.BatchNorm)
netEng := bnSimplify(node, linkFrom.chans)
aes[2+j] = netEng[0]
}
for j := 0; j < bnPost; j++ {
node := pms[1][j].Nodes[0].(*raw.BatchNorm)
netEng := bnSimplify(node, filts)
aes[2+bnPre+j] = netEng[0]
}
}
}
sublayer1 := func() {
switch planOp.Nodes[0].(type) {
case *raw.Conv, *raw.FullyConnected:
sublayer2()
default:
banks = nil
}
layer7()
}
sublayer1()
}
layer5 := func() {
linkFrom = linker(
planOp.From,
planOp.FromMods,
true,
)
linkTo = linker(
planOp.To,
planOp.ToMods,
false,
)
layer6()
}
layer4 := func() {
for _, planOp = range st.pl.Seq {
layer5()
tmpRewind()
engRewind()
netBlock()
engBlock()
}
}
layer3 := func() {
sublayer2 := func() {
do := func(nodes []raw.Node) {
for _, node := range nodes {
switch node := node.(type) {
case *raw.BatchNorm:
key := bnKey(node)
bnPersist[key] = true
}
}
}
doMods := func(pmods [][]plan.Mod) {
for _, pms := range pmods {
for i := range pms {
pm := &pms[i]
do(pm.Nodes)
}
}
}
for _, op := range st.pl.Seq {
do(op.Nodes)
doMods(op.FromMods)
doMods(op.ToMods)
}
}
sublayer1 := func() {
bnPersist = make(map[string]bool)
bnNetEng = make(map[string][2]cgen.Gen)
sublayer2()
layer4()
}
sublayer1()
}
layer2 := func() {
sublayer4 := func() cgen.Gen {
var (
eng = st.engineCtx.InferenceEng
used = false
)
field := func(name string) cgen.Gen {
used = true
return cgen.Arrow{
Expr: eng, Name: name,
}
}
return cgen.Stmts{
func() cgen.Gen {
if engNetAlign == nil {
return nil
}
return cgen.Var{
Type: cgen.PtrChar,
What: engNetAlign,
Init: cgen.Arrow{
Expr: field(st.engineCtx.StructNet),
Name: st.netCtx.StructAlign,
},
}
}(),
func() cgen.Gen {
if engTeam == nil {
return nil
}
return cgen.Var{
Type: st.threaderCtx.PtrTeam,
What: engTeam,
Init: field(st.engineCtx.StructTeam),
}
}(),
func() cgen.Gen {
if engAlign == nil {
return nil
}
return cgen.Var{
Type: cgen.PtrChar,
What: engAlign,
Init: field(st.engineCtx.StructAlign),
}
}(),
func() cgen.Gen {
if used {
return nil
}
return unused(eng)
}(),
engBlocks,
}
}
sublayer3 := func() {
var (
body = sublayer4()
def = st.engineCtx.InferenceDef(body)
)
st.hc.Append(hc.CEngine, def, cgen.Newline)
}
sublayer2 := func() {
def := st.engineCtx.CreateDef(engBytes)
st.hc.Append(hc.CEngine, def, cgen.Newline)
sublayer3()
}
sublayer1 := func() {
engNetAlign = nil
engTeam = nil
engAlign = nil
engEdge = 0
engBytes = 0
engStmts = nil
engBlocks = nil
layer3()
sublayer2()
}
sublayer1()
}
layer1 := func() {
malloc := func(bytes cgen.Gen) cgen.Gen {
return cgen.Call{
Func: cgen.Malloc,
Args: bytes,
}
}
alloc := func(what, unwind cgen.Gen, bytes int) cgen.Gen {
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: what,
Init: malloc(il(
st.netCtx.Alignment - 1 + bytes,
)),
},
&errmsg.ErrnoIf{
Ctx: st.errmsgCtx,
Cond: cgen.IsZero{Expr: what},
Unwind: unwind,
},
}
}
align := func(dest, src cgen.Gen) cgen.Gen {
expr := cgen.And{
Expr1: cgen.Paren{
Inner: cgen.Add{
Expr1: cgen.Cast{
Type: cgen.SizeT,
Expr: src,
},
Expr2: il(st.netCtx.Alignment - 1),
},
},
Expr2: il(-st.netCtx.Alignment),
}
return cgen.Var{
Type: cgen.PtrChar,
What: dest,
Init: cgen.Cast{
Type: cgen.PtrVoid,
Expr: cgen.Paren{Inner: expr},
},
}
}
free := func(what cgen.Gen) cgen.Gen {
if what == nil {
return nil
}
return cgen.Call{
Func: cgen.Free,
Args: what,
}
}
sublayer6 := func() cgen.Gen {
if netTeam == nil {
return netBlocks
}
unwind := func() cgen.Gen {
var (
free1 = free(tmpAlloc)
free2 = free(netAlloc)
)
switch {
case free1 == nil:
return free2
case free2 == nil:
return free1
}
return cgen.Stmts{
free1,
free2,
}
}()
return cgen.Stmts{
cgen.Var{
Type: st.threaderCtx.PtrTeam,
What: netTeam,
Init: il(0),
},
&threader.Create{
Ctx: st.threaderCtx,
Team: cgen.Addr{Expr: netTeam},
Nt: st.netCtx.CreateThreads,
Unwind: unwind,
},
netBlocks,
&threader.Destroy{
Ctx: st.threaderCtx,
Team: netTeam,
},
}
}
sublayer5 := func() cgen.Gen {
if tmpAlloc == nil {
return sublayer6()
}
return cgen.Stmts{
alloc(tmpAlloc, free(netAlloc), tmpBytes),
align(tmpAlign, tmpAlloc),
sublayer6(),
free(tmpAlloc),
}
}
sublayer4 := func() cgen.Gen {
out := nm("net")
put := func(dest string, src cgen.Gen) cgen.Gen {
if src == nil {
src = il(0)
}
return cgen.Assign{
Expr1: cgen.Arrow{
Expr: out, Name: dest,
},
Expr2: src,
}
}
return cgen.Stmts{
func() cgen.Gen {
if netAlloc == nil {
return nil
}
return cgen.Stmts{
alloc(netAlloc, nil, netBytes),
align(netAlign, netAlloc),
}
}(),
sublayer5(),
cgen.Var{
Type: ptr(vb(st.netCtx.StructName)),
What: out,
Init: malloc(cgen.Sizeof{
What: vb(st.netCtx.StructName),
}),
},
&errmsg.ErrnoIf{
Ctx: st.errmsgCtx,
Cond: cgen.IsZero{Expr: out},
Unwind: free(netAlloc),
},
put(st.netCtx.StructAlloc, netAlloc),
put(st.netCtx.StructAlign, netAlign),
cgen.Assign{
Expr1: cgen.At{
Expr: st.netCtx.CreateNet,
},
Expr2: out,
},
}
}
sublayer3 := func() cgen.Gen {
return cgen.Stmts{
func() cgen.Gen {
if usedParams {
return nil
}
return unused(st.netCtx.CreateParams)
}(),
func() cgen.Gen {
if netTeam != nil {
return nil
}
return unused(st.netCtx.CreateThreads)
}(),
&cpu.Chk{
Platform: st.pl.Config.Platform,
Emc: st.errmsgCtx,
},
sublayer4(),
cgen.Return{
Expr: il(0),
},
}
}
sublayer2 := func() {
var (
body = sublayer3()
def = st.netCtx.CreateDef(body)
)
st.hc.Append(hc.CNet, def, cgen.Newline)
}
sublayer1 := func() {
usedParams = false
netAlloc = nil
netAlign = nil
netBytes = 0
tmpAlloc = nil
tmpAlign = nil
tmpEdge = 0
tmpBytes = 0
netTeam = nil
netStmts = nil
netBlocks = nil
layer2()
sublayer2()
}
sublayer1()
}
layer1()
}

Top || internal/compile/author/act/act.go

package act

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

type Ctx struct {
platform raw.Platform
nms nmsrc.Src
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src) *Ctx {
return &Ctx{
platform: pl.Config.Platform,
nms: nms,
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

type ReLU struct {
*Ctx
NegSlope float32
Var cgen.Gen
}

func (r *ReLU) Append(to []byte) []byte {
if r.NegSlope == 1 {
return to
}
switch r.platform {
case raw.AVX512Float32:
return r.m512(to)
default:
panic("bug")
}
}

func (r *ReLU) m512(to []byte) []byte {
stmts := make(cgen.Stmts, 2)
switch r.NegSlope {
case 0:
stmts[0] = cgen.Assign{
Expr1: r.Var,
Expr2: avx.Mm512MaxPs{avx.Mm512SetzeroPs, r.Var},
}
default:
mask := vb(r.name("mask"))
stmts[0] = cgen.Var{
Type: avx.Mmask16, What: mask,
Init: avx.Mm512CmpPsMask{
r.Var, avx.Mm512SetzeroPs, avx.CmpLtOq,
},
}
stmts[1] = cgen.Assign{
Expr1: r.Var,
Expr2: avx.Mm512MaskMulPs{
r.Var, mask,
r.Var, avx.Mm512Set1PsLit(r.NegSlope),
},
}
}
return stmts.Append(to)
}

Top || internal/compile/author/avx/avx.go

package avx

import "NN-512/internal/compile/author/cgen"

var (
CmpLtOq cgen.Gen = cgen.Vb("_CMP_LT_OQ")
M256i cgen.Gen = cgen.Vb("__m256i")
M512 cgen.Gen = cgen.Vb("__m512")
M512i cgen.Gen = cgen.Vb("__m512i")
Mmask16 cgen.Gen = cgen.Vb("__mmask16")
)

const (
mmFroundToNearestInt cgen.Vb = "_MM_FROUND_TO_NEAREST_INT"
mmFroundNoExc cgen.Vb = "_MM_FROUND_NO_EXC"
)

var FroundToNearestIntNoExc cgen.Gen = cgen.Or{
Expr1: mmFroundToNearestInt,
Expr2: mmFroundNoExc,
}

func call(to []byte, fn string, args []cgen.Gen) []byte {
return cgen.Call{
Func: cgen.Vb(fn),
Args: cgen.CommaSpaced(args),
}.Append(to)
}

type Mm512AddEpi32 []cgen.Gen

func (m Mm512AddEpi32) Append(to []byte) []byte {
return call(to, "_mm512_add_epi32", m)
}

type Mm512AddPs []cgen.Gen

func (m Mm512AddPs) Append(to []byte) []byte {
return call(to, "_mm512_add_ps", m)
}

type Mm512AlignrEpi32 []cgen.Gen

func (m Mm512AlignrEpi32) Append(to []byte) []byte {
return call(to, "_mm512_alignr_epi32", m)
}

type Mm512CastpsSi512 []cgen.Gen

func (m Mm512CastpsSi512) Append(to []byte) []byte {
return call(to, "_mm512_castps_si512", m)
}

type Mm512Castsi256Si512 []cgen.Gen

func (m Mm512Castsi256Si512) Append(to []byte) []byte {
return call(to, "_mm512_castsi256_si512", m)
}

type Mm512Castsi512Ps []cgen.Gen

func (m Mm512Castsi512Ps) Append(to []byte) []byte {
return call(to, "_mm512_castsi512_ps", m)
}

type Mm512Castsi512Si256 []cgen.Gen

func (m Mm512Castsi512Si256) Append(to []byte) []byte {
return call(to, "_mm512_castsi512_si256", m)
}

type Mm512CmpPsMask []cgen.Gen

func (m Mm512CmpPsMask) Append(to []byte) []byte {
return call(to, "_mm512_cmp_ps_mask", m)
}

type Mm512CvtphPs []cgen.Gen

func (m Mm512CvtphPs) Append(to []byte) []byte {
return call(to, "_mm512_cvtph_ps", m)
}

type Mm512CvtpsEpi32 []cgen.Gen

func (m Mm512CvtpsEpi32) Append(to []byte) []byte {
return call(to, "_mm512_cvtps_epi32", m)
}

type Mm512CvtpsPh []cgen.Gen

func (m Mm512CvtpsPh) Append(to []byte) []byte {
return call(to, "_mm512_cvtps_ph", m)
}

type Mm512DivPs []cgen.Gen

func (m Mm512DivPs) Append(to []byte) []byte {
return call(to, "_mm512_div_ps", m)
}

type Mm512Extracti64x4Epi64 []cgen.Gen

func (m Mm512Extracti64x4Epi64) Append(to []byte) []byte {
return call(to, "_mm512_extracti64x4_epi64", m)
}

type Mm512FmaddPs []cgen.Gen

func (m Mm512FmaddPs) Append(to []byte) []byte {
return call(to, "_mm512_fmadd_ps", m)
}

type Mm512FmsubPs []cgen.Gen

func (m Mm512FmsubPs) Append(to []byte) []byte {
return call(to, "_mm512_fmsub_ps", m)
}

type Mm512FnmaddPs []cgen.Gen

func (m Mm512FnmaddPs) Append(to []byte) []byte {
return call(to, "_mm512_fnmadd_ps", m)
}

type Mm512FnmsubPs []cgen.Gen

func (m Mm512FnmsubPs) Append(to []byte) []byte {
return call(to, "_mm512_fnmsub_ps", m)
}

type Mm512Inserti64x4 []cgen.Gen

func (m Mm512Inserti64x4) Append(to []byte) []byte {
return call(to, "_mm512_inserti64x4", m)
}

type Mm512LoaduPs []cgen.Gen

func (m Mm512LoaduPs) Append(to []byte) []byte {
return call(to, "_mm512_loadu_ps", m)
}

type Mm512LoaduSi512 []cgen.Gen

func (m Mm512LoaduSi512) Append(to []byte) []byte {
return call(to, "_mm512_loadu_si512", m)
}

type Mm512Mask3FmaddPs []cgen.Gen

func (m Mm512Mask3FmaddPs) Append(to []byte) []byte {
return call(to, "_mm512_mask3_fmadd_ps", m)
}

type Mm512Mask3FnmaddPs []cgen.Gen

func (m Mm512Mask3FnmaddPs) Append(to []byte) []byte {
return call(to, "_mm512_mask3_fnmadd_ps", m)
}

type Mm512MaskAddPs []cgen.Gen

func (m Mm512MaskAddPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_add_ps", m)
}

type Mm512MaskFmaddPs []cgen.Gen

func (m Mm512MaskFmaddPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_fmadd_ps", m)
}

type Mm512MaskFnmaddPs []cgen.Gen

func (m Mm512MaskFnmaddPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_fnmadd_ps", m)
}

type Mm512MaskMaxPs []cgen.Gen

func (m Mm512MaskMaxPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_max_ps", m)
}

type Mm512MaskMovPs []cgen.Gen

func (m Mm512MaskMovPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_mov_ps", m)
}

type Mm512MaskMulPs []cgen.Gen

func (m Mm512MaskMulPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_mul_ps", m)
}

type Mm512MaskStoreuEpi32 []cgen.Gen

func (m Mm512MaskStoreuEpi32) Append(to []byte) []byte {
return call(to, "_mm512_mask_storeu_epi32", m)
}

type Mm512MaskStoreuPs []cgen.Gen

func (m Mm512MaskStoreuPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_storeu_ps", m)
}

type Mm512MaskSubPs []cgen.Gen

func (m Mm512MaskSubPs) Append(to []byte) []byte {
return call(to, "_mm512_mask_sub_ps", m)
}

type Mm512MaskzLoaduEpi32 []cgen.Gen

func (m Mm512MaskzLoaduEpi32) Append(to []byte) []byte {
return call(to, "_mm512_maskz_loadu_epi32", m)
}

type Mm512MaskzLoaduPs []cgen.Gen

func (m Mm512MaskzLoaduPs) Append(to []byte) []byte {
return call(to, "_mm512_maskz_loadu_ps", m)
}

type Mm512MaxPs []cgen.Gen

func (m Mm512MaxPs) Append(to []byte) []byte {
return call(to, "_mm512_max_ps", m)
}

type Mm512MinPs []cgen.Gen

func (m Mm512MinPs) Append(to []byte) []byte {
return call(to, "_mm512_min_ps", m)
}

type Mm512MulPs []cgen.Gen

func (m Mm512MulPs) Append(to []byte) []byte {
return call(to, "_mm512_mul_ps", m)
}

type Mm512Permutex2varPs []cgen.Gen

func (m Mm512Permutex2varPs) Append(to []byte) []byte {
return call(to, "_mm512_permutex2var_ps", m)
}

type Mm512PermutexvarPs []cgen.Gen

func (m Mm512PermutexvarPs) Append(to []byte) []byte {
return call(to, "_mm512_permutexvar_ps", m)
}

type Mm512RoundscalePs []cgen.Gen

func (m Mm512RoundscalePs) Append(to []byte) []byte {
return call(to, "_mm512_roundscale_ps", m)
}

type Mm512Rsqrt14Ps []cgen.Gen

func (m Mm512Rsqrt14Ps) Append(to []byte) []byte {
return call(to, "_mm512_rsqrt14_ps", m)
}

type Mm512Set1Ps []cgen.Gen

func (m Mm512Set1Ps) Append(to []byte) []byte {
return call(to, "_mm512_set1_ps", m)
}

type Mm512Set1PsLit cgen.FloatLit

func (m Mm512Set1PsLit) Append(to []byte) []byte {
return Mm512Set1Ps{cgen.FloatLit(m)}.Append(to)
}

type Mm512SetEpi32 []cgen.Gen

func (m Mm512SetEpi32) Append(to []byte) []byte {
return call(to, "_mm512_set_epi32", m)
}

type Mm512SetPs []cgen.Gen

func (m Mm512SetPs) Append(to []byte) []byte {
return call(to, "_mm512_set_ps", m)
}

var Mm512SetzeroPs cgen.Gen = cgen.Call{
Func: cgen.Vb("_mm512_setzero_ps"),
}

type Mm512ShuffleF32x4 []cgen.Gen

func (m Mm512ShuffleF32x4) Append(to []byte) []byte {
return call(to, "_mm512_shuffle_f32x4", m)
}

type Mm512ShuffleI32x4 []cgen.Gen

func (m Mm512ShuffleI32x4) Append(to []byte) []byte {
return call(to, "_mm512_shuffle_i32x4", m)
}

type Mm512ShufflePs []cgen.Gen

func (m Mm512ShufflePs) Append(to []byte) []byte {
return call(to, "_mm512_shuffle_ps", m)
}

type Mm512SlliEpi32 []cgen.Gen

func (m Mm512SlliEpi32) Append(to []byte) []byte {
return call(to, "_mm512_slli_epi32", m)
}

type Mm512StoreuEpi32 []cgen.Gen

func (m Mm512StoreuEpi32) Append(to []byte) []byte {
return call(to, "_mm512_storeu_epi32", m)
}

type Mm512StoreuPs []cgen.Gen

func (m Mm512StoreuPs) Append(to []byte) []byte {
return call(to, "_mm512_storeu_ps", m)
}

type Mm512SubPs []cgen.Gen

func (m Mm512SubPs) Append(to []byte) []byte {
return call(to, "_mm512_sub_ps", m)
}

type Mm512UnpackhiPs []cgen.Gen

func (m Mm512UnpackhiPs) Append(to []byte) []byte {
return call(to, "_mm512_unpackhi_ps", m)
}

type Mm512UnpackloPs []cgen.Gen

func (m Mm512UnpackloPs) Append(to []byte) []byte {
return call(to, "_mm512_unpacklo_ps", m)
}

Top || internal/compile/author/bn/bn.go

package bn

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/rsqrt"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

type Ctx struct {
prefix string
platform raw.Platform
nms nmsrc.Src
rc *rsqrt.Ctx
dedup map[string]string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, rc *rsqrt.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Bn",
platform: pl.Config.Platform,
nms: nms,
rc: rc,
dedup: make(map[string]string),
}
}

func (c *Ctx) maBytes() int {
switch c.platform {
case raw.AVX512Float32:
return 8
default:
panic("bug")
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func cast(stride int) cgen.Gen {
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: il(stride),
}
}

func addr(ptr, stride, idx cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: ptr,
Expr2: cgen.Mul{
Expr1: stride,
Expr2: idx,
},
}
}

type Simplify struct {
*Ctx
Channels int
Epsilon float32
Means cgen.Gen
Variances cgen.Gen
Scales cgen.Gen
Shifts cgen.Gen
Mas cgen.Gen
funcName string
means cgen.Gen
variances cgen.Gen
scales cgen.Gen
shifts cgen.Gen
mas cgen.Gen
}

func (s *Simplify) Append(to []byte) []byte {
return cgen.Stmts{cgen.Call{
Func: vb(s.funcName),
Args: cgen.CommaLines{
s.Means,
s.Variances,
s.Scales,
s.Shifts,
s.Mas,
},
}}.Append(to)
}

func (s *Simplify) MasBytes() int {
return s.Channels * s.maBytes()
}

func (s *Simplify) Prep() cgen.Gen {
const label = "Simplify"
sig := fmt.Sprintf(label+" %d %g", s.Channels, s.Epsilon)
if prior, ok := s.dedup[sig]; ok {
s.funcName = prior
return nil
}
s.funcName = s.name(s.prefix + label)
s.dedup[sig] = s.funcName
return cgen.Gens{s.funcDef(), cgen.Newline}
}

func (s *Simplify) funcDef() cgen.Gen {
s.means = vb(s.name("means"))
s.variances = vb(s.name("variances"))
s.scales = vb(s.name("scales"))
s.shifts = vb(s.name("shifts"))
s.mas = vb(s.name("mas"))
return cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: s.funcName,
Params: cgen.CommaLines{
cgen.Param{Type: cgen.RestrictPtrFloat, What: s.means},
cgen.Param{Type: cgen.RestrictPtrFloat, What: s.variances},
cgen.Param{Type: cgen.RestrictPtrFloat, What: s.scales},
cgen.Param{Type: cgen.RestrictPtrFloat, What: s.shifts},
cgen.Param{Type: cgen.RestrictPtrChar, What: s.mas},
},
Body: s.body(),
}
}

func (s *Simplify) body() cgen.Gen {
switch s.platform {
case raw.AVX512Float32:
return s.m512()
default:
panic("bug")
}
}

func (s *Simplify) m512() cgen.Gen {
const (
unroll = 5
lanes = 16
laneBytes = 4
)
var (
iters = s.Channels / (unroll * lanes)
after = s.Channels % (unroll * lanes)
eps = vb(s.name("eps"))
xlo = vb(s.name("xlo"))
xhi cgen.Gen
)
ld := func(to, ptr, i cgen.Gen, j, n int) cgen.Gen {
var (
stmt = cgen.Var{Type: avx.M512, What: to}
from = addr(ptr, cast(lanes), il(j))
)
if iters > 0 {
from = addr(from, cast(unroll*lanes), i)
}
if n == lanes {
stmt.Init = avx.Mm512LoaduPs{from}
} else {
stmt.Init = avx.Mm512MaskzLoaduPs{
il(1<<uint(n) - 1), from,
}
}
return stmt
}
st := func(lo, hi, i cgen.Gen, j, n int) cgen.Gen {
const half = lanes * laneBytes
var (
stmts = make(cgen.Stmts, 2)
alo = addr(s.mas, cast(half), il(j*2+0))
ahi = addr(s.mas, cast(half), il(j*2+1))
)
if iters > 0 {
alo = addr(alo, cast(unroll*half*2), i)
ahi = addr(ahi, cast(unroll*half*2), i)
}
if nn := n * 2; nn < lanes {
stmts[0] = avx.Mm512MaskStoreuPs{
alo, il(1<<uint(nn) - 1), lo,
}
} else {
stmts[0] = avx.Mm512StoreuPs{alo, lo}
if nn -= lanes; nn == lanes {
stmts[1] = avx.Mm512StoreuPs{ahi, hi}
} else if nn > 0 {
stmts[1] = avx.Mm512MaskStoreuPs{
ahi, il(1<<uint(nn) - 1), hi,
}
}
}
return stmts
}
deck := func(i cgen.Gen, j, n int) []cgen.Gen {
var (
gs = make([]cgen.Gen, 10)
va = vb(s.name("va"))
rcp = vb(s.name("rcp"))
sc = vb(s.name("sc"))
mul = vb(s.name("mul"))
me = vb(s.name("me"))
sh = vb(s.name("sh"))
add = vb(s.name("add"))
lo = vb(s.name("lo"))
hi cgen.Gen
)
gs[0] = ld(va, s.variances, i, j, n)
gs[1] = cgen.Var{
Type: avx.M512, What: rcp,
Init: &rsqrt.Call{
Ctx: s.rc,
Arg: avx.Mm512AddPs{eps, va},
},
}
gs[2] = ld(sc, s.scales, i, j, n)
gs[3] = cgen.Var{
Type: avx.M512, What: mul,
Init: avx.Mm512MulPs{rcp, sc},
}
gs[4] = ld(me, s.means, i, j, n)
gs[5] = ld(sh, s.shifts, i, j, n)
gs[6] = cgen.Var{
Type: avx.M512, What: add,
Init: avx.Mm512FnmaddPs{
me, mul, sh,
},
}
gs[7] = cgen.Var{
Type: avx.M512, What: lo,
Init: avx.Mm512Permutex2varPs{
mul, xlo, add,
},
}
if n > lanes/2 {
hi = vb(s.name("hi"))
gs[8] = cgen.Var{
Type: avx.M512, What: hi,
Init: avx.Mm512Permutex2varPs{
mul, xhi, add,
},
}
}
gs[9] = st(lo, hi, i, j, n)
return gs
}
shuf := func(a [][]cgen.Gen) cgen.Stmts {
var (
n = len(a[0])
stmts = make(cgen.Stmts, len(a)*n)
i = 0
)
for j := 0; j < n; j++ {
for k := range a {
stmts[i] = a[k][j]
i++
}
}
return stmts
}
var (
stmts = make(cgen.Stmts, 5)
lower = make(avx.Mm512SetEpi32, lanes)
upper = make(avx.Mm512SetEpi32, lanes)
)
stmts[0] = cgen.Var{
Type: avx.M512, What: eps,
Init: avx.Mm512Set1PsLit(s.Epsilon),
}
for i := 0; i < lanes; i++ {
x := i>>1 + lanes*(i&1)
lower[lanes-1-i] = il(x)
upper[lanes-1-i] = il(x + lanes/2)
}
stmts[1] = cgen.Var{
Type: avx.M512i, What: xlo,
Init: lower,
}
if s.Channels > lanes/2 {
xhi = vb(s.name("xhi"))
stmts[2] = cgen.Var{
Type: avx.M512i, What: xhi,
Init: upper,
}
}
if iters > 0 {
var (
inner = make([][]cgen.Gen, unroll)
i = vb(s.name("i"))
)
for j := 0; j < unroll; j++ {
inner[j] = deck(i, j, lanes)
}
stmts[3] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: i, Expr2: il(iters),
},
Post: cgen.IncPre{Expr: i},
Body: shuf(inner),
}
}
if after > 0 {
var (
full = after / lanes
part = after % lanes
outer = make([][]cgen.Gen, full, full+1)
i = il(iters)
)
for j := 0; j < full; j++ {
outer[j] = deck(i, j, lanes)
}
if part > 0 {
last := deck(i, full, part)
outer = append(outer, last)
}
stmts[4] = shuf(outer)
}
return stmts
}

type Offset struct {
*Ctx
Mas cgen.Gen
Channel cgen.Gen
}

func (o *Offset) Append(to []byte) []byte {
var (
stride = cast(o.maBytes())
expr = addr(o.Mas, stride, o.Channel)
)
return expr.Append(to)
}

type Load struct {
*Ctx
Mas cgen.Gen
Channel cgen.Gen
Mul cgen.Gen
Add cgen.Gen
Cnt int
Spread int
}

func (l *Load) Append(to []byte) []byte {
switch l.platform {
case raw.AVX512Float32:
return l.m512(to)
default:
panic("bug")
}
}

func (l *Load) m512(to []byte) []byte {
if l.Cnt == 0 {
return l.m512Broadcast(to)
}
return l.m512Singles(to)
}

func (l *Load) m512Broadcast(to []byte) []byte {
var (
stmts = make(cgen.Stmts, 2)
a1 = cgen.Cast{Type: cgen.PtrFloat, Expr: l.Mas}
a2 = addr(a1, cast(2), l.Channel)
a3 = cgen.Gen(cgen.Paren{Inner: a2})
)
if l.Mul != nil {
stmts[0] = cgen.Var{
Type: avx.M512, What: l.Mul,
Init: avx.Mm512Set1Ps{cgen.Elem{
Arr: a3, Idx: cgen.Zero,
}},
}
}
if l.Add != nil {
stmts[1] = cgen.Var{
Type: avx.M512, What: l.Add,
Init: avx.Mm512Set1Ps{cgen.Elem{
Arr: a3, Idx: cgen.One,
}},
}
}
return stmts.Append(to)
}

func (l *Load) m512Singles(to []byte) []byte {
const (
lanes = 16
laneBytes = 4
)
var (
stmts = make(cgen.Stmts, 6)
spread = l.Spread
even = make(avx.Mm512SetEpi32, lanes)
odd = make(avx.Mm512SetEpi32, lanes)
pmMul = vb(l.name("pmMul"))
pmAdd = vb(l.name("pmAdd"))
mul cgen.Gen
add cgen.Gen
)
if spread == 0 {
spread = 1
}
for i := 0; i < lanes; i++ {
j, k := lanes-1-i, i/spread*2
even[j], odd[j] = il(k), il(k+1)
}
stmts[0] = cgen.Var{
Type: avx.M512i, What: pmMul,
Init: even,
}
stmts[1] = cgen.Var{
Type: avx.M512i, What: pmAdd,
Init: odd,
}
lo := &Offset{
Ctx: l.Ctx,
Mas: l.Mas,
Channel: l.Channel,
}
if n := l.Cnt * 2; n <= lanes {
mas := vb(l.name("mas"))
stmts[2] = cgen.Var{
Type: avx.M512, What: mas,
Init: avx.Mm512MaskzLoaduPs{
il(1<<uint(n) - 1), lo,
},
}
mul = avx.Mm512PermutexvarPs{
pmMul, mas,
}
add = avx.Mm512PermutexvarPs{
pmAdd, mas,
}
} else {
var (
masLo = vb(l.name("masLo"))
masHi = vb(l.name("masHi"))
)
stmts[2] = cgen.Var{
Type: avx.M512, What: masLo,
Init: avx.Mm512LoaduPs{lo},
}
hi := cgen.Add{
Expr1: lo,
Expr2: cast(lanes * laneBytes),
}
stmts[3] = cgen.Var{
Type: avx.M512, What: masHi,
Init: avx.Mm512MaskzLoaduPs{
il(1<<uint(n-lanes) - 1), hi,
},
}
mul = avx.Mm512Permutex2varPs{
masLo, pmMul, masHi,
}
add = avx.Mm512Permutex2varPs{
masLo, pmAdd, masHi,
}
}
stmts[4] = cgen.Var{
Type: avx.M512, What: l.Mul,
Init: mul,
}
stmts[5] = cgen.Var{
Type: avx.M512, What: l.Add,
Init: add,
}
return stmts.Append(to)
}

type Apply struct {
*Ctx
Mul cgen.Gen
Add cgen.Gen
To cgen.Gen
Mask cgen.Gen
}

func (a *Apply) Append(to []byte) []byte {
switch a.platform {
case raw.AVX512Float32:
return a.m512(to)
default:
panic("bug")
}
}

func (a *Apply) m512(to []byte) []byte {
assn := cgen.Assign{
Expr1: a.To,
}
if a.Mask == nil {
assn.Expr2 = avx.Mm512FmaddPs{
a.To, a.Mul, a.Add,
}
} else {
assn.Expr2 = avx.Mm512MaskFmaddPs{
a.To, a.Mask, a.Mul, a.Add,
}
}
return cgen.Stmts{
assn,
}.Append(to)
}

Top || internal/compile/author/cgen/cgen.go

package cgen

import "strconv"

const (
aligned = "aligned"
ampersand = "&"
arrow = "->"
assign = "="
asterisk = "*"
attribute = "__attribute__"
bang = "!"
brace1 = "{"
brace2 = "}"
break_ = "break"
calloc = "calloc"
caret = "^"
case_ = "case"
char = "char"
cmpE = "=="
cmpG = ">"
cmpGE = ">="
cmpL = "<"
cmpLE = "<="
cmpNE = "!="
colon = ":"
comma = ","
continue_ = "continue"
cplusplus = "__cplusplus"
cpuSupports = "__builtin_cpu_supports"
ctzl = "__builtin_ctzl"
dec = "--"
default_ = "default"
dot = "."
doubleQuote = "\""
ellipsis = "..."
else_ = "else"
empty = ""
endif = "endif"
errno = "errno"
expect = "__builtin_expect"
extern = "extern"
float = "float"
floatSuffix = "f"
for_ = "for"
free = "free"
gap = "/**/"
goto_ = "goto"
hash = "#"
ifdef = "ifdef"
if_ = "if"
inc = "++"
include = "include"
int64T = "int64_t"
int_ = "int"
land = "&&"
lineNum = "__LINE__"
linkageC = "C"
long = "long"
lor = "||"
malloc = "malloc"
memcpy = "memcpy"
memset = "memset"
minus = "-"
newline = "\n"
once = "once"
one = "1"
packed = "packed"
paren1 = "("
paren2 = ")"
percent = "%"
pipe = "|"
plus = "+"
pragma = "pragma"
pthreadCondDestroy = "pthread_cond_destroy"
pthreadCondInit = "pthread_cond_init"
pthreadCondSignal = "pthread_cond_signal"
pthreadCondT = "pthread_cond_t"
pthreadCondWait = "pthread_cond_wait"
pthreadCreate = "pthread_create"
pthreadJoin = "pthread_join"
pthreadMutexDestroy = "pthread_mutex_destroy"
pthreadMutexInit = "pthread_mutex_init"
pthreadMutexLock = "pthread_mutex_lock"
pthreadMutexT = "pthread_mutex_t"
pthreadMutexUnlock = "pthread_mutex_unlock"
pthreadT = "pthread_t"
ptrdiffT = "ptrdiff_t"
questionMark = "?"
restrict = "restrict"
return_ = "return"
semicolon = ";"
shiftHigh = "<<"
shiftLow = ">>"
sizeof = "sizeof"
sizeT = "size_t"
slash = "/"
slashes = "//"
space = " "
sprintf = "sprintf"
squareBracket1 = "["
squareBracket2 = "]"
static = "static"
struct_ = "struct"
switch_ = "switch"
tilde = "~"
typedef = "typedef"
vaEnd = "va_end"
vaList = "va_list"
vaStart = "va_start"
void = "void"
vsnprintf = "vsnprintf"
zero = "0"
)

type Add struct {
Expr1, Expr2 Gen
}

func (a Add) Append(to []byte) []byte {
to = a.Expr1.Append(to)
to = append(to, plus...)
to = a.Expr2.Append(to)
return to
}

type AddAssign struct {
Expr1, Expr2 Gen
}

func (a AddAssign) Append(to []byte) []byte {
to = a.Expr1.Append(to)
to = append(to, space+plus+assign+space...)
to = a.Expr2.Append(to)
return to
}

type Addr struct {
Expr Gen
}

func (a Addr) Append(to []byte) []byte {
to = append(to, ampersand...)
to = a.Expr.Append(to)
return to
}

type AddrArrow Arrow

func (a AddrArrow) Append(to []byte) []byte {
to = Addr{Arrow(a)}.Append(to)
return to
}

type AddrDot Dot

func (a AddrDot) Append(to []byte) []byte {
to = Addr{Dot(a)}.Append(to)
return to
}

type Aligned int

func (a Aligned) Append(to []byte) []byte {
to = append(to, aligned...)
to = Paren{IntLit(a)}.Append(to)
return to
}

type And struct {
Expr1, Expr2 Gen
}

func (a And) Append(to []byte) []byte {
to = a.Expr1.Append(to)
to = append(to, ampersand...)
to = a.Expr2.Append(to)
return to
}

type AndAssign struct {
Expr1, Expr2 Gen
}

func (a AndAssign) Append(to []byte) []byte {
to = a.Expr1.Append(to)
to = append(to, space+ampersand+assign+space...)
to = a.Expr2.Append(to)
return to
}

type AngleBracketed string

func (a AngleBracketed) Append(to []byte) []byte {
to = append(to, cmpL...)
to = append(to, a...)
to = append(to, cmpG...)
return to
}

type Arrow struct {
Expr Gen
Name string
}

func (a Arrow) Append(to []byte) []byte {
to = a.Expr.Append(to)
to = append(to, arrow...)
to = append(to, a.Name...)
return to
}

type Assign struct {
Expr1, Expr2 Gen
}

func (a Assign) Append(to []byte) []byte {
to = a.Expr1.Append(to)
to = append(to, space+assign+space...)
to = a.Expr2.Append(to)
return to
}

type At struct {
Expr Gen
}

func (a At) Append(to []byte) []byte {
to = append(to, asterisk...)
to = a.Expr.Append(to)
return to
}

type AttrSpec struct {
Attrs Gen
}

func (a AttrSpec) Append(to []byte) []byte {
to = append(to, attribute...)
to = Paren{Paren{a.Attrs}}.Append(to)
return to
}

type Block struct {
Inner Gen
}

func (b Block) Append(to []byte) []byte {
to = append(to, brace1+newline...)
to = Maybe{b.Inner}.Append(to)
to = append(to, brace2...)
return to
}

type Brace struct {
Inner Gen
}

func (b Brace) Append(to []byte) []byte {
to = append(to, brace1...)
to = Maybe{b.Inner}.Append(to)
to = append(to, brace2...)
return to
}

type Call struct {
Func, Args Gen
}

func (c Call) Append(to []byte) []byte {
to = c.Func.Append(to)
to = Paren{c.Args}.Append(to)
return to
}

type Case struct {
Expr, Body Gen
}

func (c Case) Append(to []byte) []byte {
if c.Expr == nil {
to = append(to, default_...)
} else {
to = append(to, case_+space...)
to = c.Expr.Append(to)
}
to = append(to, colon...)
if c.Body != nil {
to = append(to, space...)
to = Block{c.Body}.Append(to)
}
return to
}

type Cast struct {
Type, Expr Gen
}

func (c Cast) Append(to []byte) []byte {
to = Paren{c.Type}.Append(to)
to = c.Expr.Append(to)
return to
}

type CmpE struct {
Expr1, Expr2 Gen
}

func (c CmpE) Append(to []byte) []byte {
to = c.Expr1.Append(to)
to = append(to, space+cmpE+space...)
to = c.Expr2.Append(to)
return to
}

type CmpG struct {
Expr1, Expr2 Gen
}

func (c CmpG) Append(to []byte) []byte {
to = c.Expr1.Append(to)
to = append(to, space+cmpG+space...)
to = c.Expr2.Append(to)
return to
}

type CmpGE struct {
Expr1, Expr2 Gen
}

func (c CmpGE) Append(to []byte) []byte {
to = c.Expr1.Append(to)
to = append(to, space+cmpGE+space...)
to = c.Expr2.Append(to)
return to
}

type CmpL struct {
Expr1, Expr2 Gen
}

func (c CmpL) Append(to []byte) []byte {
to = c.Expr1.Append(to)
to = append(to, space+cmpL+space...)
to = c.Expr2.Append(to)
return to
}

type CmpLE struct {
Expr1, Expr2 Gen
}

func (c CmpLE) Append(to []byte) []byte {
to = c.Expr1.Append(to)
to = append(to, space+cmpLE+space...)
to = c.Expr2.Append(to)
return to
}

type CmpNE struct {
Expr1, Expr2 Gen
}

func (c CmpNE) Append(to []byte) []byte {
to = c.Expr1.Append(to)
to = append(to, space+cmpNE+space...)
to = c.Expr2.Append(to)
return to
}

type CommaLines []Gen

func (c CommaLines) Append(to []byte) []byte {
first := true
for _, gen := range c {
if gen == nil {
continue
}
if first {
first = false
} else {
to = append(to, comma...)
}
to = append(to, newline...)
to = gen.Append(to)
}
if !first {
to = append(to, newline...)
}
return to
}

type CommaSpaced []Gen

func (c CommaSpaced) Append(to []byte) []byte {
first := true
for _, gen := range c {
if gen == nil {
continue
}
if first {
first = false
} else {
to = append(to, comma+space...)
}
to = gen.Append(to)
}
return to
}

type Comment []string

func (c Comment) Append(to []byte) []byte {
for _, line := range c {
switch line {
case empty:
to = append(to, slashes+newline...)
default:
to = append(to, slashes+space...)
to = append(to, line...)
to = append(to, newline...)
}
}
return to
}

type DecPost struct {
Expr Gen
}

func (d DecPost) Append(to []byte) []byte {
to = d.Expr.Append(to)
to = append(to, dec...)
return to
}

type DecPre struct {
Expr Gen
}

func (d DecPre) Append(to []byte) []byte {
to = append(to, dec...)
to = d.Expr.Append(to)
return to
}

type Directive string

const (
Endif Directive = endif
Ifdef Directive = ifdef
Include Directive = include
Pragma Directive = pragma
)

type Dot struct {
Expr Gen
Name string
}

func (d Dot) Append(to []byte) []byte {
to = d.Expr.Append(to)
to = append(to, dot...)
to = append(to, d.Name...)
return to
}

type DoubleQuoted string

func (d DoubleQuoted) Append(to []byte) []byte {
to = append(to, doubleQuote...)
to = append(to, d...)
to = append(to, doubleQuote...)
return to
}

type Elem struct {
Arr, Idx Gen
}

func (e Elem) Append(to []byte) []byte {
to = e.Arr.Append(to)
to = append(to, squareBracket1...)
to = Maybe{e.Idx}.Append(to)
to = append(to, squareBracket2...)
return to
}

type Extern struct {
Tail Gen
}

func (e Extern) Append(to []byte) []byte {
to = append(to, extern+space...)
to = e.Tail.Append(to)
return to
}

type Field struct {
Type, What Gen
}

func (f Field) Append(to []byte) []byte {
to = f.Type.Append(to)
to = append(to, space...)
to = f.What.Append(to)
to = append(to, semicolon...)
return to
}

type FloatLit float64

func (f FloatLit) Append(to []byte) []byte {
to = strconv.AppendFloat(to, float64(f), 'e', -1, 32)
to = append(to, floatSuffix...)
return to
}

type For struct {
Init, Cond, Post, Body Gen
}

func (f For) Append(to []byte) []byte {
to = append(to, for_+space+paren1...)
to = Maybe{f.Init}.Append(to)
if to[len(to)-1] != semicolon[0] {
to = append(to, semicolon...)
}
to = append(to, space...)
to = Maybe{f.Cond}.Append(to)
to = append(to, semicolon+space...)
to = Maybe{f.Post}.Append(to)
to = append(to, paren2...)
if f.Body != nil {
to = append(to, space...)
to = Block{f.Body}.Append(to)
}
return to
}

type FuncDecl struct {
ReturnType Gen
Name string
Params Gen
}

func (f FuncDecl) Append(to []byte) []byte {
to = f.ReturnType.Append(to)
to = append(to, space...)
to = Call{Vb(f.Name), f.Params}.Append(to)
to = append(to, semicolon+newline...)
return to
}

type FuncDef struct {
ReturnType Gen
Name string
Params Gen
Body Gen
}

func (f FuncDef) Append(to []byte) []byte {
var g1, g2, g3 Gen
g1 = f.ReturnType
g2 = Call{Vb(f.Name), f.Params}
g3 = Block{f.Body}
to = Spaced{g1, g2, g3}.Append(to)
to = append(to, newline...)
return to
}

type Gen interface {
Append(to []byte) []byte
}

type Gens []Gen

func (gs Gens) Append(to []byte) []byte {
for _, gen := range gs {
if gen != nil {
to = gen.Append(to)
}
}
return to
}

type Goto Label

func (g Goto) Append(to []byte) []byte {
to = append(to, goto_+space...)
to = append(to, g...)
return to
}

type If struct {
Cond Gen
Then Stmts
Else Stmts
}

func (i If) Append(to []byte) []byte {
to = append(to, if_+space...)
to = Paren{i.Cond}.Append(to)
to = append(to, space...)
to = Block{i.Then}.Append(to)
if n := len(i.Else); n != 0 {
to = append(to, space+else_+space...)
chain := false
if n == 1 {
_, chain = i.Else[0].(If)
}
if chain {
to = i.Else[0].Append(to)
} else {
to = Block{i.Else}.Append(to)
}
}
return to
}

type If1 struct {
Cond, Then, Else Gen
}

func (i If1) Append(to []byte) []byte {
to = append(to, if_+space...)
to = Paren{i.Cond}.Append(to)
to = append(to, space...)
if i.Else == nil {
to = i.Then.Append(to)
} else {
to = Stmts{i.Then}.Append(to)
to = append(to, else_+space...)
to = i.Else.Append(to)
}
return to
}

type IncPost struct {
Expr Gen
}

func (i IncPost) Append(to []byte) []byte {
to = i.Expr.Append(to)
to = append(to, inc...)
return to
}

type IncPre struct {
Expr Gen
}

func (i IncPre) Append(to []byte) []byte {
to = append(to, inc...)
to = i.Expr.Append(to)
return to
}

type IntLit int

func (i IntLit) Append(to []byte) []byte {
to = strconv.AppendInt(to, int64(i), 10)
return to
}

type IsNonzero struct {
Expr Gen
}

func (i IsNonzero) Append(to []byte) []byte {
to = append(to, bang+bang...)
to = i.Expr.Append(to)
return to
}

type IsZero struct {
Expr Gen
}

func (i IsZero) Append(to []byte) []byte {
to = append(to, bang...)
to = i.Expr.Append(to)
return to
}

type Label string

func (l Label) Append(to []byte) []byte {
to = append(to, l...)
to = append(to, colon...)
return to
}

type Land struct {
Expr1, Expr2 Gen
}

func (l Land) Append(to []byte) []byte {
to = l.Expr1.Append(to)
to = append(to, space+land+space...)
to = l.Expr2.Append(to)
return to
}

type Lor struct {
Expr1, Expr2 Gen
}

func (l Lor) Append(to []byte) []byte {
to = l.Expr1.Append(to)
to = append(to, space+lor+space...)
to = l.Expr2.Append(to)
return to
}

type Maybe struct {
What Gen
}

func (m Maybe) Append(to []byte) []byte {
if m.What != nil {
to = m.What.Append(to)
}
return to
}

type MaybeSpace struct {
What Gen
}

func (m MaybeSpace) Append(to []byte) []byte {
if m.What != nil {
to = append(to, space...)
to = m.What.Append(to)
}
return to
}

type Mul struct {
Expr1, Expr2 Gen
}

func (m Mul) Append(to []byte) []byte {
to = m.Expr1.Append(to)
to = append(to, asterisk...)
to = m.Expr2.Append(to)
return to
}

type MulAssign struct {
Expr1, Expr2 Gen
}

func (m MulAssign) Append(to []byte) []byte {
to = m.Expr1.Append(to)
to = append(to, space+asterisk+assign+space...)
to = m.Expr2.Append(to)
return to
}

type Neg struct {
Expr Gen
}

func (n Neg) Append(to []byte) []byte {
to = append(to, minus...)
to = n.Expr.Append(to)
return to
}

type Not struct {
Expr Gen
}

func (n Not) Append(to []byte) []byte {
to = append(to, tilde...)
to = n.Expr.Append(to)
return to
}

type Or struct {
Expr1, Expr2 Gen
}

func (o Or) Append(to []byte) []byte {
to = o.Expr1.Append(to)
to = append(to, pipe...)
to = o.Expr2.Append(to)
return to
}

type OrAssign struct {
Expr1, Expr2 Gen
}

func (o OrAssign) Append(to []byte) []byte {
to = o.Expr1.Append(to)
to = append(to, space+pipe+assign+space...)
to = o.Expr2.Append(to)
return to
}

type Param struct {
Type, What Gen
}

func (p Param) Append(to []byte) []byte {
to = p.Type.Append(to)
to = append(to, space...)
to = p.What.Append(to)
return to
}

type Paren struct {
Inner Gen
}

func (p Paren) Append(to []byte) []byte {
to = append(to, paren1...)
to = Maybe{p.Inner}.Append(to)
to = append(to, paren2...)
return to
}

type Preprocessor struct {
Head Directive
Tail Gen
}

func (p Preprocessor) Append(to []byte) []byte {
to = append(to, hash...)
to = append(to, p.Head...)
to = MaybeSpace{p.Tail}.Append(to)
to = append(to, newline...)
return to
}

type Ptr struct {
Type Gen
}

func (p Ptr) Append(to []byte) []byte {
to = p.Type.Append(to)
to = append(to, asterisk...)
return to
}

type Quo struct {
Expr1, Expr2 Gen
}

func (q Quo) Append(to []byte) []byte {
to = q.Expr1.Append(to)
to = append(to, slash...)
to = q.Expr2.Append(to)
return to
}

type QuoAssign struct {
Expr1, Expr2 Gen
}

func (q QuoAssign) Append(to []byte) []byte {
to = q.Expr1.Append(to)
to = append(to, space+slash+assign+space...)
to = q.Expr2.Append(to)
return to
}

type Rem struct {
Expr1, Expr2 Gen
}

func (r Rem) Append(to []byte) []byte {
to = r.Expr1.Append(to)
to = append(to, percent...)
to = r.Expr2.Append(to)
return to
}

type RemAssign struct {
Expr1, Expr2 Gen
}

func (r RemAssign) Append(to []byte) []byte {
to = r.Expr1.Append(to)
to = append(to, space+percent+assign+space...)
to = r.Expr2.Append(to)
return to
}

type RestrictPtr Ptr

func (r RestrictPtr) Append(to []byte) []byte {
to = Ptr(r).Append(to)
to = append(to, restrict...)
return to
}

type Return struct {
Expr Gen
}

func (r Return) Append(to []byte) []byte {
to = append(to, return_...)
to = MaybeSpace{r.Expr}.Append(to)
return to
}

type ShiftHigh struct {
Expr1, Expr2 Gen
}

func (s ShiftHigh) Append(to []byte) []byte {
to = s.Expr1.Append(to)
to = append(to, shiftHigh...)
to = s.Expr2.Append(to)
return to
}

type ShiftHighAssign struct {
Expr1, Expr2 Gen
}

func (s ShiftHighAssign) Append(to []byte) []byte {
to = s.Expr1.Append(to)
to = append(to, space+shiftHigh+assign+space...)
to = s.Expr2.Append(to)
return to
}

type ShiftLow struct {
Expr1, Expr2 Gen
}

func (s ShiftLow) Append(to []byte) []byte {
to = s.Expr1.Append(to)
to = append(to, shiftLow...)
to = s.Expr2.Append(to)
return to
}

type ShiftLowAssign struct {
Expr1, Expr2 Gen
}

func (s ShiftLowAssign) Append(to []byte) []byte {
to = s.Expr1.Append(to)
to = append(to, space+shiftLow+assign+space...)
to = s.Expr2.Append(to)
return to
}

type Sizeof struct {
What Gen
}

func (s Sizeof) Append(to []byte) []byte {
to = append(to, sizeof...)
to = Paren{s.What}.Append(to)
return to
}

type Spaced []Gen

func (s Spaced) Append(to []byte) []byte {
first := true
for _, gen := range s {
if gen == nil {
continue
}
if first {
first = false
} else {
to = append(to, space...)
}
to = gen.Append(to)
}
return to
}

type Static struct {
Tail Gen
}

func (s Static) Append(to []byte) []byte {
to = append(to, static+space...)
to = s.Tail.Append(to)
return to
}

type StaticFuncDef FuncDef

func (s StaticFuncDef) Append(to []byte) []byte {
to = Static{FuncDef(s)}.Append(to)
return to
}

type Stmts []Gen

func (s Stmts) Append(to []byte) []byte {
for _, gen := range s {
if gen == nil {
continue
}
n1 := len(to)
to = gen.Append(to)
n2 := len(to)
if n1 >= n2 {
continue
}
switch to[n2-1] {
case newline[0]:
case brace2[0], semicolon[0]:
to = append(to, newline...)
default:
to = append(to, semicolon+newline...)
}
}
return to
}

type StructDef struct {
Name string
Fields Gen
Attrs Gen
}

func (s StructDef) Append(to []byte) []byte {
var g1, g2, g3 Gen
g1, g2 = StructTag(s.Name), Block{s.Fields}
if s.Attrs != nil {
g3 = AttrSpec{s.Attrs}
}
to = Spaced{g1, g2, g3}.Append(to)
to = append(to, semicolon+newline...)
return to
}

type StructFwd string

func (s StructFwd) Append(to []byte) []byte {
var g1, g2 Gen
g1, g2 = StructTag(s), Vb(s)
to = Typedef{g1, g2}.Append(to)
return to
}

type StructTag string

func (s StructTag) Append(to []byte) []byte {
to = append(to, struct_+space...)
to = append(to, s...)
return to
}

type Sub struct {
Expr1, Expr2 Gen
}

func (s Sub) Append(to []byte) []byte {
to = s.Expr1.Append(to)
to = append(to, minus...)
to = s.Expr2.Append(to)
return to
}

type SubAssign struct {
Expr1, Expr2 Gen
}

func (s SubAssign) Append(to []byte) []byte {
to = s.Expr1.Append(to)
to = append(to, space+minus+assign+space...)
to = s.Expr2.Append(to)
return to
}

type Switch struct {
Expr, Cases Gen
}

func (s Switch) Append(to []byte) []byte {
to = append(to, switch_+space...)
to = Paren{s.Expr}.Append(to)
to = append(to, space...)
to = Block{s.Cases}.Append(to)
return to
}

type Table struct {
Flat []Gen
Cols int
}

func (t Table) Append(to []byte) []byte {
last := t.Cols - 1
if last < 0 {
return to
}
var text []byte
sizes := make([]int, 0, len(t.Flat))
maxes := make([]int, last)
col := 0
for i := range t.Flat {
if col == last {
col = 0
continue
}
was := len(text)
if gen := t.Flat[i]; gen != nil {
text = gen.Append(text)
}
size := len(text) - was
sizes = append(sizes, size)
if maxes[col] < size {
maxes[col] = size
}
col += 1
}
most := 0
for _, max := range maxes {
if most < max {
most = max
}
}
sp, nl := space[0], newline[0]
spaces := make([]byte, most+1)
for i := range spaces {
spaces[i] = sp
}
for i := range t.Flat {
if col == last {
was := len(to)
if gen := t.Flat[i]; gen != nil {
to = gen.Append(to)
}
now := len(to)
if was >= now || to[now-1] != nl {
to = append(to, nl)
}
col = 0
continue
}
size := sizes[0]
sizes = sizes[1:]
to = append(to, text[:size]...)
text = text[size:]
fill := maxes[col] - size + 1
to = append(to, spaces[:fill]...)
col += 1
}
return to
}

type Ternary struct {
Cond, Then, Else Gen
}

func (t Ternary) Append(to []byte) []byte {
to = t.Cond.Append(to)
to = append(to, space+questionMark+space...)
to = t.Then.Append(to)
to = append(to, space+colon+space...)
to = t.Else.Append(to)
return to
}

type Typedef struct {
Type, What Gen
}

func (t Typedef) Append(to []byte) []byte {
to = append(to, typedef+space...)
to = Var{t.Type, t.What, nil}.Append(to)
to = append(to, newline...)
return to
}

type TypedefPtrFunc struct {
ReturnType, What, Params Gen
}

func (t TypedefPtrFunc) Append(to []byte) []byte {
var call Gen
call = Call{Paren{At{t.What}}, t.Params}
to = Typedef{t.ReturnType, call}.Append(to)
return to
}

type Unlikely struct {
Cond Gen
}

func (u Unlikely) Append(to []byte) []byte {
var args Gen
args = CommaSpaced{u.Cond, Zero}
to = Call{Vb(expect), args}.Append(to)
return to
}

type Var struct {
Type, What, Init Gen
}

func (v Var) Append(to []byte) []byte {
to = v.Type.Append(to)
to = append(to, space...)
to = v.What.Append(to)
if v.Init != nil {
to = append(to, space+assign+space...)
to = v.Init.Append(to)
}
to = append(to, semicolon...)
return to
}

type Vb string

func (v Vb) Append(to []byte) []byte {
to = append(to, v...)
return to
}

type Xor struct {
Expr1, Expr2 Gen
}

func (x Xor) Append(to []byte) []byte {
to = x.Expr1.Append(to)
to = append(to, caret...)
to = x.Expr2.Append(to)
return to
}

type XorAssign struct {
Expr1, Expr2 Gen
}

func (x XorAssign) Append(to []byte) []byte {
to = x.Expr1.Append(to)
to = append(to, space+caret+assign+space...)
to = x.Expr2.Append(to)
return to
}

var (
BitsPerByte Gen = IntLit(8)
BitsPerLong Gen = Paren{Mul{Sizeof{Long}, BitsPerByte}}
Brace1 Gen = Vb(brace1)
Brace2 Gen = Vb(brace2)
Break Gen = Vb(break_)
Calloc Gen = Vb(calloc)
Char Gen = Vb(char)
Continue Gen = Vb(continue_)
Cplusplus Gen = Vb(cplusplus)
CpuSupports Gen = Vb(cpuSupports)
Ctzl Gen = Vb(ctzl)
Ellipsis Gen = Vb(ellipsis)
Errno Gen = Vb(errno)
Float Gen = Vb(float)
Free Gen = Vb(free)
Gap Gen = Vb(gap)
Int64T Gen = Vb(int64T)
Int Gen = Vb(int_)
LineNum Gen = Vb(lineNum)
LinkageC Gen = DoubleQuoted(linkageC)
Long Gen = Vb(long)
Malloc Gen = Vb(malloc)
Memcpy Gen = Vb(memcpy)
Memset Gen = Vb(memset)
NegOne Gen = Neg{One}
Newline Gen = Vb(newline)
Once Gen = Vb(once)
One Gen = Vb(one)
Packed Gen = Vb(packed)
PragmaOnce Gen = Preprocessor{Pragma, Once}
PthreadCondDestroy Gen = Vb(pthreadCondDestroy)
PthreadCondInit Gen = Vb(pthreadCondInit)
PthreadCondSignal Gen = Vb(pthreadCondSignal)
PthreadCondT Gen = Vb(pthreadCondT)
PthreadCondWait Gen = Vb(pthreadCondWait)
PthreadCreate Gen = Vb(pthreadCreate)
PthreadJoin Gen = Vb(pthreadJoin)
PthreadMutexDestroy Gen = Vb(pthreadMutexDestroy)
PthreadMutexInit Gen = Vb(pthreadMutexInit)
PthreadMutexLock Gen = Vb(pthreadMutexLock)
PthreadMutexT Gen = Vb(pthreadMutexT)
PthreadMutexUnlock Gen = Vb(pthreadMutexUnlock)
PthreadT Gen = Vb(pthreadT)
PtrChar Gen = Ptr{Char}
PtrdiffT Gen = Vb(ptrdiffT)
PtrFloat Gen = Ptr{Float}
PtrInt64T Gen = Ptr{Int64T}
PtrPthreadT Gen = Ptr{PthreadT}
PtrPtrChar Gen = Ptr{PtrChar}
PtrPtrVoid Gen = Ptr{PtrVoid}
PtrVoid Gen = Ptr{Void}
RestrictPtrChar Gen = RestrictPtr{Char}
RestrictPtrFloat Gen = RestrictPtr{Float}
RestrictPtrInt64T Gen = RestrictPtr{Int64T}
SizeT Gen = Vb(sizeT)
Sprintf Gen = Vb(sprintf)
VaEnd Gen = Vb(vaEnd)
VaList Gen = Vb(vaList)
VaStart Gen = Vb(vaStart)
Void Gen = Vb(void)
Vsnprintf Gen = Vb(vsnprintf)
Zero Gen = Vb(zero)
Zeros Gen = Brace{Zero}
)

var Linkage1 Gen = Gens{
Preprocessor{Ifdef, Cplusplus},
Extern{Spaced{LinkageC, Brace1, Gap}}, Newline,
Preprocessor{Endif, nil},
}

var Linkage2 Gen = Gens{
Preprocessor{Ifdef, Cplusplus},
Spaced{Gap, Brace2}, Newline,
Preprocessor{Endif, nil},
}

Top || internal/compile/author/cov/cov.go

package cov

func Rect(area, max1, n1, n2 int) (s1, s2 int) {
best := -1
eval := func(t1, t2 int) {
if t1 > max1 {
return
}
var (
fit1 = n1 / t1
rem1 = n1 % t1
fit2 = n2 / t2
rem2 = n2 % t2
)
cost := 0
if rem1 > 0 {
cost += (area - rem1*t2) * fit2
if rem2 > 0 {
cost += area - rem1*rem2
}
}
if rem2 > 0 {
cost += (area - t1*rem2) * fit1
}
if best == -1 ||
best > cost ||
best == cost && s1 > t1 {
best = cost
s1, s2 = t1, t2
}
}
for lo := 1; lo*lo <= area; lo++ {
hi := area / lo
if lo*hi == area {
eval(lo, hi)
if lo != hi {
eval(hi, lo)
}
}
}
return
}

func Box(vol, max1, n1, n2, n3 int) (s1, s2, s3 int) {
best := -1
eval := func(t1, t2, t3 int) {
if t1 > max1 {
return
}
var (
dim1 = (n1 + t1 - 1) / t1 * t1
dim2 = (n2 + t2 - 1) / t2 * t2
dim3 = (n3 + t3 - 1) / t3 * t3
)
cost := dim1 * dim2 * dim3
if best == -1 ||
best > cost ||
best == cost && (s1 > t1 ||
s1 == t1 && s2 > t2) {
best = cost
s1, s2, s3 = t1, t2, t3
}
}
for lo := 1; lo*lo*lo <= vol; lo++ {
area := vol / lo
if lo*area != vol {
continue
}
for md := lo; md*md <= area; md++ {
hi := area / md
if md*hi != area {
continue
}
eval(lo, md, hi)
if lo != md {
eval(md, lo, hi)
eval(md, hi, lo)
}
if md != hi {
eval(lo, hi, md)
eval(hi, lo, md)
if lo != md {
eval(hi, md, lo)
}
}
}
}
return
}

Top || internal/compile/author/cpu/cpu.go

package cpu

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/errmsg"
"NN-512/internal/raw"
)

type Chk struct {
Platform raw.Platform
Emc *errmsg.Ctx
}

func (c *Chk) Append(to []byte) []byte {
switch c.Platform {
case raw.AVX512Float32:
return c.m512().Append(to)
default:
panic("bug")
}
}

func (c *Chk) m512() cgen.Gen {
call := cgen.Call{
Func: cgen.CpuSupports,
Args: cgen.DoubleQuoted("avx512f"),
}
return &errmsg.FormatIf{
Ctx: c.Emc,
Cond: cgen.IsZero{Expr: call},
Format: "CPU does not support AVX512F",
}
}

Top || internal/compile/author/elwi/elwi.go

package elwi

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/cov"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

type Ctx struct {
prefix string
platform raw.Platform
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Elwi",
platform: pl.Config.Platform,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]string),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

func (c *Ctx) lanes() int {
switch c.platform {
case raw.AVX512Float32:
return 16
default:
panic("bug")
}
}

type Spec struct {
Channels int
Height int
Width int
ElemBytes int
Pitch1Bytes []int
Pitch2Bytes []int
Ops [][]mod.Op
}

func enough(ctx *Ctx, spec *Spec) int {
var (
cost = len(spec.Pitch1Bytes)
mul int
)
switch ctx.platform {
case raw.AVX512Float32:
const lo = 8
if mul = 512 / cost; mul < lo {
mul = lo
}
default:
panic("bug")
}
return ctx.lanes() * mul
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func cast(pitch int) cgen.Gen {
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: il(pitch),
}
}

func addr(ptr, pitch, idx cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: ptr,
Expr2: cgen.Mul{
Expr1: pitch,
Expr2: idx,
},
}
}

func mix(a []cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Call struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
funcName string
}

func (c *Call) Prep() cgen.Gen {
sig := fmt.Sprintf("%v", c.Spec)
if prior, ok := c.dedup[sig]; ok {
c.funcName = prior
return nil
}
c.funcName = c.name(c.prefix)
c.dedup[sig] = c.funcName
const (
formPacked = iota
formSemipacked
formUnpacked
)
var (
form = formPacked
tight1 = c.Width * c.ElemBytes
tight2 = c.Height * tight1
funcs cgen.Gen
)
for i, pitch1 := range c.Pitch1Bytes {
if pitch1 != tight1 {
form = formUnpacked
break
}
if c.Pitch2Bytes[i] != tight2 {
form = formSemipacked
}
}
if form == formPacked {
outer:
for _, ops := range c.Ops {
for i := range ops {
if ops[i].Kind == mod.Bn {
form = formSemipacked
break outer
}
}
}
}
switch form {
case formPacked:
funcs = &packed{
Ctx: c.Ctx,
Spec: c.Spec,
FuncName: c.funcName,
}
case formSemipacked:
funcs = &semipacked{
Ctx: c.Ctx,
Spec: c.Spec,
FuncName: c.funcName,
}
case formUnpacked:
funcs = &unpacked{
Ctx: c.Ctx,
Spec: c.Spec,
FuncName: c.funcName,
}
}
return cgen.Gens{
funcs, cgen.Newline,
}
}

func (c *Call) Append(to []byte) []byte {
var (
tensors = vb(c.name("tensors"))
ptrs = cgen.CommaLines(c.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(c.funcName),
Args: cgen.CommaSpaced{
c.Team, tensors,
},
},
}.Append(to)
}

type unpacked struct {
*Ctx
*Spec
FuncName string
wTile int
wTiles int
wScrap int
hTile int
hTiles int
hScrap int
cTile int
cTiles int
cScrap int
funcName string
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
}

func (u *unpacked) Append(to []byte) []byte {
elems := enough(u.Ctx, u.Spec)
if u.Width >= elems {
fit := u.Width / elems
elems = u.Width / fit
elems -= elems % u.lanes()
u.wTile = elems
u.wTiles = fit
u.wScrap = u.Width - elems*fit
if u.wScrap > 0 {
u.wTiles--
u.wScrap += elems
}
u.hTile = 1
u.hTiles = u.Height
u.hScrap = 0
u.cTile = 1
u.cTiles = u.Channels
u.cScrap = 0
} else if hw := u.Height * u.Width; hw >= elems {
var (
rows = (elems + u.Width - 1) / u.Width
fit = u.Height / rows
)
rows = u.Height / fit
u.wTile = u.Width
u.wTiles = 1
u.wScrap = 0
u.hTile = rows
u.hTiles = fit
u.hScrap = u.Height - rows*fit
if u.hScrap > 0 {
u.hTiles--
u.hScrap += rows
}
u.cTile = 1
u.cTiles = u.Channels
u.cScrap = 0
} else {
chans := (elems + hw - 1) / hw
u.wTile = u.Width
u.wTiles = 1
u.wScrap = 0
u.hTile = u.Height
u.hTiles = 1
u.hScrap = 0
u.cTile = chans
u.cTiles = u.Channels / chans
u.cScrap = u.Channels % chans
}
u.funcName = u.name(u.FuncName + "Callee")
var (
team = vb(u.name("team"))
tensors = vb(u.name("tensors"))
wHull = u.wTiles
hHull = u.hTiles
cHull = u.cTiles
)
if u.wScrap > 0 {
wHull++
}
if u.hScrap > 0 {
hHull++
}
if u.cScrap > 0 {
cHull++
}
return cgen.Gens{
u.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: u.FuncName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: u.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: u.tc,
Callee: vb(u.funcName),
Any: tensors,
Hull: []cgen.Gen{
il(wHull),
il(hHull),
il(cHull),
},
Team: team,
},
},
}.Append(to)
}

func (u *unpacked) calleeFunc() cgen.Gen {
var (
body = make(cgen.Stmts, 7)
tensors = vb(u.name("tensors"))
w = vb(u.name("w"))
h = vb(u.name("h"))
c = vb(u.name("c"))
)
callee := &threader.Callee{
Ctx: u.tc,
Name: u.funcName,
Task: vb(u.name("task")),
Pt: vb(u.name("pt")),
}
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrdiffT, What: w,
Init: cgen.Elem{Arr: callee.Pt, Idx: il(0)},
}
body[2] = cgen.Var{
Type: cgen.PtrdiffT, What: h,
Init: cgen.Elem{Arr: callee.Pt, Idx: il(1)},
}
body[3] = cgen.Var{
Type: cgen.PtrdiffT, What: c,
Init: cgen.Elem{Arr: callee.Pt, Idx: il(2)},
}
body[4] = u.ptrs(tensors, w, h, c)
doIf := func(do, i cgen.Gen, n int) cgen.Gen {
return cgen.If{
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(n),
},
Then: cgen.Stmts{
do,
cgen.Return{},
},
}
}
wSplit := func(chans, rows int) cgen.Gen {
stmts := make(cgen.Stmts, 2)
if u.wTiles > 0 {
k := u.kernel(chans, rows, u.wTile)
if u.wScrap > 0 {
stmts[0] = doIf(k, w, u.wTiles)
} else {
stmts[0] = k
}
}
if u.wScrap > 0 {
stmts[1] = u.kernel(chans, rows, u.wScrap)
}
return stmts
}
hSplit := func(chans int) cgen.Gen {
stmts := make(cgen.Stmts, 2)
if u.hTiles > 0 {
ws := wSplit(chans, u.hTile)
if u.hScrap > 0 {
stmts[0] = doIf(ws, h, u.hTiles)
} else {
stmts[0] = ws
}
}
if u.hScrap > 0 {
stmts[1] = wSplit(chans, u.hScrap)
}
return stmts
}
if u.cTiles > 0 {
hs := hSplit(u.cTile)
if u.cScrap > 0 {
body[5] = doIf(hs, c, u.cTiles)
} else {
body[5] = hs
}
}
if u.cScrap > 0 {
body[6] = hSplit(u.cScrap)
}
return callee.Func(body)
}

func (u *unpacked) ptrs(tensors, w, h, c cgen.Gen) cgen.Gen {
var (
stmts cgen.Stmts
tensorIdx = 0
datPtrIdx = 0
)
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
datPtr := func() {
ptr := vb(u.name("ptr"))
u.datPtrs = append(u.datPtrs, ptr)
var (
wPitch = u.wTile * u.ElemBytes
hPitch = u.hTile * u.Pitch1Bytes[datPtrIdx]
cPitch = u.cTile * u.Pitch2Bytes[datPtrIdx]
a1 = tensor()
a2 = addr(a1, cast(wPitch), w)
a3 = addr(a2, cast(hPitch), h)
a4 = addr(a3, cast(cPitch), c)
)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr, Init: a4,
})
datPtrIdx++
}
ndp := func(n int) {
for ; n > 0; n-- {
datPtr()
}
}
bnPtr := func() {
ptr := vb(u.name("ptr"))
u.bnPtrs = append(u.bnPtrs, ptr)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr,
Init: &bn.Offset{
Ctx: u.bc,
Mas: tensor(),
Channel: cgen.Mul{
Expr1: il(u.cTile),
Expr2: c,
},
},
})
}
for i, ops := range u.Ops {
if i < len(u.Ops)-1 {
datPtr()
}
for j := range ops {
switch op := &ops[j]; op.Kind {
case mod.Add:
ndp(op.Int)
case mod.Bn:
bnPtr()
case mod.ReLU:
default:
panic("bug")
}
}
}
ndp(len(u.Pitch1Bytes) - datPtrIdx)
return stmts
}

func (u *unpacked) kernel(chans, rows, elems int) cgen.Gen {
switch u.platform {
case raw.AVX512Float32:
return u.m512(chans, rows, elems)
default:
panic("bug")
}
}

func (u *unpacked) m512(chans, rows, elems int) cgen.Gen {
const (
lanes = 16
laneBytes = 4
)
unroll := 6 - len(u.datPtrs)
if unroll < 1 {
unroll = 1
}
iUnroll, jUnroll, kUnroll := cov.Box(
unroll, unroll, chans, rows, (elems+lanes-1)/lanes,
)
var (
iIters = chans / iUnroll
iAfter = chans % iUnroll
bnMuls = make([][]cgen.Gen, iUnroll)
bnAdds = make([][]cgen.Gen, iUnroll)
jIters = rows / jUnroll
jAfter = rows % jUnroll
kIters = elems / (kUnroll * lanes)
kAfter = elems % (kUnroll * lanes)
)
leaf := func(i, j, k cgen.Gen, ii, jj, kk, l int) cgen.Stmts {
cell := &m512Cell{
Ctx: u.Ctx,
Spec: u.Spec,
Lanes: l,
Ptrs: make([]cgen.Gen, len(u.datPtrs)),
BnMuls: bnMuls[ii],
BnAdds: bnAdds[ii],
}
for x, ptr := range u.datPtrs {
var (
iiPitch = u.Pitch2Bytes[x]
jjPitch = u.Pitch1Bytes[x]
kkPitch = lanes * laneBytes
iPitch = iiPitch * iUnroll
jPitch = jjPitch * jUnroll
kPitch = kkPitch * kUnroll
)
ptr = cgen.Add{
Expr1: ptr,
Expr2: cast(iiPitch*ii + jjPitch*jj + kkPitch*kk),
}
if iIters > 0 {
ptr = addr(ptr, cast(iPitch), i)
}
if jIters > 0 {
ptr = addr(ptr, cast(jPitch), j)
}
if kIters > 0 {
ptr = addr(ptr, cast(kPitch), k)
}
cell.Ptrs[x] = ptr
}
return cell.Stmts()
}
kSplit := func(i, j cgen.Gen, iCnt, jCnt int) cgen.Gen {
stmts := make(cgen.Stmts, 2)
if kIters > 0 {
var (
body = make([]cgen.Stmts, iCnt*jCnt*kUnroll)
k = vb(u.name("k"))
)
for ii := 0; ii < iCnt; ii++ {
for jj := 0; jj < jCnt; jj++ {
for kk := 0; kk < kUnroll; kk++ {
x := (ii*jCnt+jj)*kUnroll + kk
body[x] = leaf(i, j, k, ii, jj, kk, lanes)
}
}
}
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: k,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: k, Expr2: il(kIters),
},
Post: cgen.IncPre{Expr: k},
Body: mix(body),
}
}
if kAfter > 0 {
var (
full = kAfter / lanes
part = kAfter % lanes
tail = make([]cgen.Stmts, iCnt*jCnt*kUnroll)
k = il(kIters)
)
for ii := 0; ii < iCnt; ii++ {
for jj := 0; jj < jCnt; jj++ {
for kk := 0; kk <= full; kk++ {
var (
x = (ii*jCnt+jj)*kUnroll + kk
l = lanes
)
if kk == full {
l = part
}
if l > 0 {
tail[x] = leaf(i, j, k, ii, jj, kk, l)
}
}
}
}
stmts[1] = mix(tail)
}
return stmts
}
jSplit := func(i cgen.Gen, iCnt int) cgen.Gen {
stmts := make(cgen.Stmts, 2)
if jIters > 0 {
j := vb(u.name("j"))
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: j, Expr2: il(jIters),
},
Post: cgen.IncPre{Expr: j},
Body: kSplit(i, j, iCnt, jUnroll),
}
}
if jAfter > 0 {
j := il(jIters)
stmts[1] = kSplit(i, j, iCnt, jAfter)
}
return stmts
}
iBlock := func(i cgen.Gen, iCnt int) cgen.Gen {
var (
bnLds = make([]cgen.Stmts, iCnt)
bnCnt = len(u.bnPtrs)
)
for ii := 0; ii < iCnt; ii++ {
var (
ch = il(ii)
muls = make([]cgen.Gen, bnCnt)
adds = make([]cgen.Gen, bnCnt)
lds = make(cgen.Stmts, bnCnt)
)
if iIters > 0 {
ch = cgen.Paren{
Inner: addr(ch, cast(iUnroll), i),
}
}
for x, ptr := range u.bnPtrs {
var (
bnMul = vb(u.name("bnMul"))
bnAdd = vb(u.name("bnAdd"))
)
muls[x] = bnMul
adds[x] = bnAdd
lds[x] = &bn.Load{
Ctx: u.bc,
Mas: ptr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
bnMuls[ii] = muls
bnAdds[ii] = adds
bnLds[ii] = lds
}
return cgen.Gens{
mix(bnLds), jSplit(i, iCnt),
}
}
stmts := make(cgen.Stmts, 2)
if iIters > 0 {
i := vb(u.name("i"))
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: i, Expr2: il(iIters),
},
Post: cgen.IncPre{Expr: i},
Body: iBlock(i, iUnroll),
}
}
if iAfter > 0 {
i := il(iIters)
stmts[1] = iBlock(i, iAfter)
}
return stmts
}

type semipacked struct {
*Ctx
*Spec
FuncName string
elemTile int
elemTiles int
elemScrap int
chanTile int
chanTiles int
chanScrap int
funcName string
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
}

func (s *semipacked) Append(to []byte) []byte {
var (
hw = s.Height * s.Width
elems = enough(s.Ctx, s.Spec)
)
if hw >= elems {
fit := hw / elems
elems = hw / fit
elems -= elems % s.lanes()
s.elemTile = elems
s.elemTiles = fit
s.elemScrap = hw - elems*fit
if s.elemScrap > 0 {
s.elemTiles--
s.elemScrap += elems
}
s.chanTile = 1
s.chanTiles = s.Channels
s.chanScrap = 0
} else {
chans := elems / hw
if chans*hw < elems {
chans++
}
s.elemTile = hw
s.elemTiles = 1
s.elemScrap = 0
s.chanTile = chans
s.chanTiles = s.Channels / chans
s.chanScrap = s.Channels % chans
}
s.funcName = s.name(s.FuncName + "Callee")
var (
team = vb(s.name("team"))
tensors = vb(s.name("tensors"))
elemHull = s.elemTiles
chanHull = s.chanTiles
)
if s.elemScrap > 0 {
elemHull++
}
if s.chanScrap > 0 {
chanHull++
}
return cgen.Gens{
s.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: s.FuncName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: s.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: s.tc,
Callee: vb(s.funcName),
Any: tensors,
Hull: []cgen.Gen{
il(elemHull),
il(chanHull),
},
Team: team,
},
},
}.Append(to)
}

func (s *semipacked) calleeFunc() cgen.Gen {
var (
body = make(cgen.Stmts, 6)
tensors = vb(s.name("tensors"))
e = vb(s.name("e"))
c = vb(s.name("c"))
)
callee := &threader.Callee{
Ctx: s.tc,
Name: s.funcName,
Task: vb(s.name("task")),
Pt: vb(s.name("pt")),
}
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrdiffT, What: e,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.Zero},
}
body[2] = cgen.Var{
Type: cgen.PtrdiffT, What: c,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.One},
}
body[3] = s.ptrs(tensors, e, c)
doIf := func(do, i cgen.Gen, n int) cgen.Gen {
return cgen.If{
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(n),
},
Then: cgen.Stmts{
do,
cgen.Return{},
},
}
}
kernels := func(chans int) cgen.Gen {
stmts := make(cgen.Stmts, 2)
if s.elemTiles > 0 {
k := s.kernel(chans, s.elemTile)
if s.elemScrap > 0 {
stmts[0] = doIf(k, e, s.elemTiles)
} else {
stmts[0] = k
}
}
if s.elemScrap > 0 {
stmts[1] = s.kernel(chans, s.elemScrap)
}
return stmts
}
if s.chanTiles > 0 {
ks := kernels(s.chanTile)
if s.chanScrap > 0 {
body[4] = doIf(ks, c, s.chanTiles)
} else {
body[4] = ks
}
}
if s.chanScrap > 0 {
body[5] = kernels(s.chanScrap)
}
return callee.Func(body)
}

func (s *semipacked) ptrs(tensors, e, c cgen.Gen) cgen.Gen {
var (
stmts cgen.Stmts
tensorIdx = 0
datPtrIdx = 0
)
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
pitch := func() int {
i := datPtrIdx
datPtrIdx++
return s.Pitch2Bytes[i]
}
datPtr := func() {
ptr := vb(s.name("ptr"))
s.datPtrs = append(s.datPtrs, ptr)
var (
ePitch = s.elemTile * s.ElemBytes
cPitch = s.chanTile * pitch()
a1 = tensor()
a2 = addr(a1, cast(ePitch), e)
a3 = addr(a2, cast(cPitch), c)
)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr, Init: a3,
})
}
ndp := func(n int) {
for ; n > 0; n-- {
datPtr()
}
}
bnPtr := func() {
ptr := vb(s.name("ptr"))
s.bnPtrs = append(s.bnPtrs, ptr)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr,
Init: &bn.Offset{
Ctx: s.bc,
Mas: tensor(),
Channel: cgen.Mul{
Expr1: il(s.chanTile),
Expr2: c,
},
},
})
}
for i, ops := range s.Ops {
if i < len(s.Ops)-1 {
datPtr()
}
for j := range ops {
switch op := &ops[j]; op.Kind {
case mod.Add:
ndp(op.Int)
case mod.Bn:
bnPtr()
case mod.ReLU:
default:
panic("bug")
}
}
}
ndp(len(s.Pitch2Bytes) - datPtrIdx)
return stmts
}

func (s *semipacked) kernel(chans, elems int) cgen.Gen {
switch s.platform {
case raw.AVX512Float32:
return s.m512(chans, elems)
default:
panic("bug")
}
}

func (s *semipacked) m512(chans, elems int) cgen.Gen {
const (
lanes = 16
laneBytes = 4
)
unroll := 6 - len(s.datPtrs)
if unroll < 1 {
unroll = 1
}
iUnroll, jUnroll := cov.Rect(
unroll, unroll, chans, (elems+lanes-1)/lanes,
)
var (
iIters = chans / iUnroll
iAfter = chans % iUnroll
bnMuls = make([][]cgen.Gen, iUnroll)
bnAdds = make([][]cgen.Gen, iUnroll)
jIters = elems / (jUnroll * lanes)
jAfter = elems % (jUnroll * lanes)
)
leaf := func(i, j cgen.Gen, ii, jj, l int) cgen.Stmts {
cell := &m512Cell{
Ctx: s.Ctx,
Spec: s.Spec,
Lanes: l,
Ptrs: make([]cgen.Gen, len(s.datPtrs)),
BnMuls: bnMuls[ii],
BnAdds: bnAdds[ii],
}
for k, ptr := range s.datPtrs {
var (
iiPitch = s.Pitch2Bytes[k]
jjPitch = lanes * laneBytes
iPitch = iiPitch * iUnroll
jPitch = jjPitch * jUnroll
)
ptr = cgen.Add{
Expr1: ptr,
Expr2: cast(iiPitch*ii + jjPitch*jj),
}
if iIters > 0 {
ptr = addr(ptr, cast(iPitch), i)
}
if jIters > 0 {
ptr = addr(ptr, cast(jPitch), j)
}
cell.Ptrs[k] = ptr
}
return cell.Stmts()
}
inner := func(i cgen.Gen, iCnt int) cgen.Gen {
jSplit := make(cgen.Stmts, 2)
if jIters > 0 {
var (
body = make([]cgen.Stmts, iCnt*jUnroll)
j = vb(s.name("j"))
)
for ii := 0; ii < iCnt; ii++ {
for jj := 0; jj < jUnroll; jj++ {
k := ii*jUnroll + jj
body[k] = leaf(i, j, ii, jj, lanes)
}
}
jSplit[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: j, Expr2: il(jIters),
},
Post: cgen.IncPre{Expr: j},
Body: mix(body),
}
}
if jAfter > 0 {
var (
full = jAfter / lanes
part = jAfter % lanes
tail = make([]cgen.Stmts, iCnt*jUnroll)
j = il(jIters)
)
for ii := 0; ii < iCnt; ii++ {
for jj := 0; jj <= full; jj++ {
var (
k = ii*jUnroll + jj
l = lanes
)
if jj == full {
l = part
}
if l > 0 {
tail[k] = leaf(i, j, ii, jj, l)
}
}
}
jSplit[1] = mix(tail)
}
return jSplit
}
outer := func(i cgen.Gen, iCnt int) cgen.Gen {
var (
bnLds = make([]cgen.Stmts, iCnt)
bnCnt = len(s.bnPtrs)
)
for ii := 0; ii < iCnt; ii++ {
var (
ch = il(ii)
muls = make([]cgen.Gen, bnCnt)
adds = make([]cgen.Gen, bnCnt)
lds = make(cgen.Stmts, bnCnt)
)
if iIters > 0 {
ch = cgen.Paren{
Inner: addr(ch, cast(iUnroll), i),
}
}
for k, ptr := range s.bnPtrs {
var (
bnMul = vb(s.name("bnMul"))
bnAdd = vb(s.name("bnAdd"))
)
muls[k] = bnMul
adds[k] = bnAdd
lds[k] = &bn.Load{
Ctx: s.bc,
Mas: ptr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
bnMuls[ii] = muls
bnAdds[ii] = adds
bnLds[ii] = lds
}
return cgen.Gens{
mix(bnLds), inner(i, iCnt),
}
}
iSplit := make(cgen.Stmts, 2)
if iIters > 0 {
i := vb(s.name("i"))
iSplit[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: i, Expr2: il(iIters),
},
Post: cgen.IncPre{Expr: i},
Body: outer(i, iUnroll),
}
}
if iAfter > 0 {
i := il(iIters)
iSplit[1] = outer(i, iAfter)
}
return iSplit
}

type packed struct {
*Ctx
*Spec
FuncName string
grain int
grains int
remain int
funcName string
ptrs []cgen.Gen
}

func (p *packed) Append(to []byte) []byte {
p.grain = enough(p.Ctx, p.Spec)
elems := p.Channels * p.Height * p.Width
p.grains = elems / p.grain
p.remain = elems % p.grain
p.funcName = p.name(p.FuncName + "Callee")
var (
team = vb(p.name("team"))
tensors = vb(p.name("tensors"))
hull = p.grains
)
if p.remain > 0 {
hull++
}
return cgen.Gens{
p.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.FuncName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: p.tc.PtrTeam, What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar, What: tensors,
},
},
Body: &threader.Do{
Ctx: p.tc,
Callee: vb(p.funcName),
Any: tensors,
Hull: []cgen.Gen{il(hull)},
Team: team,
},
},
}.Append(to)
}

func (p *packed) calleeFunc() cgen.Gen {
var (
body = make(cgen.Stmts, 5)
tensors = vb(p.name("tensors"))
i = vb(p.name("i"))
)
callee := &threader.Callee{
Ctx: p.tc,
Name: p.funcName,
Task: vb(p.name("task")),
Pt: vb(p.name("pt")),
}
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.Zero},
}
body[2] = p.loadPtrs(tensors, i)
if p.grains > 0 {
body[3] = p.kernel(p.grain)
if p.remain > 0 {
body[3] = cgen.If{
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(p.grains),
},
Then: cgen.Stmts{
body[3],
cgen.Return{},
},
}
}
}
if p.remain > 0 {
body[4] = p.kernel(p.remain)
}
return callee.Func(body)
}

func (p *packed) loadPtrs(tensors, i cgen.Gen) cgen.Gen {
var (
n = len(p.Pitch1Bytes)
stmts = make(cgen.Stmts, n)
pitch = cast(p.grain * p.ElemBytes)
)
p.ptrs = make([]cgen.Gen, n)
for j := range p.ptrs {
p.ptrs[j] = vb(p.name("ptr"))
var (
a1 = cgen.Elem{Arr: tensors, Idx: il(j)}
a2 = addr(a1, pitch, i)
)
stmts[j] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: p.ptrs[j], Init: a2,
}
}
return stmts
}

func (p *packed) kernel(elems int) cgen.Gen {
switch p.platform {
case raw.AVX512Float32:
return p.m512(elems)
default:
panic("bug")
}
}

func (p *packed) m512(elems int) cgen.Gen {
const (
unroll = 4
lanes = 16
iterElems = unroll * lanes
)
var (
iters = elems / iterElems
after = elems % iterElems
)
code := func(j cgen.Gen, k, l int) cgen.Stmts {
cell := &m512Cell{
Ctx: p.Ctx,
Spec: p.Spec,
Lanes: l,
Ptrs: make([]cgen.Gen, len(p.ptrs)),
}
for i, ptr := range p.ptrs {
const (
laneBytes = 4
kPitch = lanes * laneBytes
jPitch = unroll * kPitch
)
ptr = addr(ptr, cast(kPitch), il(k))
if iters > 0 {
ptr = addr(ptr, cast(jPitch), j)
}
cell.Ptrs[i] = ptr
}
return cell.Stmts()
}
stmts := make(cgen.Stmts, 2)
if iters > 0 {
var (
inner = make([]cgen.Stmts, unroll)
j = vb(p.name("j"))
)
for k := 0; k < unroll; k++ {
inner[k] = code(j, k, lanes)
}
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: j, Expr2: il(iters),
},
Post: cgen.IncPre{Expr: j},
Body: mix(inner),
}
}
if after > 0 {
var (
full = after / lanes
part = after % lanes
outer = make([]cgen.Stmts, full, full+1)
j = il(iters)
)
for k := 0; k < full; k++ {
outer[k] = code(j, k, lanes)
}
if part > 0 {
last := code(j, full, part)
outer = append(outer, last)
}
stmts[1] = mix(outer)
}
return stmts
}

type m512Cell struct {
*Ctx
*Spec
Lanes int
Ptrs []cgen.Gen
BnMuls []cgen.Gen
BnAdds []cgen.Gen
mask cgen.Gen
nextPtr int
nextBn int
loads cgen.Stmts
nonloads cgen.Stmts
}

func (m *m512Cell) ptr() cgen.Gen {
i := m.nextPtr
m.nextPtr = i + 1
return m.Ptrs[i]
}

func (m *m512Cell) load() cgen.Gen {
dat := vb(m.name("dat"))
m.loads = append(m.loads, cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
m.mask, m.ptr(),
},
})
return dat
}

func (m *m512Cell) nonload(a cgen.Gen) {
m.nonloads = append(m.nonloads, a)
}

func (m *m512Cell) adder(dats []cgen.Gen) {
for n := len(dats); n > 1; {
fold := n >> 1
n -= fold
for i := 0; i < fold; i++ {
to := dats[n-1-i]
m.nonload(cgen.Assign{
Expr1: to,
Expr2: avx.Mm512AddPs{
to, dats[n+i],
},
})
}
}
}

func (m *m512Cell) apply(dat cgen.Gen, ops []mod.Op) {
for i := range ops {
switch op := &ops[i]; op.Kind {
case mod.Add:
n := op.Int
dats := make([]cgen.Gen, 1+n)
dats[0] = dat
for j := 1; j <= n; j++ {
dats[j] = m.load()
}
m.adder(dats)
case mod.Bn:
j := m.nextBn
m.nextBn = j + 1
m.nonload(&bn.Apply{
Ctx: m.bc,
Mul: m.BnMuls[j],
Add: m.BnAdds[j],
To: dat,
})
case mod.ReLU:
m.nonload(&act.ReLU{
Ctx: m.ac,
NegSlope: op.Float,
Var: dat,
})
default:
panic("bug")
}
}
}

func (m *m512Cell) Stmts() cgen.Stmts {
m.mask = il(1<<uint(m.Lanes) - 1)
var (
last = len(m.Ops) - 1
dats = make([]cgen.Gen, last)
loads = make([]cgen.Stmts, last)
nonloads = make([]cgen.Stmts, last)
)
for i := 0; i < last; i++ {
m.loads = nil
m.nonloads = nil
dat := m.load()
m.apply(dat, m.Ops[i])
dats[i] = dat
loads[i] = m.loads
nonloads[i] = m.nonloads
}
if last > 1 {
m.loads = mix(loads)
m.nonloads = mix(nonloads)
m.adder(dats)
}
dat := dats[0]
m.apply(dat, m.Ops[last])
for range m.Ptrs[m.nextPtr:] {
m.nonload(avx.Mm512MaskStoreuPs{
m.ptr(), m.mask, dat,
})
}
return append(
m.loads, m.nonloads...,
)
}

Top || internal/compile/author/engine/engine.go

package engine

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/errmsg"
"NN-512/internal/compile/author/net"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
"sort"
)

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func ptr(t cgen.Gen) cgen.Gen {
return cgen.Ptr{Type: t}
}

type tensor struct {
name string
chans int
height int
width int
}

type tensors []*tensor

func (ts tensors) Len() int {
return len(ts)
}

func (ts tensors) Less(i, j int) bool {
return ts[i].name < ts[j].name
}

func (ts tensors) Swap(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
}

type Ctx struct {
nms nmsrc.Src
emc *errmsg.Ctx
tc *threader.Ctx
nc *net.Ctx
structName string
StructNet string
StructTeam string
structAlloc string
StructAlign string
Alignment int
Split int
createName string
pthreadTName string
inferenceName string
InferenceEng cgen.Gen
inferenceTensors tensors
destroyName string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, emc *errmsg.Ctx, tc *threader.Ctx, nc *net.Ctx) *Ctx {
var (
structName = pl.Config.Prefix + "Engine"
alignment = nc.Alignment
)
split := func() (bytes int) {
for _, op := range pl.Seq {
for _, span := range op.To {
for _, pile := range span.Piles {
first := pile.OffsetBytes
if first < 0 {
continue
}
past := first + pile.SizeBytes
if bytes < past {
bytes = past
}
}
}
}
bytes += alignment - 1
bytes &= -alignment
return
}
Tensors := func() (ts tensors) {
put := func(t *tensor) {
ts = append(ts, t)
}
for _, op := range pl.Seq {
switch node := op.Nodes[0].(type) {
case *raw.Input:
put(&tensor{
name: node.ToTensor,
chans: node.Channels,
height: node.Height,
width: node.Width,
})
case *raw.Output:
var (
span = op.From[0]
pile = span.Piles[0]
)
put(&tensor{
name: node.FromTensor,
chans: pile.Channels,
height: pile.Height,
width: pile.Width,
})
}
}
sort.Sort(ts)
return
}
return &Ctx{
nms: nms,
emc: emc,
tc: tc,
nc: nc,
structName: structName,
StructNet: nms.Name("net"),
StructTeam: nms.Name("team"),
structAlloc: nms.Name("alloc"),
StructAlign: nms.Name("align"),
Alignment: alignment,
Split: split(),
createName: structName + "Create",
pthreadTName: structName + "PthreadT",
inferenceName: structName + "Inference",
InferenceEng: vb(nms.Name("eng")),
inferenceTensors: Tensors(),
destroyName: structName + "Destroy",
}
}

func (c *Ctx) Comment() cgen.Gen {
const (
space = ` `
indent = space + space + space + space
)
var comment cgen.Comment
text := func(lines ...string) {
comment = append(comment, lines...)
}
text(
`An Engine performs inference. It contains inference threads, scratch`,
`memory, and a pointer to the Net. Any number of Engines can share the`,
`same Net (and perform inference in parallel) because the Net is never`,
`modified. For best performance the number of inference threads should`,
`not exceed the number of CPU cores.`,
``,
indent+c.nc.StructName+`* net;`,
``,
indent+`... Create net ...`,
``,
indent+c.structName+`* engine; // For example, 4 inference threads:`,
indent+`char* err = `+c.createName+`(&engine, net, 4);`,
``,
indent+`if (err) { // Nonzero err means failure; engine is unmodified.`,
indent+indent+`printf("%s\n", err); // Explain the failure, add a newline.`,
indent+indent+`free(err); // Free the error string to avoid a memory leak.`,
``,
indent+indent+`... Destroy net ...`,
``,
indent+indent+`exit(1); // Exit, or propagate the failure some other way.`,
indent+`}`,
``,
indent+`... Use the POSIX threads API to adjust engine's threads ...`,
indent+`... Use engine to perform inference (dependent on net) ...`,
``,
indent+c.destroyName+`(engine); // Terminate threads, free memory.`,
``,
indent+`... Destroy net ...`,
``,
`The POSIX threads API can be used to adjust an Engine's threads. If`,
`an Engine has N threads, those threads are indexed 0, 1, 2, ..., N-1`,
`and a pthread_t identifier is associated with each index. To set the`,
`CPU affinity mask for the first inference thread, for example:`,
``,
indent+`pthread_t thread; // The first thread has index 0:`,
indent+`char* err = `+c.pthreadTName+`(engine, 0, &thread);`,
``,
indent+`assert(!err); // Can only fail if the thread index is invalid.`,
``,
indent+`pthread_setaffinity_np(thread, ...); // Details omitted.`,
``,
`The inference function reads floats from (one or more) input tensors`,
`and writes floats to (one or more) output tensors. All the input and`,
`output tensors are owned (allocated and freed) by the caller and are`,
`in CHW format, 32-bit floating point, fully packed (in other words,`,
`C has the largest pitch, W has the smallest pitch, and there is no`,
`padding anywhere).`,
``,
)
for _, t := range c.inferenceTensors {
text(fmt.Sprintf(
indent+`float* %s = malloc(sizeof(float)*%d*%d*%d);`,
t.name, t.chans, t.height, t.width,
))
}
text(
``,
indent+`for (...) { // Reuse the input and output tensors.`,
``,
indent+indent+`... Write the input floats ...`,
``,
indent+indent+c.inferenceName+`( // This function cannot fail.`,
indent+indent+indent+`engine, // Pass an Engine as the first argument.`,
)
for x, t := range c.inferenceTensors {
var follow string
switch x {
case 0:
follow = `, // The tensor arguments are sorted by name.`
case len(c.inferenceTensors) - 1:
default:
follow = `,`
}
text(fmt.Sprintf(
indent+indent+indent+`%s%s`,
t.name, follow,
))
}
text(
indent+indent+`);`,
``,
indent+indent+`... Read the output floats ...`,
``,
indent+`}`,
``,
)
for _, t := range c.inferenceTensors {
text(fmt.Sprintf(
indent+`free(%s);`,
t.name,
))
}
text(
``,
`The tensor parameters of the inference function are ordered by name,`,
`lexically bytewise. In other words, the function parameters have been`,
`sorted by name using Go's "<" string comparison operator (a bytewise`,
`lexical string sort).`,
)
return comment
}

func (c *Ctx) StructFwd() cgen.Gen {
return cgen.StructFwd(c.structName)
}

func (c *Ctx) StructDef() cgen.Gen {
return cgen.StructDef{
Name: c.structName,
Fields: cgen.Stmts{
cgen.Field{
Type: ptr(vb(c.nc.StructName)),
What: vb(c.StructNet),
},
cgen.Field{
Type: c.tc.PtrTeam,
What: vb(c.StructTeam),
},
cgen.Field{
Type: cgen.PtrChar,
What: vb(c.structAlloc),
},
cgen.Field{
Type: cgen.PtrChar,
What: vb(c.StructAlign),
},
},
}
}

func (c *Ctx) CreateDecl() cgen.Gen {
return cgen.FuncDecl{
ReturnType: cgen.PtrChar,
Name: c.createName,
Params: cgen.CommaLines{
ptr(ptr(vb(c.structName))),
ptr(vb(c.nc.StructName)),
cgen.Param{
Type: cgen.PtrdiffT,
What: vb("threads"),
},
},
}
}

func (c *Ctx) CreateDef(bytes int) cgen.Gen {
var (
paramEng = vb(c.nms.Name("eng"))
paramNet = vb(c.nms.Name("net"))
paramThreads = vb(c.nms.Name("threads"))
eng = vb(c.nms.Name("eng"))
alloc = vb(c.nms.Name("alloc"))
)
freeEng := cgen.Call{
Func: cgen.Free, Args: eng,
}
freeAlloc := cgen.Call{
Func: cgen.Free, Args: alloc,
}
field := func(nm string) cgen.Gen {
return cgen.Arrow{
Expr: eng, Name: nm,
}
}
align := cgen.Cast{
Type: cgen.PtrVoid,
Expr: cgen.Paren{
Inner: cgen.And{
Expr1: cgen.Paren{
Inner: cgen.Add{
Expr1: cgen.Cast{
Type: cgen.SizeT,
Expr: alloc,
},
Expr2: il(c.Alignment - 1),
},
},
Expr2: il(-c.Alignment),
},
},
}
return cgen.FuncDef{
ReturnType: cgen.PtrChar,
Name: c.createName,
Params: cgen.CommaLines{
cgen.Param{
Type: ptr(ptr(vb(c.structName))),
What: paramEng,
},
cgen.Param{
Type: ptr(vb(c.nc.StructName)),
What: paramNet,
},
cgen.Param{
Type: cgen.PtrdiffT,
What: paramThreads,
},
},
Body: cgen.Stmts{
cgen.Var{
Type: ptr(vb(c.structName)),
What: eng,
Init: cgen.Call{
Func: cgen.Malloc,
Args: cgen.Sizeof{
What: vb(c.structName),
},
},
},
&errmsg.ErrnoIf{
Ctx: c.emc,
Cond: cgen.IsZero{Expr: eng},
},
cgen.Var{
Type: cgen.PtrChar,
What: alloc,
Init: cgen.Call{
Func: cgen.Malloc,
Args: il(
c.Alignment - 1 +
c.Split +
bytes,
),
},
},
&errmsg.ErrnoIf{
Ctx: c.emc,
Cond: cgen.IsZero{Expr: alloc},
Unwind: freeEng,
},
cgen.Assign{
Expr1: field(c.structAlloc),
Expr2: alloc,
},
cgen.Assign{
Expr1: field(c.StructAlign),
Expr2: align,
},
&threader.Create{
Ctx: c.tc,
Team: cgen.Addr{
Expr: field(c.StructTeam),
},
Nt: paramThreads,
Unwind: cgen.Stmts{
freeEng,
freeAlloc,
},
},
cgen.Assign{
Expr1: field(c.StructNet),
Expr2: paramNet,
},
cgen.Assign{
Expr1: cgen.At{
Expr: paramEng,
},
Expr2: eng,
},
cgen.Return{
Expr: il(0),
},
},
}
}

func (c *Ctx) PthreadTDecl() cgen.Gen {
return cgen.FuncDecl{
ReturnType: cgen.PtrChar,
Name: c.pthreadTName,
Params: cgen.CommaLines{
ptr(vb(c.structName)),
cgen.Param{
Type: cgen.PtrdiffT,
What: vb("threadIdx"),
},
cgen.Param{
Type: cgen.PtrPthreadT,
What: vb("to"),
},
},
}
}

func (c *Ctx) PthreadTDef() cgen.Gen {
var (
paramEng = vb(c.nms.Name("eng"))
paramIdx = vb(c.nms.Name("idx"))
paramTo = vb(c.nms.Name("to"))
)
return cgen.FuncDef{
ReturnType: cgen.PtrChar,
Name: c.pthreadTName,
Params: cgen.CommaLines{
cgen.Param{
Type: ptr(vb(c.structName)),
What: paramEng,
},
cgen.Param{
Type: cgen.PtrdiffT,
What: paramIdx,
},
cgen.Param{
Type: cgen.PtrPthreadT,
What: paramTo,
},
},
Body: &threader.PthreadT{
Ctx: c.tc,
Thr: paramTo,
Team: cgen.Arrow{
Expr: paramEng,
Name: c.StructTeam,
},
Idx: paramIdx,
},
}
}

func (c *Ctx) inferenceParams(isDef bool) cgen.Gen {
var (
n = len(c.inferenceTensors)
lines = make(cgen.CommaLines, 1+n)
)
lines[0] = ptr(vb(c.structName))
if isDef {
lines[0] = cgen.Param{
Type: lines[0],
What: c.InferenceEng,
}
}
for x, t := range c.inferenceTensors {
lines[1+x] = cgen.Param{
Type: cgen.PtrFloat,
What: vb(t.name),
}
}
return lines
}

func (c *Ctx) InferenceDecl() cgen.Gen {
return cgen.FuncDecl{
ReturnType: cgen.Void,
Name: c.inferenceName,
Params: c.inferenceParams(false),
}
}

func (c *Ctx) InferenceDef(body cgen.Gen) cgen.Gen {
return cgen.FuncDef{
ReturnType: cgen.Void,
Name: c.inferenceName,
Params: c.inferenceParams(true),
Body: body,
}
}

func (c *Ctx) DestroyDecl() cgen.Gen {
return cgen.FuncDecl{
ReturnType: cgen.Void,
Name: c.destroyName,
Params: ptr(vb(c.structName)),
}
}

func (c *Ctx) DestroyDef() cgen.Gen {
paramEng := vb(c.nms.Name("eng"))
return cgen.FuncDef{
ReturnType: cgen.Void,
Name: c.destroyName,
Params: cgen.Param{
Type: ptr(vb(c.structName)),
What: paramEng,
},
Body: cgen.Stmts{
&threader.Destroy{
Ctx: c.tc,
Team: cgen.Arrow{
Expr: paramEng,
Name: c.StructTeam,
},
},
cgen.Call{
Func: cgen.Free,
Args: cgen.Arrow{
Expr: paramEng,
Name: c.structAlloc,
},
},
cgen.Call{
Func: cgen.Free,
Args: paramEng,
},
},
}
}

Top || internal/compile/author/eof/eof.go

package eof

import "NN-512/internal/compile/author/cgen"

var Gen cgen.Gen = cgen.Comment{"End of file."}

Top || internal/compile/author/errmsg/errmsg.go

package errmsg

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
)

type Ctx struct {
nms nmsrc.Src
msgPrefix string
funcName string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src) *Ctx {
prefix := pl.Config.Prefix
return &Ctx{
nms: nms,
msgPrefix: prefix,
funcName: nms.Name(prefix + "Errmsg"),
}
}

func (c *Ctx) name(prefix string) cgen.Gen {
return cgen.Vb(c.nms.Name(prefix))
}

type Prep struct {
*Ctx
lineNum cgen.Gen
format cgen.Gen
}

func (p *Prep) body() cgen.Gen {
msg := p.name("msg")
pre := cgen.DoubleQuoted(p.msgPrefix + ": line %td: ")
const plenty = 1 << 8
size := cgen.IntLit(len(pre) + plenty)
step := p.name("step")
ap := p.name("ap")
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: msg,
Init: cgen.Call{Func: cgen.Malloc, Args: size},
},
cgen.Var{
Type: cgen.Int,
What: step,
Init: cgen.Call{
Func: cgen.Sprintf,
Args: cgen.CommaSpaced{msg, pre, p.lineNum},
},
},
cgen.Var{Type: cgen.VaList, What: ap},
cgen.Call{
Func: cgen.VaStart,
Args: cgen.CommaSpaced{ap, p.format},
},
cgen.Call{
Func: cgen.Vsnprintf,
Args: cgen.CommaSpaced{
cgen.Add{Expr1: msg, Expr2: step},
cgen.Sub{Expr1: size, Expr2: step},
p.format, ap,
},
},
cgen.Call{Func: cgen.VaEnd, Args: ap},
cgen.Return{Expr: msg},
}
}

func (p *Prep) Append(to []byte) []byte {
p.lineNum = p.name("lineNum")
p.format = p.name("format")
return cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: p.funcName,
Params: cgen.CommaSpaced{
cgen.Param{Type: cgen.PtrdiffT, What: p.lineNum},
cgen.Param{Type: cgen.PtrChar, What: p.format},
cgen.Ellipsis,
},
Body: p.body(),
}.Append(to)
}

type FormatIf struct {
*Ctx
Cond cgen.Gen
Format string
Args []cgen.Gen
Unwind cgen.Gen
}

func (f *FormatIf) Append(to []byte) []byte {
args := make(cgen.CommaSpaced, 2+len(f.Args))
args[0] = cgen.LineNum
args[1] = cgen.DoubleQuoted(f.Format)
copy(args[2:], f.Args)
call := cgen.Call{
Func: cgen.Vb(f.funcName),
Args: args,
}
var then cgen.Stmts
if f.Unwind == nil {
then = cgen.Stmts{cgen.Return{Expr: call}}
} else {
msg := f.name("msg")
then = cgen.Stmts{
cgen.Var{Type: cgen.PtrChar, What: msg, Init: call},
f.Unwind,
cgen.Return{Expr: msg},
}
}
return cgen.Stmts{cgen.If{
Cond: cgen.Unlikely{Cond: f.Cond},
Then: then,
}}.Append(to)
}

const errnoFormat = "errno %d"

type ErrnoIf struct {
*Ctx
Cond cgen.Gen
Unwind cgen.Gen
}

func (e *ErrnoIf) Append(to []byte) []byte {
return (&FormatIf{
Ctx: e.Ctx,
Cond: e.Cond,
Format: errnoFormat,
Args: []cgen.Gen{cgen.Errno},
Unwind: e.Unwind,
}).Append(to)
}

type ReturnedErrnoIf struct {
*Ctx
Call cgen.Gen
Unwind cgen.Gen
}

func (r *ReturnedErrnoIf) Append(to []byte) []byte {
err := r.name("err")
return cgen.Stmts{
cgen.Var{Type: cgen.Int, What: err, Init: r.Call},
&FormatIf{
Ctx: r.Ctx,
Cond: err,
Format: errnoFormat,
Args: []cgen.Gen{err},
Unwind: r.Unwind,
},
}.Append(to)
}

Top || internal/compile/author/exp/exp.go

package exp

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

type Ctx struct {
platform raw.Platform
nms nmsrc.Src
funcName string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src) *Ctx {
return &Ctx{
platform: pl.Config.Platform,
nms: nms,
funcName: nms.Name(pl.Config.Prefix + "Exp"),
}
}

func (c *Ctx) name(a string) cgen.Gen {
return cgen.Vb(c.nms.Name(a))
}

type Prep struct {
*Ctx
to []byte
}

func (p *Prep) Append(to []byte) []byte {
p.to = to
switch p.platform {
case raw.AVX512Float32:
p.m512()
default:
panic("bug")
}
return p.to
}

func (p *Prep) m512() {
var (
x = p.name("x")
t = p.name("t")
r = p.name("r")
f = p.name("f")
g = p.name("g")
y = p.name("y")
)
p.to = cgen.StaticFuncDef{
ReturnType: avx.M512,
Name: p.funcName,
Params: cgen.Param{Type: avx.M512, What: x},
Body: cgen.Stmts{
cgen.Assign{
Expr1: x,
Expr2: avx.Mm512MaxPs{x, avx.Mm512Set1PsLit(-87.33654)},
},
cgen.Assign{
Expr1: x,
Expr2: avx.Mm512MinPs{x, avx.Mm512Set1PsLit(88.72284)},
},
cgen.Var{
Type: avx.M512, What: t,
Init: avx.Mm512MulPs{x, avx.Mm512Set1PsLit(1.442695)},
},
cgen.Var{
Type: avx.M512, What: r,
Init: avx.Mm512RoundscalePs{t, avx.FroundToNearestIntNoExc},
},
cgen.Var{
Type: avx.M512, What: f,
Init: avx.Mm512FmaddPs{r, avx.Mm512Set1PsLit(-0.69314575), x},
},
cgen.Assign{
Expr1: f,
Expr2: avx.Mm512FmaddPs{r, avx.Mm512Set1PsLit(-1.4286068e-6), f},
},
cgen.Var{
Type: avx.M512, What: g, Init: avx.Mm512Set1PsLit(0.04194439),
},
cgen.Assign{
Expr1: g,
Expr2: avx.Mm512FmaddPs{g, f, avx.Mm512Set1PsLit(0.16800667)},
},
cgen.Assign{
Expr1: g,
Expr2: avx.Mm512FmaddPs{g, f, avx.Mm512Set1PsLit(0.49999994)},
},
cgen.Assign{
Expr1: g,
Expr2: avx.Mm512FmaddPs{g, f, avx.Mm512Set1PsLit(0.9999569)},
},
cgen.Assign{
Expr1: g,
Expr2: avx.Mm512FmaddPs{g, f, avx.Mm512Set1PsLit(0.99999964)},
},
cgen.Var{
Type: avx.M512i, What: y,
Init: avx.Mm512SlliEpi32{
avx.Mm512CvtpsEpi32{t}, cgen.IntLit(23),
},
},
cgen.Return{Expr: avx.Mm512Castsi512Ps{
avx.Mm512AddEpi32{y, avx.Mm512CastpsSi512{g}},
}},
},
}.Append(p.to)
}

type Call struct {
*Ctx
Arg cgen.Gen
}

func (c *Call) Append(to []byte) []byte {
return cgen.Call{
Func: cgen.Vb(c.funcName),
Args: c.Arg,
}.Append(to)
}

Top || internal/compile/author/fc/fc.go

package fc

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/sumr"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

func btoi(b bool) int {
if b {
return 1
}
return 0
}

func min(x, y int) int {
if x <= y {
return x
}
return y
}

func max(x, y int) int {
if x >= y {
return x
}
return y
}

func ceilQuo(n, d int) int {
return (n + d - 1) / d
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func loMask(n int) cgen.Gen {
return il(1<<uint(n) - 1)
}

func cast(pitch int) cgen.Gen {
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: il(pitch),
}
}

func addr(ptr, pitch, idx cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: ptr,
Expr2: cgen.Mul{
Expr1: pitch,
Expr2: idx,
},
}
}

func mix(a []cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Ctx struct {
prefix string
platform raw.Platform
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Fc",
platform: pl.Config.Platform,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]string),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

func (c *Ctx) newLayout(toC, fromC, fromH, fromW int) *layout {
var y *layout
pad := func(n int) int {
n += y.alignment - 1
n &= -y.alignment
return n
}
switch c.platform {
case raw.AVX512Float32:
y = &layout{
cellWeights1: 16,
groupCells1: 16,
alignment: 64,
weightBytes1: 4,
weightBytes2: 2,
biasBytes: 4,
datBytes: 4,
}
default:
panic("bug")
}
y.fromHW = fromH * fromW
y.fromCHW = fromC * y.fromHW
y.cellWeights2 = y.fromCHW % y.cellWeights1
y.stripGroups1 = y.fromCHW / y.cellWeights1
y.stripGroups2 = y.stripGroups1 + btoi(y.cellWeights2 > 0)
y.groupCells2 = toC % y.groupCells1
y.strips1 = toC / y.groupCells1
y.strips2 = y.strips1 + btoi(y.groupCells2 > 0)
y.cellBytes = y.cellWeights1 * y.weightBytes2
y.groupBytes1 = y.groupCells1 * y.cellBytes
y.groupBytes2 = y.groupCells2 * y.cellBytes
y.stripBytes1 = pad(y.stripGroups2 * y.groupBytes1)
y.stripBytes2 = pad(y.stripGroups2 * y.groupBytes2)
y.biasOffset = y.strips1*y.stripBytes1 + y.stripBytes2
y.totalBytes = y.biasOffset + toC*y.biasBytes
return y
}

type layout struct {
fromHW int
fromCHW int
cellWeights1 int
cellWeights2 int
stripGroups1 int
stripGroups2 int
groupCells1 int
groupCells2 int
strips1 int
strips2 int
alignment int
weightBytes1 int
weightBytes2 int
biasBytes int
datBytes int
cellBytes int
groupBytes1 int
groupBytes2 int
stripBytes1 int
stripBytes2 int
biasOffset int
totalBytes int
}

type Arrange struct {
*Ctx
ToC int
FromC int
FromH int
FromW int
BnPre int
BnPost int
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *Arrange) Prep() cgen.Gen {
a.layout = a.newLayout(
a.ToC, a.FromC, a.FromH, a.FromW,
)
const affix = "Arrange"
sig := fmt.Sprint(
affix, " ",
a.ToC, a.FromC, a.FromH, a.FromW,
a.BnPre, a.BnPost,
)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrange{Arrange: a},
cgen.Newline,
}
}

func (a *Arrange) Bytes() int {
return a.totalBytes
}

func (a *Arrange) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrange struct {
*Arrange
tile int
tiles int
scrap int
hull1 int
hull2 int
calleeName string
weights1 cgen.Gen
biases1 cgen.Gen
bnPtrs []cgen.Gen
weights2 cgen.Gen
biases2 cgen.Gen
strips int
groupCells int
}

func (a *arrange) Append(to []byte) []byte {
var (
threadVecs int
stripVecs = a.stripGroups2 * a.groupCells1
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
switch a.platform {
case raw.AVX512Float32:
threadVecs = 512
default:
panic("bug")
}
a.tile = ceilQuo(threadVecs, stripVecs)
a.tiles = a.strips1 / a.tile
a.scrap = a.strips1 % a.tile
a.hull1 = a.tiles + btoi(a.scrap > 0)
a.hull2 = a.hull1 + btoi(a.strips1 < a.strips2)
a.calleeName = a.name(a.callerName + "Callee")
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.hull2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrange) calleeFunc() cgen.Gen {
var (
body = make(cgen.Stmts, 6)
tensors = vb(a.name("tensors"))
t = vb(a.name("t"))
)
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrdiffT, What: t,
Init: cgen.Elem{Arr: callee.Pt, Idx: il(0)},
}
body[2] = a.ptrs(tensors, t)
part := func(i, n int) {
body[i] = a.kernel()
if n < a.hull2 {
body[i] = cgen.If{
Cond: cgen.CmpL{
Expr1: t,
Expr2: il(n),
},
Then: cgen.Stmts{
body[i],
cgen.Return{},
},
}
}
}
if 0 < a.tiles {
a.strips = a.tile
a.groupCells = a.groupCells1
part(3, a.tiles)
}
if a.tiles < a.hull1 {
a.strips = a.scrap
a.groupCells = a.groupCells1
part(4, a.hull1)
}
if a.hull1 < a.hull2 {
a.strips = 1
a.groupCells = a.groupCells2
body[5] = a.kernel()
}
return callee.Func(body)
}

func (a *arrange) ptrs(tensors, t cgen.Gen) cgen.Gen {
var (
bnCnt = a.BnPre + a.BnPost
stmts = make(cgen.Stmts, 3+bnCnt+2)
s = t
)
if a.tile > 1 {
s = vb(a.name("s"))
var strip cgen.Gen = cgen.Mul{
Expr1: cast(a.tile),
Expr2: t,
}
if i := a.tiles + 1; i < a.hull2 {
fix := cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpE{
Expr1: t,
Expr2: il(i),
},
Then: il(a.tile - a.scrap),
Else: il(0),
},
}
strip = cgen.Sub{
Expr1: strip,
Expr2: fix,
}
}
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT, What: s,
Init: strip,
}
}
var (
tensorIdx = 0
n = a.groupCells1
weightPitch1 = cast(n * a.fromCHW * a.weightBytes1)
weightPitch2 = cast(a.stripBytes1)
biasPitch = cast(n * a.biasBytes)
)
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
a.weights1 = vb(a.name("weights"))
stmts[1] = cgen.Var{
Type: cgen.RestrictPtrChar, What: a.weights1,
Init: addr(tensor(), weightPitch1, s),
}
a.biases1 = vb(a.name("biases"))
stmts[2] = cgen.Var{
Type: cgen.RestrictPtrChar, What: a.biases1,
Init: addr(tensor(), biasPitch, s),
}
a.bnPtrs = make([]cgen.Gen, bnCnt)
for i := range a.bnPtrs {
var (
bnPtr = vb(a.name("bnPtr"))
expr = tensor()
)
if i >= a.BnPre {
expr = &bn.Offset{
Ctx: a.bc,
Mas: expr,
Channel: cgen.Mul{
Expr1: il(n),
Expr2: s,
},
}
}
stmts[3+i] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: bnPtr, Init: expr,
}
a.bnPtrs[i] = bnPtr
}
var (
arranged1 = tensor()
arranged2 = cgen.Add{
Expr1: arranged1,
Expr2: cast(a.biasOffset),
}
)
a.weights2 = vb(a.name("weights"))
stmts[3+bnCnt] = cgen.Var{
Type: cgen.RestrictPtrChar, What: a.weights2,
Init: addr(arranged1, weightPitch2, s),
}
a.biases2 = vb(a.name("biases"))
stmts[3+bnCnt+1] = cgen.Var{
Type: cgen.RestrictPtrChar, What: a.biases2,
Init: addr(arranged2, biasPitch, s),
}
return stmts
}

func (a *arrange) kernel() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}

func (a *arrange) m512() cgen.Gen {
if a.BnPre == 0 {
return a.m512NoBnPre()
}
if a.cellWeights1%a.fromHW == 0 {
return a.m512BnPreSpecial()
}
return a.m512BnPreGeneral()
}

func (a *arrange) m512NoBnPre() cgen.Gen {
const (
lanes = 16
unroll = 16
)
var (
i = vb(a.name("i"))
gc = a.groupCells
bnMuls []cgen.Gen
j cgen.Gen
jg = 1 + gc%2
)
if cells := jg * gc; cells < unroll {
jg *= unroll / cells
}
ch := func(cell int) cgen.Gen {
return cgen.Paren{
Inner: cgen.Add{
Expr1: il(cell),
Expr2: cgen.Mul{
Expr1: il(gc),
Expr2: i,
},
},
}
}
ld := func(wt cgen.Gen, pair, side, elems int) cgen.Gen {
var (
from = a.weights1
k = pair*2 + side
group = k / gc
cell = k % gc
groupPitch = a.cellWeights1 * a.weightBytes1
cellPitch = a.fromCHW * a.weightBytes1
iPitch = gc * cellPitch
jPitch = jg * groupPitch
)
from = cgen.Add{
Expr1: from,
Expr2: cast(group*groupPitch + cell*cellPitch),
}
from = addr(from, cast(iPitch), i)
from = addr(from, cast(jPitch), j)
return cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512MaskzLoaduPs{
loMask(elems), from,
},
}
}
mul := func(wt cgen.Gen, pair, side int) cgen.Gen {
if bnMuls == nil {
return nil
}
var (
k = pair*2 + side
bnMul = bnMuls[k%gc]
)
return cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
bnMul, wt,
},
}
}
cvt := func(half, wt cgen.Gen) cgen.Gen {
return cgen.Var{
Type: avx.M256i, What: half,
Init: avx.Mm512CvtpsPh{
wt, avx.FroundToNearestIntNoExc,
},
}
}
st := func(yield cgen.Gen, pair, elems int) cgen.Gen {
var (
to = a.weights2
iPitch = a.stripBytes1
jPitch = jg * gc * a.cellBytes
)
to = cgen.Add{
Expr1: to,
Expr2: cast(pair * 2 * a.cellBytes),
}
to = addr(to, cast(iPitch), i)
to = addr(to, cast(jPitch), j)
return avx.Mm512MaskStoreuEpi32{
to, loMask(elems), yield,
}
}
two := func(pair, elemsLo, elemsHi int) cgen.Stmts {
var (
wtLo = vb(a.name("wtLo"))
wtHi = vb(a.name("wtHi"))
halfLo = vb(a.name("halfLo"))
halfHi = vb(a.name("halfHi"))
yield = vb(a.name("yield"))
)
return cgen.Stmts{
cgen.Stmts{
ld(wtLo, pair, 0, elemsLo),
ld(wtHi, pair, 1, elemsHi),
},
cgen.Stmts{
mul(wtLo, pair, 0),
mul(wtHi, pair, 1),
},
cgen.Stmts{
cvt(halfLo, wtLo),
cvt(halfHi, wtHi),
},
cgen.Var{
Type: avx.M512i, What: yield,
Init: avx.Mm512Inserti64x4{
avx.Mm512Castsi256Si512{halfLo},
halfHi, il(1),
},
},
st(yield, pair, lanes),
}
}
one := func(pair, elemsLo int) cgen.Stmts {
var (
wtLo = vb(a.name("wtLo"))
halfLo = vb(a.name("halfLo"))
yield = vb(a.name("yield"))
)
return cgen.Stmts{
ld(wtLo, pair, 0, elemsLo),
mul(wtLo, pair, 0),
cvt(halfLo, wtLo),
cgen.Var{
Type: avx.M512i, What: yield,
Init: avx.Mm512Castsi256Si512{halfLo},
},
st(yield, pair, lanes/2),
}
}
layer4 := func(cells1, cells2 int) cgen.Gen {
var (
n1 = cells2 / 2
n2 = n1 + cells2%2
toMix = make([]cgen.Stmts, n2)
)
for pair := 0; pair < n1; pair++ {
var (
k = pair*2 + 1
elemsLo = a.cellWeights1
elemsHi = elemsLo
)
if k >= cells1 {
elemsHi = a.cellWeights2
if k-1 >= cells1 {
elemsLo = elemsHi
}
}
toMix[pair] = two(pair, elemsLo, elemsHi)
}
if n1 < n2 {
elemsLo := a.cellWeights1
if n1*2 >= cells1 {
elemsLo = a.cellWeights2
}
toMix[n1] = one(n1, elemsLo)
}
const bundle = unroll / 2
var (
bundles = ceilQuo(n2, bundle)
ret = make(cgen.Gens, bundles)
)
for x := range ret {
var (
first = x * bundle
past = min(first+bundle, n2)
)
ret[x] = mix(toMix[first:past])
}
return ret
}
layer3 := func() cgen.Gen {
var (
stmts = make(cgen.Stmts, 2)
iters = a.stripGroups1 / jg
after = a.stripGroups1 % jg
n1 = jg * gc
n2 = after * gc
n3 = n2
)
if a.stripGroups1 < a.stripGroups2 {
after++
n3 += gc
}
if iters > 0 {
j = vb(a.name("j"))
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: j,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: j,
},
Body: layer4(n1, n1),
}
}
if after > 0 {
j = il(iters)
stmts[1] = layer4(n2, n3)
}
return stmts
}
layer2 := func() cgen.Gen {
var (
parts = ceilQuo(gc, lanes)
toMix = make([]cgen.Stmts, parts)
)
for part := range toMix {
var (
bnCnt = a.BnPost
stmts = make(cgen.Stmts, 1+bnCnt*2+1)
bias = vb(a.name("bias"))
first = part * lanes
cnt = min(gc-first, lanes)
mask = loMask(cnt)
bnCh = ch(first)
)
offset := cgen.Add{
Expr1: cast(first * a.biasBytes),
Expr2: cgen.Mul{
Expr1: cast(gc * a.biasBytes),
Expr2: i,
},
}
from := cgen.Add{
Expr1: a.biases1,
Expr2: offset,
}
to := cgen.Add{
Expr1: a.biases2,
Expr2: offset,
}
stmts[0] = cgen.Var{
Type: avx.M512, What: bias,
Init: avx.Mm512MaskzLoaduPs{
mask, from,
},
}
for x := 0; x < bnCnt; x++ {
var (
bnPtr = a.bnPtrs[x]
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
stmts[1+x*2] = &bn.Load{
Ctx: a.bc,
Mas: bnPtr,
Channel: bnCh,
Mul: bnMul,
Add: bnAdd,
Cnt: cnt,
}
stmts[1+x*2+1] = &bn.Apply{
Ctx: a.bc,
Mul: bnMul,
Add: bnAdd,
To: bias,
}
}
stmts[1+bnCnt*2] = avx.Mm512MaskStoreuPs{
to, mask, bias,
}
toMix[part] = stmts
}
return cgen.Gens{
layer3(),
mix(toMix),
}
}
layer1 := func() cgen.Gen {
var (
bnPrep cgen.Gen
bnCnt = a.BnPost
)
if bnCnt > 0 {
bnMuls = make([]cgen.Gen, gc)
toMix := make([]cgen.Stmts, gc)
for cell := range toMix {
var (
stmts = make(cgen.Stmts, bnCnt*2)
bnCh = ch(cell)
)
for x := 0; x < bnCnt; x++ {
var (
bnPtr = a.bnPtrs[x]
bnMul = vb(a.name("bnMul"))
)
stmts[x*2] = &bn.Load{
Ctx: a.bc,
Mas: bnPtr,
Channel: bnCh,
Mul: bnMul,
}
if x == 0 {
bnMuls[cell] = bnMul
continue
}
prod := bnMuls[cell]
stmts[x*2+1] = cgen.Assign{
Expr1: prod,
Expr2: avx.Mm512MulPs{
prod, bnMul,
},
}
}
toMix[cell] = stmts
}
bnPrep = mix(toMix)
}
return cgen.Gens{
bnPrep,
layer2(),
}
}
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(a.strips),
},
Post: cgen.IncPre{
Expr: i,
},
Body: layer1(),
}
}

func (a *arrange) m512BnPreSpecial() cgen.Gen {
var (
i = vb(a.name("i"))
gc = a.groupCells
jUnroll = 8
j cgen.Gen
cells int
postMuls []cgen.Gen
sums []cgen.Gen
k cgen.Gen
elems int
preMul1 cgen.Gen
preAdd1 cgen.Gen
)
if a.BnPost == 0 {
jUnroll = 16
}
ch := func(cell int) cgen.Gen {
return cgen.Paren{
Inner: cgen.Add{
Expr1: il(cell),
Expr2: cgen.Add{
Expr1: cgen.Mul{
Expr1: il(gc),
Expr2: i,
},
Expr2: cgen.Mul{
Expr1: il(jUnroll),
Expr2: j,
},
},
},
}
}
ld := func(wt cgen.Gen, pair, side int) cgen.Gen {
var (
from = a.weights1
cell = pair*2 + side
cellPitch = a.fromCHW * a.weightBytes1
iPitch = gc * cellPitch
jPitch = jUnroll * cellPitch
kPitch = a.cellWeights1 * a.weightBytes1
)
from = cgen.Add{
Expr1: from,
Expr2: cast(cell * cellPitch),
}
from = addr(from, cast(iPitch), i)
from = addr(from, cast(jPitch), j)
from = addr(from, cast(kPitch), k)
return cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512MaskzLoaduPs{
loMask(elems), from,
},
}
}
madd := func(wt cgen.Gen, pair, side int) cgen.Gen {
var (
cell = pair*2 + side
sum = sums[cell]
)
return cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
wt, preAdd1, sum,
},
}
}
muls := func(wt cgen.Gen, pair, side int) cgen.Gen {
inner := wt
if postMuls != nil {
cell := pair*2 + side
inner = avx.Mm512MulPs{
wt, postMuls[cell],
}
}
return cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
inner, preMul1,
},
}
}
cvt := func(half, wt cgen.Gen) cgen.Gen {
return cgen.Var{
Type: avx.M256i, What: half,
Init: avx.Mm512CvtpsPh{
wt, avx.FroundToNearestIntNoExc,
},
}
}
st := func(yield cgen.Gen, pair, have int) cgen.Gen {
var (
to = a.weights2
iPitch = a.stripBytes1
jPitch = jUnroll * a.cellBytes
kPitch = gc * a.cellBytes
mask = loMask(have * 8)
)
to = cgen.Add{
Expr1: to,
Expr2: cast(pair * 2 * a.cellBytes),
}
to = addr(to, cast(iPitch), i)
to = addr(to, cast(jPitch), j)
to = addr(to, cast(kPitch), k)
return avx.Mm512MaskStoreuEpi32{
to, mask, yield,
}
}
two := func(pair int) cgen.Stmts {
var (
wtLo = vb(a.name("wtLo"))
wtHi = vb(a.name("wtHi"))
halfLo = vb(a.name("halfLo"))
halfHi = vb(a.name("halfHi"))
yield = vb(a.name("yield"))
)
return cgen.Stmts{
cgen.Stmts{
ld(wtLo, pair, 0),
ld(wtHi, pair, 1),
},
cgen.Stmts{
madd(wtLo, pair, 0),
madd(wtHi, pair, 1),
},
cgen.Stmts{
muls(wtLo, pair, 0),
muls(wtHi, pair, 1),
},
cgen.Stmts{
cvt(halfLo, wtLo),
cvt(halfHi, wtHi),
},
cgen.Var{
Type: avx.M512i, What: yield,
Init: avx.Mm512Inserti64x4{
avx.Mm512Castsi256Si512{halfLo},
halfHi, il(1),
},
},
st(yield, pair, 2),
}
}
one := func(pair int) cgen.Stmts {
var (
wtLo = vb(a.name("wtLo"))
halfLo = vb(a.name("halfLo"))
yield = vb(a.name("yield"))
)
return cgen.Stmts{
ld(wtLo, pair, 0),
madd(wtLo, pair, 0),
muls(wtLo, pair, 0),
cvt(halfLo, wtLo),
cgen.Var{
Type: avx.M512i, What: yield,
Init: avx.Mm512Castsi256Si512{halfLo},
},
st(yield, pair, 1),
}
}
layer7 := func() cgen.Gen {
var (
n1 = cells / 2
n2 = n1 + cells%2
toMix = make([]cgen.Stmts, n2)
)
for pair := 0; pair < n1; pair++ {
toMix[pair] = two(pair)
}
if n1 < n2 {
toMix[n1] = one(n1)
}
const bundle = 4
var (
n3 = ceilQuo(n2, bundle)
ret = make(cgen.Gens, n3)
)
for x := range ret {
var (
first = x * bundle
past = min(first+bundle, n2)
)
ret[x] = mix(toMix[first:past])
}
return ret
}
layer6 := func() cgen.Gen {
var (
preCnt = a.BnPre
stmts = make(cgen.Stmts, preCnt*3)
chans = elems / a.fromHW
)
preCh := cgen.Mul{
Expr1: il(a.cellWeights1 / a.fromHW),
Expr2: k,
}
for x := 0; x < preCnt; x++ {
var (
prePtr = a.bnPtrs[x]
preMul2 = vb(a.name("preMul"))
preAdd2 = vb(a.name("preAdd"))
)
if chans == 1 {
stmts[x*3] = &bn.Load{
Ctx: a.bc,
Mas: prePtr,
Channel: preCh,
Mul: preMul2,
Add: preAdd2,
}
} else {
stmts[x*3] = &bn.Load{
Ctx: a.bc,
Mas: prePtr,
Channel: preCh,
Mul: preMul2,
Add: preAdd2,
Cnt: chans,
Spread: a.fromHW,
}
}
if x == 0 {
preMul1 = preMul2
preAdd1 = preAdd2
continue
}
stmts[x*3+1] = cgen.Assign{
Expr1: preMul1,
Expr2: avx.Mm512MulPs{
preMul1, preMul2,
},
}
stmts[x*3+2] = &bn.Apply{
Ctx: a.bc,
Mul: preMul2,
Add: preAdd2,
To: preAdd1,
}
}
return cgen.Gens{
stmts,
layer7(),
}
}
layer5 := func() cgen.Gen {
var (
stmts = make(cgen.Stmts, 2)
iters = a.stripGroups1
)
if iters > 0 {
k = vb(a.name("k"))
elems = a.cellWeights1
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: k,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: k,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: k,
},
Body: layer6(),
}
}
if iters < a.stripGroups2 {
k = il(iters)
elems = a.cellWeights2
stmts[1] = layer6()
}
return stmts
}
layer4 := func() cgen.Gen {
var (
postCnt = a.BnPost
stmts = make(cgen.Stmts, 3+postCnt*2+1)
bias = vb(a.name("bias"))
mask = loMask(cells)
iPitch = cast(gc * a.biasBytes)
jPitch = cast(jUnroll * a.biasBytes)
from = addr(a.biases1, iPitch, i)
to = addr(a.biases2, iPitch, i)
postCh = ch(0)
)
stmts[0] = &sumr.Pack{
Platform: a.platform,
Nms: a.nms,
Vars: sums,
}
stmts[1] = cgen.Var{
Type: avx.M512, What: bias,
Init: avx.Mm512MaskzLoaduPs{
mask, addr(from, jPitch, j),
},
}
stmts[2] = cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512AddPs{
sums[0], bias,
},
}
for x := 0; x < postCnt; x++ {
var (
postPtr = a.bnPtrs[a.BnPre+x]
postMul = vb(a.name("postMul"))
postAdd = vb(a.name("postAdd"))
)
stmts[3+x*2] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul,
Add: postAdd,
Cnt: cells,
}
stmts[3+x*2+1] = &bn.Apply{
Ctx: a.bc,
Mul: postMul,
Add: postAdd,
To: bias,
}
}
stmts[3+postCnt*2] = avx.Mm512MaskStoreuPs{
addr(to, jPitch, j), mask, bias,
}
return cgen.Gens{
layer5(),
stmts,
}
}
layer3 := func() cgen.Gen {
sums = make([]cgen.Gen, cells)
stmts := make(cgen.Stmts, cells)
for cell := range stmts {
sum := vb(a.name("sum"))
sums[cell] = sum
stmts[cell] = cgen.Var{
Type: avx.M512, What: sum,
Init: avx.Mm512SetzeroPs,
}
}
return cgen.Gens{
stmts,
layer4(),
}
}
layer2 := func() cgen.Gen {
var (
postPrep cgen.Gen
postCnt = a.BnPost
)
if postCnt > 0 {
postMuls = make([]cgen.Gen, cells)
toMix := make([]cgen.Stmts, cells)
for cell := range toMix {
var (
stmts = make(cgen.Stmts, postCnt*2)
postCh = ch(cell)
)
for x := 0; x < postCnt; x++ {
var (
postPtr = a.bnPtrs[a.BnPre+x]
postMul = vb(a.name("postMul"))
)
stmts[x*2] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul,
}
if x == 0 {
postMuls[cell] = postMul
continue
}
prod := postMuls[cell]
stmts[x*2+1] = cgen.Assign{
Expr1: prod,
Expr2: avx.Mm512MulPs{
prod, postMul,
},
}
}
toMix[cell] = stmts
}
postPrep = mix(toMix)
}
return cgen.Gens{
postPrep,
layer3(),
}
}
layer1 := func() cgen.Gen {
var (
stmts = make(cgen.Stmts, 2)
iters = gc / jUnroll
after = gc % jUnroll
)
if iters > 0 {
j = vb(a.name("j"))
cells = jUnroll
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: j,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: j,
},
Body: layer2(),
}
}
if after > 0 {
j = il(iters)
cells = after
stmts[1] = layer2()
}
return stmts
}
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(a.strips),
},
Post: cgen.IncPre{
Expr: i,
},
Body: layer1(),
}
}

func (a *arrange) m512BnPreGeneral() cgen.Gen {
var (
i = vb(a.name("i"))
gc = a.groupCells
jUnroll = 8
j cgen.Gen
cells int
postMuls []cgen.Gen
sums []cgen.Gen
cw = a.cellWeights1
kUnroll int
k cgen.Gen
preChans int
group int
run int
preMul1 cgen.Gen
preAdd1 cgen.Gen
preMul2 cgen.Gen
preAdd2 cgen.Gen
elems int
l cgen.Gen
)
postChan := func(cell int) cgen.Gen {
return cgen.Paren{
Inner: cgen.Add{
Expr1: il(cell),
Expr2: cgen.Add{
Expr1: cgen.Mul{
Expr1: il(gc),
Expr2: i,
},
Expr2: cgen.Mul{
Expr1: il(jUnroll),
Expr2: j,
},
},
},
}
}
ld := func(wt cgen.Gen, pair, side int) cgen.Gen {
var (
from = a.weights1
cell = pair*2 + side
cellPitch = a.fromCHW * a.weightBytes1
iPitch = gc * cellPitch
jPitch = jUnroll * cellPitch
kPitch = kUnroll * a.fromHW * a.weightBytes1
lPitch = cw * a.weightBytes1
)
from = cgen.Add{
Expr1: from,
Expr2: cast(cell * cellPitch),
}
from = addr(from, cast(iPitch), i)
from = addr(from, cast(jPitch), j)
from = addr(from, cast(kPitch), k)
from = addr(from, cast(lPitch), l)
return cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512MaskzLoaduPs{
loMask(elems), from,
},
}
}
madd := func(wt cgen.Gen, pair, side int) cgen.Gen {
var (
cell = pair*2 + side
sum = sums[cell]
)
return cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
wt, preAdd1, sum,
},
}
}
muls := func(wt cgen.Gen, pair, side int) cgen.Gen {
inner := preMul1
if postMuls != nil {
var (
cell = pair*2 + side
postMul = postMuls[cell]
)
inner = avx.Mm512MulPs{
postMul, preMul1,
}
}
return cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
wt, inner,
},
}
}
cvt := func(half, wt cgen.Gen) cgen.Gen {
return cgen.Var{
Type: avx.M256i, What: half,
Init: avx.Mm512CvtpsPh{
wt, avx.FroundToNearestIntNoExc,
},
}
}
st := func(yield cgen.Gen, pair, have int) cgen.Gen {
var (
to = a.weights2
iPitch = a.stripBytes1
jPitch = jUnroll * a.cellBytes
groupPitch = gc * a.cellBytes
kGroups = kUnroll * a.fromHW / cw
kPitch = kGroups * groupPitch
lPitch = groupPitch
mask = loMask(have * 8)
)
to = cgen.Add{
Expr1: to,
Expr2: cast(pair * 2 * a.cellBytes),
}
to = addr(to, cast(iPitch), i)
to = addr(to, cast(jPitch), j)
to = addr(to, cast(kPitch), k)
to = addr(to, cast(lPitch), l)
return avx.Mm512MaskStoreuEpi32{
to, mask, yield,
}
}
two := func(pair int) cgen.Stmts {
var (
wtLo = vb(a.name("wtLo"))
wtHi = vb(a.name("wtHi"))
halfLo = vb(a.name("halfLo"))
halfHi = vb(a.name("halfHi"))
yield = vb(a.name("yield"))
)
return cgen.Stmts{
cgen.Stmts{
ld(wtLo, pair, 0),
ld(wtHi, pair, 1),
},
cgen.Stmts{
madd(wtLo, pair, 0),
madd(wtHi, pair, 1),
},
cgen.Stmts{
muls(wtLo, pair, 0),
muls(wtHi, pair, 1),
},
cgen.Stmts{
cvt(halfLo, wtLo),
cvt(halfHi, wtHi),
},
cgen.Var{
Type: avx.M512i, What: yield,
Init: avx.Mm512Inserti64x4{
avx.Mm512Castsi256Si512{halfLo},
halfHi, il(1),
},
},
st(yield, pair, 2),
}
}
one := func(pair int) cgen.Stmts {
var (
wtLo = vb(a.name("wtLo"))
halfLo = vb(a.name("halfLo"))
yield = vb(a.name("yield"))
)
return cgen.Stmts{
ld(wtLo, pair, 0),
madd(wtLo, pair, 0),
muls(wtLo, pair, 0),
cvt(halfLo, wtLo),
cgen.Var{
Type: avx.M512i, What: yield,
Init: avx.Mm512Castsi256Si512{halfLo},
},
st(yield, pair, 1),
}
}
layer9 := func() cgen.Gen {
var (
n1 = cells / 2
n2 = n1 + cells%2
toMix = make([]cgen.Stmts, n2)
)
for pair := 0; pair < n1; pair++ {
toMix[pair] = two(pair)
}
if n1 < n2 {
toMix[n1] = one(n1)
}
const bundle = 2
var (
n3 = ceilQuo(n2, bundle)
ret = make(cgen.Gens, n3)
)
for x := range ret {
var (
first = x * bundle
past = min(first+bundle, n2)
)
ret[x] = mix(toMix[first:past])
}
return ret
}
layer8 := func() cgen.Gen {
if run == 1 {
l = il(group)
return layer9()
}
l = vb(a.name("l"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: l,
Init: il(group),
},
Cond: cgen.CmpL{
Expr1: l,
Expr2: il(group + run),
},
Post: cgen.IncPre{
Expr: l,
},
Body: layer9(),
}
}
layer7 := func() cgen.Gen {
var (
prePrep cgen.Gens
shift = 0
before = group * cw
preChan = before / a.fromHW
seen = before % a.fromHW
remain = a.fromHW - seen
)
for {
var do cgen.Gen
if shift != 0 || remain == a.fromHW {
var (
preCnt = a.BnPre
stmts = make(cgen.Stmts, preCnt*3)
)
preCh := cgen.Paren{
Inner: cgen.Add{
Expr1: il(preChan),
Expr2: cgen.Mul{
Expr1: cast(kUnroll),
Expr2: k,
},
},
}
for x := 0; x < preCnt; x++ {
var (
prePtr = a.bnPtrs[x]
preMul3 = vb(a.name("preMul"))
preAdd3 = vb(a.name("preAdd"))
)
stmts[x*3] = &bn.Load{
Ctx: a.bc,
Mas: prePtr,
Channel: preCh,
Mul: preMul3,
Add: preAdd3,
}
if x == 0 {
preMul2 = preMul3
preAdd2 = preAdd3
continue
}
stmts[x*3+1] = cgen.Assign{
Expr1: preMul2,
Expr2: avx.Mm512MulPs{
preMul2, preMul3,
},
}
stmts[x*3+2] = &bn.Apply{
Ctx: a.bc,
Mul: preMul3,
Add: preAdd3,
To: preAdd2,
}
}
do = stmts
}
if shift == 0 {
preMul1 = preMul2
preAdd1 = preAdd2
} else {
var (
n = min(cw-shift, remain)
bits = 1<<uint(n) - 1
mask = il(bits << uint(shift))
)
do = cgen.Stmts{
do,
cgen.Assign{
Expr1: preMul1,
Expr2: avx.Mm512MaskMovPs{
preMul1, mask, preMul2,
},
},
cgen.Assign{
Expr1: preAdd1,
Expr2: avx.Mm512MaskMovPs{
preAdd1, mask, preAdd2,
},
},
}
}
if do != nil {
prePrep = append(prePrep, do)
}
if shift += remain; shift >= cw {
elems = cw
break
}
if preChan++; preChan == preChans {
elems = shift
break
}
remain = a.fromHW
}
return cgen.Stmts{
prePrep,
layer8(),
}
}
layer6 := func() cgen.Gen {
var (
ret cgen.Gens
n = ceilQuo(preChans*a.fromHW, cw)
)
for group = 0; group < n; group += run {
var (
before = group * cw
seen = before % a.fromHW
remain = a.fromHW - seen
)
run = max(remain/cw, 1)
ret = append(ret, layer7())
}
return ret
}
layer5 := func() cgen.Gen {
if cw&(cw-1) != 0 {
panic("bug")
}
kUnroll = 1
for n := a.fromHW; n%cw != 0; n *= 2 {
kUnroll *= 2
}
var (
stmts = make(cgen.Stmts, 2)
iters = a.FromC / kUnroll
after = a.FromC % kUnroll
)
if iters > 0 {
k = vb(a.name("k"))
preChans = kUnroll
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: k,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: k,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: k,
},
Body: layer6(),
}
}
if after > 0 {
k = il(iters)
preChans = after
stmts[1] = layer6()
}
return stmts
}
layer4 := func() cgen.Gen {
var (
postCnt = a.BnPost
stmts = make(cgen.Stmts, 3+postCnt*2+1)
bias = vb(a.name("bias"))
mask = loMask(cells)
iPitch = cast(gc * a.biasBytes)
jPitch = cast(jUnroll * a.biasBytes)
from = addr(a.biases1, iPitch, i)
to = addr(a.biases2, iPitch, i)
postCh = postChan(0)
)
stmts[0] = &sumr.Pack{
Platform: a.platform,
Nms: a.nms,
Vars: sums,
}
stmts[1] = cgen.Var{
Type: avx.M512, What: bias,
Init: avx.Mm512MaskzLoaduPs{
mask, addr(from, jPitch, j),
},
}
stmts[2] = cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512AddPs{
sums[0], bias,
},
}
for x := 0; x < postCnt; x++ {
var (
postPtr = a.bnPtrs[a.BnPre+x]
postMul = vb(a.name("postMul"))
postAdd = vb(a.name("postAdd"))
)
stmts[3+x*2] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul,
Add: postAdd,
Cnt: cells,
}
stmts[3+x*2+1] = &bn.Apply{
Ctx: a.bc,
Mul: postMul,
Add: postAdd,
To: bias,
}
}
stmts[3+postCnt*2] = avx.Mm512MaskStoreuPs{
addr(to, jPitch, j), mask, bias,
}
return cgen.Gens{
layer5(),
stmts,
}
}
layer3 := func() cgen.Gen {
sums = make([]cgen.Gen, cells)
stmts := make(cgen.Stmts, cells)
for cell := range stmts {
sum := vb(a.name("sum"))
sums[cell] = sum
stmts[cell] = cgen.Var{
Type: avx.M512, What: sum,
Init: avx.Mm512SetzeroPs,
}
}
return cgen.Gens{
stmts,
layer4(),
}
}
layer2 := func() cgen.Gen {
var (
postPrep cgen.Gen
postCnt = a.BnPost
)
if postCnt > 0 {
postMuls = make([]cgen.Gen, cells)
toMix := make([]cgen.Stmts, cells)
for cell := range toMix {
var (
stmts = make(cgen.Stmts, postCnt*2)
postCh = postChan(cell)
)
for x := 0; x < postCnt; x++ {
var (
postPtr = a.bnPtrs[a.BnPre+x]
postMul = vb(a.name("postMul"))
)
stmts[x*2] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul,
}
if x == 0 {
postMuls[cell] = postMul
continue
}
prod := postMuls[cell]
stmts[x*2+1] = cgen.Assign{
Expr1: prod,
Expr2: avx.Mm512MulPs{
prod, postMul,
},
}
}
toMix[cell] = stmts
}
postPrep = mix(toMix)
}
return cgen.Gens{
postPrep,
layer3(),
}
}
layer1 := func() cgen.Gen {
var (
stmts = make(cgen.Stmts, 2)
iters = gc / jUnroll
after = gc % jUnroll
)
if iters > 0 {
j = vb(a.name("j"))
cells = jUnroll
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: j,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: j,
},
Body: layer2(),
}
}
if after > 0 {
j = il(iters)
cells = after
stmts[1] = layer2()
}
return stmts
}
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(a.strips),
},
Post: cgen.IncPre{
Expr: i,
},
Body: layer1(),
}
}

type Apply struct {
*Ctx
ToC int
FromC int
FromH int
FromW int
Ops []mod.Op
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *Apply) Prep() cgen.Gen {
a.layout = a.newLayout(
a.ToC, a.FromC, a.FromH, a.FromW,
)
const affix = "Apply"
sig := fmt.Sprint(
affix, " ",
a.ToC, a.FromC, a.FromH, a.FromW,
a.Ops, len(a.Tensors),
)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&apply{Apply: a},
cgen.Newline,
}
}

func (a *Apply) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type apply struct {
*Apply
tile int
tiles int
scrap int
hull1 int
hull2 int
calleeName string
wtPtr cgen.Gen
biasPtr cgen.Gen
seq []cgen.Gen
strips int
groupCells int
}

func (a *apply) Append(to []byte) []byte {
var (
threadVecs int
stripVecs = a.stripGroups2 * a.groupCells1
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
switch a.platform {
case raw.AVX512Float32:
threadVecs = 512
default:
panic("bug")
}
a.tile = ceilQuo(threadVecs, stripVecs)
a.tiles = a.strips1 / a.tile
a.scrap = a.strips1 % a.tile
a.hull1 = a.tiles + btoi(a.scrap > 0)
a.hull2 = a.hull1 + btoi(a.strips1 < a.strips2)
a.calleeName = a.name(a.callerName + "Callee")
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.hull2),
},
Team: team,
},
},
}.Append(to)
}

func (a *apply) calleeFunc() cgen.Gen {
var (
body = make(cgen.Stmts, 6)
tensors = vb(a.name("tensors"))
t = vb(a.name("t"))
)
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrdiffT, What: t,
Init: cgen.Elem{Arr: callee.Pt, Idx: il(0)},
}
body[2] = a.ptrs(tensors, t)
part := func(i, n int) {
body[i] = a.kernel()
if n < a.hull2 {
body[i] = cgen.If{
Cond: cgen.CmpL{
Expr1: t,
Expr2: il(n),
},
Then: cgen.Stmts{
body[i],
cgen.Return{},
},
}
}
}
if 0 < a.tiles {
a.strips = a.tile
a.groupCells = a.groupCells1
part(3, a.tiles)
}
if a.tiles < a.hull1 {
a.strips = a.scrap
a.groupCells = a.groupCells1
part(4, a.hull1)
}
if a.hull1 < a.hull2 {
a.strips = 1
a.groupCells = a.groupCells2
body[5] = a.kernel()
}
return callee.Func(body)
}

func (a *apply) ptrs(tensors, t cgen.Gen) cgen.Gen {
var (
stmts cgen.Stmts
s = t
tensorIdx = 0
wtPitch = cast(a.stripBytes1)
gc = a.groupCells1
biasPitch = cast(gc * a.biasBytes)
datPitch = cast(gc * a.datBytes)
)
if a.tile > 1 {
s = vb(a.name("s"))
var expr cgen.Gen = cgen.Mul{
Expr1: il(a.tile),
Expr2: t,
}
if i := a.tiles + 1; i < a.hull2 {
fix := cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpE{
Expr1: t,
Expr2: il(i),
},
Then: il(a.tile - a.scrap),
Else: il(0),
},
}
expr = cgen.Sub{
Expr1: expr,
Expr2: fix,
}
}
stmts = append(stmts, cgen.Var{
Type: cgen.PtrdiffT, What: s,
Init: expr,
})
}
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
var (
arranged1 = tensor()
arranged2 = cgen.Add{
Expr1: arranged1,
Expr2: cast(a.biasOffset),
}
)
a.wtPtr = vb(a.name("wtPtr"))
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar, What: a.wtPtr,
Init: addr(arranged1, wtPitch, s),
})
a.biasPtr = vb(a.name("biasPtr"))
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar, What: a.biasPtr,
Init: addr(arranged2, biasPitch, s),
})
dp := func() {
var (
datPtr = vb(a.name("datPtr"))
expr = tensor()
)
if len(a.seq) > 0 {
expr = addr(expr, datPitch, s)
}
a.seq = append(a.seq, datPtr)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: datPtr, Init: expr,
})
}
ndp := func(n int) {
for ; n > 0; n-- {
dp()
}
}
bp := func() {
bnPtr := vb(a.name("bnPtr"))
a.seq = append(a.seq, bnPtr)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: bnPtr,
Init: &bn.Offset{
Ctx: a.bc,
Mas: tensor(),
Channel: cgen.Mul{
Expr1: il(gc),
Expr2: s,
},
},
})
}
dp()
for i := range a.Ops {
op := &a.Ops[i]
switch op.Kind {
case mod.Add:
ndp(op.Int)
case mod.Bn:
bp()
case mod.ReLU:
default:
panic("bug")
}
}
ndp(len(a.Tensors) - tensorIdx)
return stmts
}

func (a *apply) kernel() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}

func (a *apply) m512() cgen.Gen {
const (
lanes = 16
bundle = 4
)
var (
i = vb(a.name("i"))
gc = a.groupCells
sums []cgen.Gen
jUnroll = 1 + gc%2
j cgen.Gen
groups int
edge int
dats []cgen.Gen
)
if cells := jUnroll * gc; cells < bundle*2 {
jUnroll *= bundle * 2 / cells
}
ldWts := func(wts cgen.Gen, pair, sides int) cgen.Gen {
var (
from = a.wtPtr
iPitch = a.stripBytes1
jPitch = jUnroll * gc * a.cellBytes
mask = loMask(lanes / 2 * sides)
)
from = cgen.Add{
Expr1: from,
Expr2: cast(pair * 2 * a.cellBytes),
}
from = addr(from, cast(iPitch), i)
from = addr(from, cast(jPitch), j)
return cgen.Var{
Type: avx.M512i, What: wts,
Init: avx.Mm512MaskzLoaduEpi32{
mask, from,
},
}
}
ldDat := func(pair, side int) cgen.Gen {
var (
k = pair*2 + side
group = k / gc
)
if dats[group] != nil {
return nil
}
dat := vb(a.name("dat"))
dats[group] = dat
var (
elems = a.cellWeights1
from = a.seq[0]
groupPitch = a.cellWeights1 * a.datBytes
jPitch = jUnroll * groupPitch
)
if group == groups-1 {
elems = edge
}
from = cgen.Add{
Expr1: from,
Expr2: cast(group * groupPitch),
}
from = addr(from, cast(jPitch), j)
return cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
loMask(elems), from,
},
}
}
cvt := func(wt, half cgen.Gen) cgen.Gen {
return cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512CvtphPs{half},
}
}
madd := func(wt cgen.Gen, pair, side int) cgen.Gen {
var (
k = pair*2 + side
dat = dats[k/gc]
sum = sums[k%gc]
)
return cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
wt, dat, sum,
},
}
}
two := func(pair int) cgen.Stmts {
var (
wts = vb(a.name("wts"))
wtLo = vb(a.name("wtLo"))
wtHi = vb(a.name("wtHi"))
)
return cgen.Stmts{
cgen.Stmts{
ldWts(wts, pair, 2),
ldDat(pair, 0),
ldDat(pair, 1),
},
cgen.Stmts{
cvt(wtLo, avx.Mm512Castsi512Si256{wts}),
cvt(wtHi, avx.Mm512Extracti64x4Epi64{
wts, il(1),
}),
},
cgen.Stmts{
madd(wtLo, pair, 0),
madd(wtHi, pair, 1),
},
}
}
one := func(pair int) cgen.Stmts {
var (
wts = vb(a.name("wts"))
wtLo = vb(a.name("wtLo"))
)
return cgen.Stmts{
cgen.Stmts{
ldWts(wts, pair, 1),
ldDat(pair, 0),
},
cvt(wtLo, avx.Mm512Castsi512Si256{wts}),
madd(wtLo, pair, 0),
}
}
layer4 := func() cgen.Gen {
dats = make([]cgen.Gen, groups)
var (
cells = groups * gc
pairs = cells / 2
slots = pairs + cells%2
toMix = make([]cgen.Stmts, slots)
)
for pair := 0; pair < pairs; pair++ {
toMix[pair] = two(pair)
}
if pairs < slots {
toMix[pairs] = one(pairs)
}
var (
n = ceilQuo(slots, bundle)
ret = make(cgen.Gens, n)
)
for x := range ret {
var (
first = x * bundle
past = min(first+bundle, slots)
)
ret[x] = mix(toMix[first:past])
}
return ret
}
layer3 := func() cgen.Gen {
var (
stmts = make(cgen.Stmts, 2)
iters = a.stripGroups1 / jUnroll
after = a.stripGroups1 % jUnroll
edge1 = a.cellWeights1
edge2 = edge1
)
if a.stripGroups1 < a.stripGroups2 {
after++
edge2 = a.cellWeights2
}
if iters > 0 {
j = vb(a.name("j"))
groups = jUnroll
edge = edge1
stmts[0] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: j,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: j,
},
Body: layer4(),
}
}
if after > 0 {
j = il(iters)
groups = after
edge = edge2
stmts[1] = layer4()
}
return stmts
}
layer2 := func() cgen.Gen {
var (
parts = ceilQuo(gc, lanes)
toMix = make([]cgen.Stmts, parts)
)
for part := range toMix {
var (
stmts [2]cgen.Stmts
first = part * lanes
cnt = min(gc-first, lanes)
bias = vb(a.name("bias"))
mask = loMask(cnt)
yield = sums[first]
at = 0
)
stmt := func(x int, s cgen.Gen) {
stmts[x] = append(stmts[x], s)
}
ae := func(ptr cgen.Gen, each int) cgen.Gen {
ptr = cgen.Add{
Expr1: ptr,
Expr2: cast(first * each),
}
return addr(ptr, cast(gc*each), i)
}
next := func() cgen.Gen {
if at++; at >= len(a.seq) {
return nil
}
return a.seq[at]
}
stmt(1, &sumr.Pack{
Platform: a.platform,
Nms: a.nms,
Vars: sums[first : first+cnt],
})
stmt(0, cgen.Var{
Type: avx.M512, What: bias,
Init: avx.Mm512MaskzLoaduPs{
mask, ae(a.biasPtr, a.biasBytes),
},
})
stmt(1, cgen.Assign{
Expr1: yield,
Expr2: avx.Mm512AddPs{
yield, bias,
},
})
for op := range a.Ops {
op := &a.Ops[op]
switch op.Kind {
case mod.Add:
var (
n = 1 + op.Int
ds = make([]cgen.Gen, n)
)
ds[0] = yield
for x := 1; x < n; x++ {
var (
dat = vb(a.name("dat"))
from = ae(next(), a.datBytes)
)
ds[x] = dat
stmt(0, cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
mask, from,
},
})
}
for n > 1 {
fold := n >> 1
n -= fold
for x := 0; x < fold; x++ {
keep := ds[x]
stmt(1, cgen.Assign{
Expr1: keep,
Expr2: avx.Mm512AddPs{
keep, ds[n+x],
},
})
}
}
case mod.Bn:
var (
bnPtr = next()
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
bnChan := cgen.Paren{
Inner: cgen.Add{
Expr1: il(first),
Expr2: cgen.Mul{
Expr1: il(gc),
Expr2: i,
},
},
}
stmt(0, &bn.Load{
Ctx: a.bc,
Mas: bnPtr,
Channel: bnChan,
Mul: bnMul,
Add: bnAdd,
Cnt: cnt,
})
stmt(1, &bn.Apply{
Ctx: a.bc,
Mul: bnMul,
Add: bnAdd,
To: yield,
})
case mod.ReLU:
stmt(1, &act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: yield,
})
default:
panic("bug")
}
}
for {
datPtr := next()
if datPtr == nil {
break
}
stmt(1, avx.Mm512MaskStoreuPs{
ae(datPtr, a.datBytes),
mask, yield,
})
}
toMix[part] = append(
stmts[0], stmts[1]...,
)
}
return cgen.Gens{
layer3(),
mix(toMix),
}
}
layer1 := func() cgen.Gen {
sums = make([]cgen.Gen, gc)
stmts := make(cgen.Stmts, gc)
for cell := range stmts {
sum := vb(a.name("sum"))
sums[cell] = sum
stmts[cell] = cgen.Var{
Type: avx.M512, What: sum,
Init: avx.Mm512SetzeroPs,
}
}
return cgen.Gens{
stmts,
layer2(),
}
}
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(a.strips),
},
Post: cgen.IncPre{
Expr: i,
},
Body: layer1(),
}
}

Top || internal/compile/author/glopl/glopl.go

package glopl

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/cov"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

type Ctx struct {
prefix string
platform raw.Platform
lanes int
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
var lanes int
switch pl.Config.Platform {
case raw.AVX512Float32:
lanes = 16
default:
panic("bug")
}
return &Ctx{
prefix: pl.Config.Prefix + "Glopl",
platform: pl.Config.Platform,
lanes: lanes,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]string),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

type Spec struct {
Kind raw.PoolingKind
Channels int
ElemBytes int
From SpecFrom
To SpecTo
}

type SpecFrom struct {
Height int
Width int
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type SpecTo struct {
Ops []mod.Op
Cnt int
}

func ceilQuo(n, d int) int {
return (n + d - 1) / d
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func cast(pitch int) cgen.Gen {
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: il(pitch),
}
}

func addr(ptr, pitch, idx cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: ptr,
Expr2: cgen.Mul{
Expr1: pitch,
Expr2: idx,
},
}
}

func mix(a ...cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Call struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
funcName string
}

func (c *Call) Prep() cgen.Gen {
sig := fmt.Sprintf("%v", c.Spec)
if prior, ok := c.dedup[sig]; ok {
c.funcName = prior
return nil
}
c.funcName = c.name(c.prefix)
c.dedup[sig] = c.funcName
return cgen.Gens{
&funcDefs{
Ctx: c.Ctx,
Spec: c.Spec,
FuncName: c.funcName,
},
cgen.Newline,
}
}

func (c *Call) Append(to []byte) []byte {
var (
tensors = vb(c.name("tensors"))
ptrs = cgen.CommaLines(c.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(c.funcName),
Args: cgen.CommaSpaced{
c.Team, tensors,
},
},
}.Append(to)
}

type funcDefs struct {
*Ctx
*Spec
FuncName string
unpacked bool
chanTile int
chanTiles int
chanScrap int
funcName string
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
datSplit int
bnSplit int
}

func (f *funcDefs) Append(to []byte) []byte {
var (
elemCost = len(f.From.Pitch1Bytes)
threadVecs int
)
switch f.platform {
case raw.AVX512Float32:
threadVecs = 512 / elemCost
if min := 8; threadVecs < min {
threadVecs = min
}
default:
panic("bug")
}
var (
width = f.From.Width
tight = width * f.ElemBytes
chanVecs int
)
for _, pitch := range f.From.Pitch1Bytes {
if pitch != tight {
f.unpacked = true
break
}
}
if f.unpacked {
widthVecs := ceilQuo(width, f.lanes)
chanVecs = f.From.Height * widthVecs
} else {
chanElems := f.From.Height * width
chanVecs = ceilQuo(chanElems, f.lanes)
}
f.chanTile = ceilQuo(threadVecs, chanVecs)
f.chanTiles = f.Channels / f.chanTile
f.chanScrap = f.Channels % f.chanTile
f.funcName = f.name(f.FuncName + "Callee")
var (
team = vb(f.name("team"))
tensors = vb(f.name("tensors"))
chanHull = f.chanTiles
)
if f.chanScrap > 0 {
chanHull++
}
return cgen.Gens{
f.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: f.FuncName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: f.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: f.tc,
Callee: vb(f.funcName),
Any: tensors,
Hull: []cgen.Gen{
il(chanHull),
},
Team: team,
},
},
}.Append(to)
}

func (f *funcDefs) calleeFunc() cgen.Gen {
var (
body = make(cgen.Stmts, 5)
tensors = vb(f.name("tensors"))
c = vb(f.name("c"))
)
callee := &threader.Callee{
Ctx: f.tc,
Name: f.funcName,
Task: vb(f.name("task")),
Pt: vb(f.name("pt")),
}
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrdiffT, What: c,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.Zero},
}
body[2] = f.ptrs(tensors, c)
if f.chanTiles > 0 {
kern := f.kernel(f.chanTile)
if f.chanScrap > 0 {
body[3] = cgen.If{
Cond: cgen.CmpL{
Expr1: c,
Expr2: il(f.chanTiles),
},
Then: cgen.Stmts{
kern,
cgen.Return{},
},
}
} else {
body[3] = kern
}
}
if f.chanScrap > 0 {
body[4] = f.kernel(f.chanScrap)
}
return callee.Func(body)
}

func (f *funcDefs) ptrs(tensors, c cgen.Gen) cgen.Gen {
var (
stmts cgen.Stmts
pitch2Idx = 0
tensorIdx = 0
)
pitch2 := func() int {
i := pitch2Idx
pitch2Idx++
if i < len(f.From.Pitch2Bytes) {
return f.From.Pitch2Bytes[i]
}
return f.ElemBytes
}
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors, Idx: il(i),
}
}
datPtr := func() {
var (
ptr = vb(f.name("ptr"))
cPitch = cast(pitch2() * f.chanTile)
cAddr = addr(tensor(), cPitch, c)
)
f.datPtrs = append(f.datPtrs, ptr)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr, Init: cAddr,
})
}
ndp := func(n int) {
for ; n > 0; n-- {
datPtr()
}
}
bnPtr := func() {
ptr := vb(f.name("ptr"))
f.bnPtrs = append(f.bnPtrs, ptr)
stmts = append(stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr,
Init: &bn.Offset{
Ctx: f.bc,
Mas: tensor(),
Channel: cgen.Mul{
Expr1: il(f.chanTile),
Expr2: c,
},
},
})
}
do := func(from bool) {
var ops []mod.Op
if from {
datPtr()
ops = f.From.Ops
} else {
ops = f.To.Ops
}
for i := range ops {
switch op := &ops[i]; op.Kind {
case mod.Add:
ndp(op.Int)
case mod.Bn:
bnPtr()
case mod.ReLU:
default:
panic("bug")
}
}
if from {
f.datSplit = len(f.datPtrs)
f.bnSplit = len(f.bnPtrs)
} else {
ndp(f.To.Cnt)
}
}
do(true)
do(false)
return stmts
}

func (f *funcDefs) kernel(chans int) cgen.Gen {
switch f.platform {
case raw.AVX512Float32:
return f.m512(chans)
default:
panic("bug")
}
}

func (f *funcDefs) m512(chans int) cgen.Gen {
if f.unpacked {
return f.m512Unpacked(chans)
}
return f.m512Semipacked(chans)
}

func (f *funcDefs) m512Unpacked(chans int) cgen.Gen {
const (
lanes = 16
laneBytes = 4
)
var (
unroll = 1
height = f.From.Height
width = f.From.Width
vecs = ceilQuo(width, lanes)
)
if f.datSplit == 1 {
unroll = 4
}
iUnroll, jUnroll, kUnroll := cov.Box(
unroll, 4, chans, height, vecs,
)
var (
buf = vb(f.name("buf"))
mask = vb(f.name("mask"))
bufChans = lanes - lanes%iUnroll
maskFull = il(1<<uint(bufChans) - 1)
iIters = chans / iUnroll
iAfter = chans % iUnroll
bnMuls = make([][]cgen.Gen, iUnroll)
bnAdds = make([][]cgen.Gen, iUnroll)
jIters int
jAfter int
kIters int
kAfter int
accs []cgen.Gen
accLanes []int
)
leaf := func(i, j, k cgen.Gen, ii, jj, kk, l, a int) cgen.Stmts {
var (
acc = accs[a]
ldAcc = false
)
if acc == nil {
acc = vb(f.name("acc"))
ldAcc = true
accs[a] = acc
accLanes[a] = l
}
pull := &m512Pull{
Ctx: f.Ctx,
Spec: f.Spec,
Lanes: l,
Ptrs: make([]cgen.Gen, f.datSplit),
BnMuls: bnMuls[ii],
BnAdds: bnAdds[ii],
Acc: acc,
LdAcc: ldAcc,
}
for x, ptr := range f.datPtrs[:f.datSplit] {
var (
iiPitch = f.From.Pitch2Bytes[x]
jjPitch = f.From.Pitch1Bytes[x]
kkPitch = lanes * laneBytes
iPitch = iiPitch * iUnroll
jPitch = jjPitch * jUnroll
kPitch = kkPitch * kUnroll
)
ptr = cgen.Add{
Expr1: ptr,
Expr2: cast(iiPitch*ii + jjPitch*jj + kkPitch*kk),
}
if iIters > 0 {
ptr = addr(ptr, cast(iPitch), i)
}
if jIters > 0 {
ptr = addr(ptr, cast(jPitch), j)
}
if kIters > 0 {
ptr = addr(ptr, cast(kPitch), k)
}
pull.Ptrs[x] = ptr
}
return pull.Stmts()
}
layer5 := func(i, j cgen.Gen, iCnt, jCnt int, first bool) cgen.Gen {
stmts := make(cgen.Stmts, 3)
if kIters > 0 {
peeled := 0
if first {
peeled = 1
var (
peel = make([]cgen.Stmts, len(accs))
k = cgen.Zero
)
for ii := 0; ii < iCnt; ii++ {
a := ii * jUnroll * kUnroll
for jj := 0; jj < jCnt; jj++ {
for kk := 0; kk < kUnroll; kk++ {
peel[a] = leaf(i, j, k, ii, jj, kk, lanes, a)
a++
}
}
}
stmts[0] = mix(peel...)
}
if peeled < kIters {
var (
body = make([]cgen.Stmts, len(accs))
k = vb(f.name("k"))
)
for ii := 0; ii < iCnt; ii++ {
a := ii * jUnroll * kUnroll
for jj := 0; jj < jCnt; jj++ {
for kk := 0; kk < kUnroll; kk++ {
body[a] = leaf(i, j, k, ii, jj, kk, lanes, a)
a++
}
}
}
stmts[1] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: k,
Init: il(peeled),
},
Cond: cgen.CmpL{
Expr1: k, Expr2: il(kIters),
},
Post: cgen.IncPre{Expr: k},
Body: mix(body...),
}
}
}
if kAfter > 0 {
var (
full = kAfter / lanes
part = kAfter % lanes
tail = make([]cgen.Stmts, len(accs))
k = il(kIters)
)
for ii := 0; ii < iCnt; ii++ {
for jj := 0; jj < jCnt; jj++ {
a := (ii*jUnroll + jj) * kUnroll
for kk := 0; kk <= full; kk++ {
l := lanes
if kk == full {
if l = part; l == 0 {
break
}
}
tail[a] = leaf(i, j, k, ii, jj, kk, l, a)
a++
}
}
}
stmts[2] = mix(tail...)
}
return stmts
}
layer4 := func(i cgen.Gen, iCnt int) cgen.Gen {
stmts := make(cgen.Stmts, 3)
if jIters > 0 {
j := cgen.Zero
stmts[0] = layer5(i, j, iCnt, jUnroll, true)
if jIters > 1 {
j := vb(f.name("j"))
stmts[1] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: cgen.One,
},
Cond: cgen.CmpL{
Expr1: j, Expr2: il(jIters),
},
Post: cgen.IncPre{Expr: j},
Body: layer5(i, j, iCnt, jUnroll, false),
}
}
}
if jAfter > 0 {
var (
j = il(jIters)
first = jIters == 0
)
stmts[2] = layer5(i, j, iCnt, jAfter, first)
}
return stmts
}
layer3 := func(i cgen.Gen, iCnt int) cgen.Gen {
jIters = height / jUnroll
jAfter = height % jUnroll
kIters = width / (kUnroll * lanes)
kAfter = width % (kUnroll * lanes)
accs = make([]cgen.Gen, iCnt*jUnroll*kUnroll)
accLanes = make([]int, len(accs))
var (
stmts = layer4(i, iCnt)
n = 0
)
for x, acc := range accs {
if acc != nil {
accs[n] = acc
accLanes[n] = accLanes[x]
n++
}
}
accs = accs[:n]
accLanes = accLanes[:n]
return stmts
}
layer2 := func(i cgen.Gen, iCnt int) (stmts cgen.Stmts) {
stmts = make(cgen.Stmts, 5)
stmts[0] = layer3(i, iCnt)
fold := &m512Fold{
Ctx: f.Ctx,
Spec: f.Spec,
Chans: iCnt,
Frame: iUnroll,
Accs: accs,
Lanes: accLanes,
}
var folded cgen.Gen
stmts[1], folded = fold.Gens()
stmts[2] = cgen.Assign{
Expr1: buf,
Expr2: avx.Mm512MaskMovPs{
buf, mask, folded,
},
}
if iCnt != iUnroll {
return
}
stmts[3] = cgen.AndAssign{
Expr1: mask,
Expr2: cgen.ShiftHigh{
Expr1: mask,
Expr2: il(iUnroll),
},
}
if chans < bufChans {
return
}
ch := cgen.Paren{Inner: cgen.Sub{
Expr1: cgen.Mul{
Expr1: cast(iUnroll),
Expr2: i,
},
Expr2: il(bufChans - iUnroll),
}}
stmts[4] = cgen.If{
Cond: cgen.Unlikely{
Cond: cgen.IsZero{Expr: mask},
},
Then: cgen.Stmts{
cgen.Assign{
Expr1: mask,
Expr2: maskFull,
},
&m512Push{
Ctx: f.Ctx,
Spec: f.Spec,
DatPtrs: f.datPtrs[f.datSplit:],
BnPtrs: f.bnPtrs[f.bnSplit:],
Buf: buf,
Chan: ch,
ChanCnt: bufChans,
},
},
}
return
}
layer1 := func(i cgen.Gen, iCnt int) cgen.Gen {
var (
bnLds = make([]cgen.Stmts, iCnt)
bnCnt = f.bnSplit
)
for ii := 0; ii < iCnt; ii++ {
var (
ch = il(ii)
muls = make([]cgen.Gen, bnCnt)
adds = make([]cgen.Gen, bnCnt)
lds = make(cgen.Stmts, bnCnt)
)
if iIters > 0 {
ch = cgen.Paren{Inner: cgen.Add{
Expr1: ch,
Expr2: cgen.Mul{
Expr1: cast(iUnroll),
Expr2: i,
},
}}
}
for x, ptr := range f.bnPtrs[:bnCnt] {
var (
bnMul = vb(f.name("bnMul"))
bnAdd = vb(f.name("bnAdd"))
)
muls[x] = bnMul
adds[x] = bnAdd
lds[x] = &bn.Load{
Ctx: f.bc,
Mas: ptr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
bnMuls[ii] = muls
bnAdds[ii] = adds
bnLds[ii] = lds
}
return cgen.Gens{
mix(bnLds...),
layer2(i, iCnt),
}
}
stmts := make(cgen.Stmts, 5)
stmts[0] = cgen.Var{
Type: avx.M512, What: buf,
Init: avx.Mm512SetzeroPs,
}
stmts[1] = cgen.Var{
Type: avx.Mmask16, What: mask,
Init: maskFull,
}
if iIters > 0 {
i := vb(f.name("i"))
stmts[2] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: i, Expr2: il(iIters),
},
Post: cgen.IncPre{Expr: i},
Body: layer1(i, iUnroll),
}
}
if iAfter > 0 {
var (
i = il(iIters)
jkUnroll = unroll / iAfter
)
jUnroll, kUnroll = cov.Rect(
jkUnroll, jkUnroll, height, vecs,
)
stmts[3] = layer1(i, iAfter)
}
if rem := chans % bufChans; rem > 0 {
stmts[4] = &m512Push{
Ctx: f.Ctx,
Spec: f.Spec,
DatPtrs: f.datPtrs[f.datSplit:],
BnPtrs: f.bnPtrs[f.bnSplit:],
Buf: buf,
Chan: il(chans - rem),
ChanCnt: rem,
}
}
return stmts
}

func (f *funcDefs) m512Semipacked(chans int) cgen.Gen {
const (
lanes = 16
laneBytes = 4
)
var (
unroll = 1
elems = f.From.Height * f.From.Width
vecs = ceilQuo(elems, lanes)
)
switch f.datSplit {
case 1:
unroll = 8
case 2:
unroll = 2
}
iUnroll, _ := cov.Rect(
unroll, 4, chans, vecs,
)
var (
buf = vb(f.name("buf"))
mask = vb(f.name("mask"))
bufChans = lanes - lanes%iUnroll
maskFull = il(1<<uint(bufChans) - 1)
iIters = chans / iUnroll
iAfter = chans % iUnroll
bnMuls = make([][]cgen.Gen, iUnroll)
bnAdds = make([][]cgen.Gen, iUnroll)
jUnroll int
jIters int
accs []cgen.Gen
accLanes []int
)
leaf := func(i, j cgen.Gen, ii, jj, l, a int) cgen.Stmts {
var (
acc = accs[a]
ldAcc = false
)
if acc == nil {
acc = vb(f.name("acc"))
ldAcc = true
accs[a] = acc
accLanes[a] = l
}
pull := &m512Pull{
Ctx: f.Ctx,
Spec: f.Spec,
Lanes: l,
Ptrs: make([]cgen.Gen, f.datSplit),
BnMuls: bnMuls[ii],
BnAdds: bnAdds[ii],
Acc: acc,
LdAcc: ldAcc,
}
for x, ptr := range f.datPtrs[:f.datSplit] {
var (
iiPitch = f.From.Pitch2Bytes[x]
jjPitch = lanes * laneBytes
iPitch = iiPitch * iUnroll
jPitch = jjPitch * jUnroll
)
ptr = cgen.Add{
Expr1: ptr,
Expr2: cast(iiPitch*ii + jjPitch*jj),
}
if iIters > 0 {
ptr = addr(ptr, cast(iPitch), i)
}
if jIters > 0 {
ptr = addr(ptr, cast(jPitch), j)
}
pull.Ptrs[x] = ptr
}
return pull.Stmts()
}
layer3 := func(i cgen.Gen, iCnt int) cgen.Gen {
stmts := make(cgen.Stmts, 3)
jUnroll = unroll / iCnt
jIter := jUnroll * lanes
jIters = elems / jIter
jAfter := elems % jIter
if jIters > 0 {
n := iCnt * jUnroll
accs = make([]cgen.Gen, n)
accLanes = make([]int, n)
var (
peel = make([]cgen.Stmts, n)
j = cgen.Zero
)
for a, ii := 0, 0; ii < iCnt; ii++ {
for jj := 0; jj < jUnroll; jj++ {
peel[a] = leaf(i, j, ii, jj, lanes, a)
a++
}
}
stmts[0] = mix(peel...)
if jIters > 1 {
var (
body = make([]cgen.Stmts, n)
j = vb(f.name("j"))
)
for a, ii := 0, 0; ii < iCnt; ii++ {
for jj := 0; jj < jUnroll; jj++ {
body[a] = leaf(i, j, ii, jj, lanes, a)
a++
}
}
stmts[1] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: j,
Init: cgen.One,
},
Cond: cgen.CmpL{
Expr1: j, Expr2: il(jIters),
},
Post: cgen.IncPre{Expr: j},
Body: mix(body...),
}
}
}
if jAfter > 0 {
var (
full = jAfter / lanes
part = jAfter % lanes
jCnt = jUnroll
)
if jIters == 0 {
if jCnt = full; part > 0 {
jCnt++
}
accs = make([]cgen.Gen, iCnt*jCnt)
accLanes = make([]int, len(accs))
}
var (
tail = make([]cgen.Stmts, len(accs))
j = il(jIters)
)
for ii := 0; ii < iCnt; ii++ {
a := ii * jCnt
for jj := 0; jj <= full; jj++ {
l := lanes
if jj == full {
if l = part; l == 0 {
break
}
}
tail[a] = leaf(i, j, ii, jj, l, a)
a++
}
}
stmts[2] = mix(tail...)
}
return stmts
}
layer2 := func(i cgen.Gen, iCnt int) (stmts cgen.Stmts) {
stmts = make(cgen.Stmts, 5)
stmts[0] = layer3(i, iCnt)
fold := &m512Fold{
Ctx: f.Ctx,
Spec: f.Spec,
Chans: iCnt,
Frame: iUnroll,
Accs: accs,
Lanes: accLanes,
}
var folded cgen.Gen
stmts[1], folded = fold.Gens()
stmts[2] = cgen.Assign{
Expr1: buf,
Expr2: avx.Mm512MaskMovPs{
buf, mask, folded,
},
}
if iCnt != iUnroll {
return
}
stmts[3] = cgen.AndAssign{
Expr1: mask,
Expr2: cgen.ShiftHigh{
Expr1: mask,
Expr2: il(iUnroll),
},
}
if chans < bufChans {
return
}
ch := cgen.Paren{Inner: cgen.Sub{
Expr1: cgen.Mul{
Expr1: cast(iUnroll),
Expr2: i,
},
Expr2: il(bufChans - iUnroll),
}}
stmts[4] = cgen.If{
Cond: cgen.Unlikely{
Cond: cgen.IsZero{Expr: mask},
},
Then: cgen.Stmts{
cgen.Assign{
Expr1: mask,
Expr2: maskFull,
},
&m512Push{
Ctx: f.Ctx,
Spec: f.Spec,
DatPtrs: f.datPtrs[f.datSplit:],
BnPtrs: f.bnPtrs[f.bnSplit:],
Buf: buf,
Chan: ch,
ChanCnt: bufChans,
},
},
}
return
}
layer1 := func(i cgen.Gen, iCnt int) cgen.Gen {
var (
bnLds = make([]cgen.Stmts, iCnt)
bnCnt = f.bnSplit
)
for ii := 0; ii < iCnt; ii++ {
var (
ch = il(ii)
muls = make([]cgen.Gen, bnCnt)
adds = make([]cgen.Gen, bnCnt)
lds = make(cgen.Stmts, bnCnt)
)
if iIters > 0 {
ch = cgen.Paren{Inner: cgen.Add{
Expr1: ch,
Expr2: cgen.Mul{
Expr1: cast(iUnroll),
Expr2: i,
},
}}
}
for x, ptr := range f.bnPtrs[:bnCnt] {
var (
bnMul = vb(f.name("bnMul"))
bnAdd = vb(f.name("bnAdd"))
)
muls[x] = bnMul
adds[x] = bnAdd
lds[x] = &bn.Load{
Ctx: f.bc,
Mas: ptr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
bnMuls[ii] = muls
bnAdds[ii] = adds
bnLds[ii] = lds
}
return cgen.Gens{
mix(bnLds...),
layer2(i, iCnt),
}
}
stmts := make(cgen.Stmts, 5)
stmts[0] = cgen.Var{
Type: avx.M512, What: buf,
Init: avx.Mm512SetzeroPs,
}
stmts[1] = cgen.Var{
Type: avx.Mmask16, What: mask,
Init: maskFull,
}
if iIters > 0 {
i := vb(f.name("i"))
stmts[2] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Zero,
},
Cond: cgen.CmpL{
Expr1: i, Expr2: il(iIters),
},
Post: cgen.IncPre{Expr: i},
Body: layer1(i, iUnroll),
}
}
if iAfter > 0 {
i := il(iIters)
stmts[3] = layer1(i, iAfter)
}
if rem := chans % bufChans; rem > 0 {
stmts[4] = &m512Push{
Ctx: f.Ctx,
Spec: f.Spec,
DatPtrs: f.datPtrs[f.datSplit:],
BnPtrs: f.bnPtrs[f.bnSplit:],
Buf: buf,
Chan: il(chans - rem),
ChanCnt: rem,
}
}
return stmts
}

type m512Pull struct {
*Ctx
*Spec
Lanes int
Ptrs []cgen.Gen
BnMuls []cgen.Gen
BnAdds []cgen.Gen
Acc cgen.Gen
LdAcc bool
mask cgen.Gen
ptrIdx int
bnIdx int
lds cgen.Stmts
nonlds cgen.Stmts
}

func (m *m512Pull) ld() (dat cgen.Gen) {
if m.ptrIdx == 0 && m.LdAcc {
dat = m.Acc
} else {
dat = vb(m.name("dat"))
}
m.lds = append(m.lds, cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
m.mask,
m.Ptrs[m.ptrIdx],
},
})
m.ptrIdx++
return
}

func (m *m512Pull) nonld(a cgen.Gen) {
m.nonlds = append(m.nonlds, a)
}

func (m *m512Pull) adder(a ...cgen.Gen) {
for n := len(a); n > 1; {
fold := n >> 1
n -= fold
for i := 0; i < fold; i++ {
to := a[i]
m.nonld(cgen.Assign{
Expr1: to,
Expr2: avx.Mm512MaskAddPs{
to, m.mask,
to, a[n+i],
},
})
}
}
}

func (m *m512Pull) apply(dat cgen.Gen, ops []mod.Op) bool {
last := len(ops) - 1
for i := range ops {
switch op := &ops[i]; op.Kind {
case mod.Add:
n := op.Int
if i == last &&
!m.LdAcc && m.Kind == raw.AvgGlobal {
dats := make([]cgen.Gen, 2+n)
dats[0] = m.Acc
dats[1] = dat
for j := 0; j < n; j++ {
dats[2+j] = m.ld()
}
m.adder(dats...)
return true
}
dats := make([]cgen.Gen, 1+n)
dats[0] = dat
for j := 1; j <= n; j++ {
dats[j] = m.ld()
}
m.adder(dats...)
case mod.Bn:
j := m.bnIdx
m.bnIdx++
m.nonld(&bn.Apply{
Ctx: m.bc,
Mul: m.BnMuls[j],
Add: m.BnAdds[j],
To: dat,
})
case mod.ReLU:
ns := op.Float
if i == last && ns == 0 &&
!m.LdAcc && m.Kind == raw.MaxGlobal {
return false
}
m.nonld(&act.ReLU{
Ctx: m.ac,
NegSlope: ns,
Var: dat,
})
default:
panic("bug")
}
}
return m.LdAcc
}

func (m *m512Pull) Stmts() cgen.Stmts {
m.mask = il(1<<uint(m.Lanes) - 1)
var (
dat = m.ld()
done = m.apply(dat, m.From.Ops)
)
if !done {
switch m.Kind {
case raw.AvgGlobal:
m.adder(m.Acc, dat)
case raw.MaxGlobal:
m.nonld(cgen.Assign{
Expr1: m.Acc,
Expr2: avx.Mm512MaskMaxPs{
m.Acc, m.mask,
m.Acc, dat,
},
})
default:
panic("bug")
}
}
return append(
m.lds, m.nonlds...,
)
}

type m512Fold struct {
*Ctx
*Spec
Chans int
Frame int
Accs []cgen.Gen
Lanes []int
pm1lo cgen.Gen
pm1hi cgen.Gen
pm4lo cgen.Gen
pm4hi cgen.Gen
}

func (m *m512Fold) combine(a ...cgen.Gen) cgen.Gen {
switch m.Kind {
case raw.AvgGlobal:
return avx.Mm512MaskAddPs(a)
case raw.MaxGlobal:
return avx.Mm512MaskMaxPs(a)
default:
panic("bug")
}
}

func (m *m512Fold) chanwise() cgen.Stmts {
var (
n = m.Chans
each = len(m.Accs) / n
stmts = make([]cgen.Stmts, n)
)
for i := 0; i < n; i++ {
j := i * each
for cnt := each; cnt > 1; {
fold := cnt >> 1
cnt -= fold
for k := j; k < j+fold; k++ {
var (
a1, a2 = &m.Accs[k], &m.Accs[k+cnt]
l1, l2 = &m.Lanes[k], &m.Lanes[k+cnt]
)
if *l1 < *l2 {
*a1, *a2 = *a2, *a1
*l1, *l2 = *l2, *l1
}
stmts[i] = append(stmts[i], cgen.Assign{
Expr1: *a1,
Expr2: m.combine(
*a1, il(1<<uint(*l2)-1),
*a1, *a2,
),
})
}
}
m.Accs[i] = m.Accs[j]
m.Lanes[i] = m.Lanes[j]
}
return mix(stmts...)
}

func (m *m512Fold) funnel(xs []int, fit int) (stmts cgen.Stmts) {
n := len(xs)
if n > 1 {
var (
n2 = n >> 1
n1 = n - n2
xs1 = make([]int, n1)
xs2 = make([]int, n2)
)
for i, x := range xs {
if ii := i >> 1; i&1 == 0 {
xs1[ii] = x
} else {
xs2[ii] = x
}
}
stmts = mix(
m.funnel(xs1, fit*2),
m.funnel(xs2, fit*2),
)
} else {
if fit > 1 {
if m.Lanes[xs[0]] <= fit {
return
}
}
stmts = m.funnel(xs, fit*2)
}
var (
acc1 = m.Accs[xs[0]]
acc2 = acc1
)
if n > 1 {
acc2 = m.Accs[xs[1]]
}
permute := func(pm cgen.Gen) cgen.Gen {
if n > 1 {
return avx.Mm512Permutex2varPs{
acc1, pm, acc2,
}
}
return avx.Mm512PermutexvarPs{
pm, acc1,
}
}
inner := func(ctrl int) cgen.Gen {
return avx.Mm512ShufflePs{
acc1, acc2, il(ctrl),
}
}
outer := func(ctrl int) cgen.Gen {
return avx.Mm512ShuffleF32x4{
acc1, acc2, il(ctrl),
}
}
var hi cgen.Gen
for _, x := range xs {
if m.Lanes[x] > fit {
hi = vb(m.name("hi"))
break
}
}
if hi != nil {
var call cgen.Gen
switch fit {
case 1:
m.pm1hi = vb(m.name("pm1hi"))
call = permute(m.pm1hi)
case 2:
call = inner(0xee)
case 4:
if n > 1 {
if m.pm4hi == nil {
m.pm4hi = vb(m.name("pm4hi"))
}
call = permute(m.pm4hi)
} else {
call = outer(0x01)
}
case 8:
call = outer(0xee)
}
stmts = append(stmts, cgen.Var{
Type: avx.M512, What: hi,
Init: call,
})
}
if n > 1 || fit == 1 {
var call cgen.Gen
switch fit {
case 1:
m.pm1lo = vb(m.name("pm1lo"))
call = permute(m.pm1lo)
case 2:
call = inner(0x44)
case 4:
if m.pm4lo == nil {
m.pm4lo = vb(m.name("pm4lo"))
}
call = permute(m.pm4lo)
case 8:
call = outer(0x44)
}
stmts = append(stmts, cgen.Assign{
Expr1: acc1, Expr2: call,
})
}
if hi != nil {
mask := 0
for i := n - 1; i >= 0; i-- {
mask <<= uint(fit)
if l := &m.Lanes[xs[i]]; *l > fit {
mask |= 1<<uint(*l-fit) - 1
*l = fit
}
}
if fit == 1 {
for bits := m.Frame; bits < 16; {
mask |= mask << uint(bits)
bits *= 2
}
mask &= 0xffff
}
stmts = append(stmts, cgen.Assign{
Expr1: acc1,
Expr2: m.combine(
acc1, il(mask),
acc1, hi,
),
})
}
return
}

func (m *m512Fold) pms() cgen.Stmts {
const lanes = 16
stmts := make(cgen.Stmts, 0, 4)
decl := func(pm cgen.Gen, fn func(int) int) {
if pm == nil {
return
}
set := make(avx.Mm512SetEpi32, lanes)
for i := 0; i < lanes; i++ {
set[lanes-1-i] = il(fn(i))
}
stmts = append(stmts, cgen.Var{
Type: avx.M512i, What: pm,
Init: set,
})
}
decl(m.pm1lo, func(i int) int {
i %= m.Frame
return i&-2 + i&1*lanes
})
decl(m.pm1hi, func(i int) int {
i %= m.Frame
return i | 1 + i&1*lanes
})
decl(m.pm4lo, func(i int) int {
return i&^4 + i&4*(lanes/4)
})
decl(m.pm4hi, func(i int) int {
return i | 4 + i&4*(lanes/4)
})
return stmts
}

func (m *m512Fold) Gens() (stmts, acc cgen.Gen) {
var (
stmts1 = m.chanwise()
xs = make([]int, m.Chans)
)
for i := range xs {
xs[i] = i
}
var (
stmts3 = m.funnel(xs, 1)
stmts2 = m.pms()
)
stmts = cgen.Gens{
stmts1, stmts2, stmts3,
}
acc = m.Accs[0]
return
}

type m512Push struct {
*Ctx
*Spec
DatPtrs []cgen.Gen
BnPtrs []cgen.Gen
Buf cgen.Gen
Chan cgen.Gen
ChanCnt int
}

func (m *m512Push) Append(to []byte) []byte {
var (
lds cgen.Stmts
nonlds cgen.Stmts
pitch = cast(m.ElemBytes)
mask = il(1<<uint(m.ChanCnt) - 1)
)
ld := func(a cgen.Gen) {
lds = append(lds, a)
}
nonld := func(a cgen.Gen) {
nonlds = append(nonlds, a)
}
datPtr := func() cgen.Gen {
ptr := m.DatPtrs[0]
m.DatPtrs = m.DatPtrs[1:]
return addr(ptr, pitch, m.Chan)
}
if m.Kind == raw.AvgGlobal {
var (
hw = m.From.Height * m.From.Width
rcp = 1 / avx.Mm512Set1PsLit(hw)
)
nonld(cgen.Assign{
Expr1: m.Buf,
Expr2: avx.Mm512MulPs{m.Buf, rcp},
})
}
for i := range m.To.Ops {
op := &m.To.Ops[i]
switch op.Kind {
case mod.Add:
var (
n = 1 + op.Int
dats = make([]cgen.Gen, n)
)
dats[0] = m.Buf
for j := 1; j < n; j++ {
dats[j] = vb(m.name("dat"))
ld(cgen.Var{
Type: avx.M512, What: dats[j],
Init: avx.Mm512MaskzLoaduPs{
mask, datPtr(),
},
})
}
for n > 1 {
fold := n >> 1
n -= fold
for j := 0; j < fold; j++ {
keep := dats[n-1-j]
nonld(cgen.Assign{
Expr1: keep,
Expr2: avx.Mm512AddPs{
keep, dats[n+j],
},
})
}
}
case mod.Bn:
var (
bnMul = vb(m.name("bnMul"))
bnAdd = vb(m.name("bnAdd"))
)
ld(&bn.Load{
Ctx: m.bc,
Mas: m.BnPtrs[0],
Channel: m.Chan,
Mul: bnMul,
Add: bnAdd,
Cnt: m.ChanCnt,
})
m.BnPtrs = m.BnPtrs[1:]
nonld(&bn.Apply{
Ctx: m.bc,
Mul: bnMul,
Add: bnAdd,
To: m.Buf,
})
case mod.ReLU:
nonld(&act.ReLU{
Ctx: m.ac,
NegSlope: op.Float,
Var: m.Buf,
})
default:
panic("bug")
}
}
for n := m.To.Cnt; n > 0; n-- {
nonld(avx.Mm512MaskStoreuPs{
datPtr(), mask, m.Buf,
})
}
to = lds.Append(to)
to = nonlds.Append(to)
return to
}

Top || internal/compile/author/hc/hc.go

package hc

import "NN-512/internal/compile/author/cgen"

type Section int

const (
HFirst Section = iota
HPragmaOnce
HLicense
HInclude
HLinkage1
HParams1
HNet
HEngine
HParams2
HLinkage2
HLast
CFirst
CToBuild
CLicense
CInclude
CErrmsg
CThreader
CExp
CSoftmax
CRsqrt
CBn
CElwi
CGlopl
CTwopl
CThrpl
CFc
COne
CThree
CStrider
CLoom
CNet
CEngine
CLast
sectionCount
)

type Sections struct {
a [sectionCount][]byte
}

func (s *Sections) Append(to Section, from ...cgen.Gen) {
for _, gen := range from {
if gen != nil {
s.a[to] = gen.Append(s.a[to])
}
}
}

func (s *Sections) Join() (h, c []byte) {
h = s.join(HFirst, HLast)
c = s.join(CFirst, CLast)
return
}

func (s *Sections) join(first, last Section) (to []byte) {
const (
brace1 = '{'
brace2 = '}'
newline = '\n'
paren1 = '('
paren2 = ')'
tab = '\t'
)
var prev byte
var indent []byte
for _, from := range s.a[first : last+1] {
for _, curr := range from {
switch curr {
case newline:
if prev == brace1 || prev == paren1 {
indent = append(indent, tab)
}
default:
if prev == newline {
if curr == brace2 || curr == paren2 {
indent = indent[:len(indent)-1]
}
to = append(to, indent...)
}
}
to = append(to, curr)
prev = curr
}
}
return
}

Top || internal/compile/author/include/include.go

package include

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"NN-512/internal/raw"
)

var alwaysH = [...]string{
"pthread.h",
"stddef.h",
}

var alwaysC = [...]string{
"errno.h",
"stdarg.h",
"stdint.h",
"stdio.h",
"stdlib.h",
"string.h",
}

func H() cgen.Gen {
var (
n = len(alwaysH)
gs = make(cgen.Gens, n)
)
for i, name := range &alwaysH {
gs[i] = cgen.Preprocessor{
Head: cgen.Include,
Tail: cgen.AngleBracketed(name),
}
}
return gs
}

func C(pl *plan.Plan) cgen.Gen {
var gs cgen.Gens
inc := func(a cgen.Gen) {
g := cgen.Preprocessor{Head: cgen.Include, Tail: a}
gs = append(gs, g)
}
sys := func(a string) { inc(cgen.AngleBracketed(a)) }
usr := func(a string) { inc(cgen.DoubleQuoted(a)) }
newline := func() { gs = append(gs, cgen.Newline) }
for _, name := range &alwaysC {
sys(name)
}
newline()
switch pl.Config.Platform {
case raw.AVX512Float32:
sys("immintrin.h")
default:
panic("bug")
}
newline()
usr(pl.Config.Prefix + ".h")
return gs
}

Top || internal/compile/author/license/license.go

package license

import "NN-512/internal/compile/author/cgen"

var Gen cgen.Gen = cgen.Comment{
`NN-512 (https://NN-512.com)`,
``,
`Copyright (C) 2019 [`,
` 37ef ced3 3727 60b4`,
` 3c29 f9c6 dc30 d518`,
` f4f3 4106 6964 cab4`,
` a06f c1a3 83fd 090e`,
`]`,
``,
`All rights reserved.`,
``,
`Redistribution and use in source and binary forms, with or without`,
`modification, are permitted provided that the following conditions`,
`are met:`,
``,
`1. Redistributions of source code must retain the above copyright`,
` notice, this list of conditions and the following disclaimer.`,
``,
`2. Redistributions in binary form must reproduce the above copyright`,
` notice, this list of conditions and the following disclaimer in`,
` the documentation and/or other materials provided with the`,
` distribution.`,
``,
`THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS`,
`"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT`,
`LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR`,
`A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT`,
`HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,`,
`SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT`,
`LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,`,
`DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY`,
`THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT`,
`(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE`,
`OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`,
}

Top || internal/compile/author/loom/loom.go

package loom

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/author/trans"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

func btoi(b bool) int {
if b {
return 1
}
return 0
}

func min(x, y int) int {
if x <= y {
return x
}
return y
}

func max(x, y int) int {
if x >= y {
return x
}
return y
}

func ceilQuo(n, d int) int {
return (n + d - 1) / d
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func loMask(n int) cgen.Gen {
return il(1<<uint(n) - 1)
}

func addMul(a, b, c cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: a,
Expr2: cgen.Mul{
Expr1: b,
Expr2: c,
},
}
}

func mix(a []cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Ctx struct {
prefix string
platform raw.Platform
cacheBytes1 int
cacheBytes2 int
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]interface{}
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Loom",
platform: pl.Config.Platform,
cacheBytes1: pl.Config.L1DataCachePerThread,
cacheBytes2: pl.Config.L2CachePerThreadExL1,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]interface{}),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

type Spec struct {
From SpecFrom
Filts []SpecFilts
To SpecTo
FilterH int
FilterW int
StrideH int
StrideW int
PaddingH int
PaddingW int
DilationH int
DilationW int
Groups int
}

type SpecFrom struct {
Chans int
Height int
Width int
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type SpecFilts struct {
Cnt int
BnPre int
BnPost int
}

type SpecTo struct {
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type spans struct {
spanH1 int
spanH2 int
spanH3 int
spanW1 int
spanW2 int
spanW3 int
}

type loopW struct {
fromW int
fromStep int
blkFirst int
blkPast int
spans
}

type loopH struct {
fromH int
fromStep int
blkFirst int
blkStep int
blkPast int
lws []*loopW
}

type blocks struct {
cnt int
lhs []*loopH
}

func newBlocks(ctx *Ctx, spec *Spec, blkVecs, vecLanes int) *blocks {
var (
blks blocks
fromH1 int
lw1 loopW
fromH2 int
lw2 loopW
lh1 loopH
lh2 loopH
)
layer6 := func() {
lh := lh2
blks.lhs = append(
blks.lhs, &lh,
)
}
layer5 := func(flush bool) {
split := true
if len(lh2.lws) == len(lh1.lws) {
split = false
for x, lw := range lh2.lws {
if *lw != *lh1.lws[x] {
split = true
break
}
}
}
switch {
case split:
if lh2.blkFirst < lh2.blkPast {
layer6()
}
lh2 = lh1
lh2.lws = make([]*loopW, len(lh1.lws))
for x, lw := range lh1.lws {
lw := *lw
lh2.lws[x] = &lw
}
default:
if lh2.fromStep == 0 {
lh2.fromStep = lh1.fromH - lh2.fromH
lh2.blkStep = lh1.blkFirst - lh2.blkFirst
}
lh2.blkPast = lh1.blkPast
}
if flush {
layer6()
}
}
layer4 := func(flush bool) {
if lw2.fromW == 0 {
if lh1.blkFirst < lh1.blkPast {
layer5(false)
}
lh1.fromH = fromH2
lh1.blkFirst = lw2.blkFirst
lh1.lws = lh1.lws[:0]
}
lh1.blkPast = lw2.blkPast
lw2.blkFirst -= lh1.blkFirst
lw2.blkPast -= lh1.blkFirst
x := len(lh1.lws)
switch x {
case cap(lh1.lws):
lh1.lws = append(
lh1.lws, new(loopW),
)
default:
lh1.lws = lh1.lws[:x+1]
if lh1.lws[x] == nil {
lh1.lws[x] = new(loopW)
}
}
*lh1.lws[x] = lw2
if flush {
layer5(true)
}
}
layer3 := func(flush bool) {
if flush {
layer4(true)
return
}
switch {
case lw1.fromW == 0:
case lw1.spans != lw2.spans:
default:
if lw2.fromStep == 0 {
lw2.fromStep = lw1.fromW - lw2.fromW
}
lw2.blkPast = lw1.blkPast
return
}
if lw2.blkFirst < lw2.blkPast {
layer4(false)
}
fromH2 = fromH1
lw2 = lw1
}
layer2 := func() {
var (
h1 = spec.PaddingH
h2 = h1 + spec.From.Height
h3 = h2 + spec.PaddingH
w1 = spec.PaddingW
w2 = w1 + spec.From.Width
w3 = w2 + spec.PaddingW
stepH = blkVecs * spec.StrideH
stepW = vecLanes * spec.StrideW
)
for h := 0; h < h3; h += stepH {
for w := 0; w < w3; w += stepW {
fromH1 = h
lw1.fromW = w
lw1.blkFirst = blks.cnt
lw1.blkPast = blks.cnt + 1
blks.cnt++
lw1.spanH1 = min(max(h1-h, 0), stepH)
lw1.spanH2 = min(max(h2-h, 0), stepH)
lw1.spanH3 = min(h3-h, stepH)
lw1.spanW1 = min(max(w1-w, 0), stepW)
lw1.spanW2 = min(max(w2-w, 0), stepW)
lw1.spanW3 = min(w3-w, stepW)
var (
datH = lw1.spanH2 - lw1.spanH1
datW = lw1.spanW2 - lw1.spanW1
)
if datH == 0 || datW == 0 {
lw1.spanH1 = lw1.spanH3
lw1.spanH2 = lw1.spanH3
lw1.spanW1 = lw1.spanW3
lw1.spanW2 = lw1.spanW3
}
layer3(false)
}
}
layer3(true)
}
layer1 := func() *blocks {
sig := fmt.Sprint(
"newBlocks",
" ",
spec.From.Height,
spec.From.Width,
spec.StrideH,
spec.StrideW,
spec.PaddingH,
spec.PaddingW,
blkVecs,
vecLanes,
)
if prior, ok := ctx.dedup[sig]; ok {
return prior.(*blocks)
}
ctx.dedup[sig] = &blks
layer2()
return &blks
}
return layer1()
}

type node struct {
filtH int
filtW int
deck int
pile int
base bool
}

type field struct {
sboxH int
sboxW int
nodeFirst int
nodeStep int
}

type layout struct {
fromChans int
toChans int
slices1 int
slices2 int
epochs1 int
epochs2 int
biasBytes int
biasGroupBytes int
biasEpochBytes int
biasTotalBytes int
lifts []int
shifts []int
nodes []*node
fields []*field
wtBytes int
wtSliceWts1 int
wtSliceWts2 int
wtSliceBytes1 int
wtSliceBytes2 int
wtCores1 int
wtCores2 int
wtCoreBytes11 int
wtCoreBytes12 int
wtCoreBytes21 int
wtCoreBytes22 int
wtNodeBytes1 int
wtNodeBytes2 int
wtGroupBytes1 int
wtGroupBytes2 int
wtEpochBytes1 int
wtEpochBytes2 int
wtTotalBytes int
blks *blocks
blkStep int
datBytes int
datVecDats int
datVecBytes int
datSliceVecs int
datSliceBytes int
datCores int
datCoreBytes1 int
datCoreBytes2 int
datGroupBytes1 int
datGroupBytes2 int
datFieldBytes1 int
datFieldBytes2 int
datEpochBytes1 int
datEpochBytes2 int
datTotalBytes int
sumSiteBytes1 int
sumSiteBytes2 int
sumPileBytes int
sumCores int
sumCoreBytes int
sumGroupBytes int
sumTotalBytes int
}

func newLayout(ctx *Ctx, spec *Spec) *layout {
var (
y layout
)
layer10 := func() {
y.datCoreBytes1 = y.slices1 * y.datSliceBytes
y.datCoreBytes2 = y.slices2 * y.datSliceBytes
y.datGroupBytes1 = y.datCores * y.datCoreBytes1
y.datGroupBytes2 = y.datCores * y.datCoreBytes2
y.datFieldBytes1 = spec.Groups * y.datGroupBytes1
y.datFieldBytes2 = spec.Groups * y.datGroupBytes2
y.datEpochBytes1 = len(y.fields) * y.datFieldBytes1
y.datEpochBytes2 = len(y.fields) * y.datFieldBytes2
y.datTotalBytes = y.epochs1*y.datEpochBytes1 + y.datEpochBytes2
}
layer9 := func() {
y.wtCoreBytes11 = y.slices1 * y.wtSliceBytes1
y.wtCoreBytes12 = y.slices1 * y.wtSliceBytes2
y.wtCoreBytes21 = y.slices2 * y.wtSliceBytes1
y.wtCoreBytes22 = y.slices2 * y.wtSliceBytes2
y.wtNodeBytes1 = y.wtCores1*y.wtCoreBytes11 + y.wtCoreBytes12
y.wtNodeBytes2 = y.wtCores1*y.wtCoreBytes21 + y.wtCoreBytes22
y.wtGroupBytes1 = len(y.nodes) * y.wtNodeBytes1
y.wtGroupBytes2 = len(y.nodes) * y.wtNodeBytes2
y.wtEpochBytes1 = spec.Groups * y.wtGroupBytes1
y.wtEpochBytes2 = spec.Groups * y.wtGroupBytes2
y.wtTotalBytes = y.epochs1*y.wtEpochBytes1 + y.wtEpochBytes2
layer10()
}
layer8 := func() {
y.biasGroupBytes = y.toChans * y.biasBytes
y.biasEpochBytes = spec.Groups * y.biasGroupBytes
y.biasTotalBytes = y.epochs2 * y.biasEpochBytes
layer9()
}
layer7 := func() {
wtSliceBytes := y.wtSliceBytes1
if y.wtCores1 == 0 {
wtSliceBytes = y.wtSliceBytes2
}
switch ctx.platform {
case raw.AVX512Float32:
var (
sliceBytes = 2*wtSliceBytes + y.datSliceBytes
cacheBytes = ctx.cacheBytes1 + ctx.cacheBytes2
)
const (
empirical1 = 4
empirical2 = 512
empirical3 = 4
)
y.slices1 = cacheBytes / empirical1 / sliceBytes
y.slices1 = max(y.slices1, empirical2)
y.slices2 = y.fromChans % y.slices1
y.epochs1 = y.fromChans / y.slices1
y.epochs2 = y.epochs1 + btoi(y.slices2 > 0)
if y.epochs1 > 0 && y.epochs1 < y.epochs2 {
if y.slices2*empirical3 < y.slices1 {
y.slices2 += y.slices1
y.epochs1--
y.epochs2--
}
}
default:
panic("bug")
}
layer8()
}
layer6 := func() {
y.sumSiteBytes1 = y.wtSliceWts1 * y.datSliceBytes
y.sumSiteBytes2 = y.wtSliceWts2 * y.datSliceBytes
y.sumPileBytes = y.wtCores1*y.sumSiteBytes1 + y.sumSiteBytes2
var (
lift = y.lifts[len(y.lifts)-1]
cut1 = lift / y.datSliceVecs
cut2 = cut1 - btoi(cut1 > 0)
cut3 = cut2 * y.blkStep
)
y.sumCores = y.datCores - cut3
y.sumCoreBytes = len(y.shifts) * y.sumPileBytes
y.sumGroupBytes = y.sumCores * y.sumCoreBytes
y.sumTotalBytes = spec.Groups * y.sumGroupBytes
layer7()
}
layer5 := func() {
y.blks = newBlocks(ctx, spec, y.datSliceVecs, y.datVecDats)
switch lh := y.blks.lhs[0]; lh.blkStep {
case 0:
y.blkStep = lh.blkPast - lh.blkFirst
default:
y.blkStep = lh.blkStep
}
y.datVecBytes = y.datVecDats * y.datBytes
y.datSliceBytes = y.datSliceVecs * y.datVecBytes
y.datCores = y.blks.cnt
layer6()
}
layer4 := func() {
y.wtSliceWts2 = y.toChans % y.wtSliceWts1
y.wtSliceBytes1 = y.wtSliceWts1 * y.wtBytes
y.wtSliceBytes2 = y.wtSliceWts2 * y.wtBytes
y.wtCores1 = y.toChans / y.wtSliceWts1
y.wtCores2 = y.wtCores1 + btoi(y.wtSliceWts2 > 0)
layer5()
}
layer3 := func() {
if len(spec.Filts) > 1 && spec.Groups > 1 {
panic("bug")
}
filts := 0
for i := range spec.Filts {
filts += spec.Filts[i].Cnt
}
y.fromChans = spec.From.Chans / spec.Groups
y.toChans = filts / spec.Groups
layer4()
}
layer2 := func() {
nds := make([][][]*node, spec.StrideH)
for sboxH := range nds {
nds[sboxH] = make([][]*node, spec.StrideW)
}
for filtH := 0; filtH < spec.FilterH; filtH++ {
var (
dilaH = filtH * spec.DilationH
sboxH = dilaH % spec.StrideH
nds = nds[sboxH]
lift = dilaH / spec.StrideH
deck = -1
)
for at, is := range y.lifts {
if is == lift {
deck = at
break
}
}
if deck == -1 {
deck = len(y.lifts)
y.lifts = append(
y.lifts, lift,
)
}
for filtW := 0; filtW < spec.FilterW; filtW++ {
var (
dilaW = filtW * spec.DilationW
sboxW = dilaW % spec.StrideW
shift = dilaW / spec.StrideW
)
nd := &node{
filtH: filtH,
filtW: filtW,
deck: deck,
pile: -1,
base: false,
}
for at, is := range y.shifts {
if is == shift {
nd.pile = at
break
}
}
if nd.pile == -1 {
nd.pile = len(y.shifts)
nd.base = true
y.shifts = append(
y.shifts, shift,
)
}
nds[sboxW] = append(
nds[sboxW], nd,
)
}
}
for sboxH, nds := range nds {
for sboxW, nds := range nds {
if nds == nil {
continue
}
fld := &field{
sboxH: sboxH,
sboxW: sboxW,
nodeFirst: len(y.nodes),
nodeStep: 0,
}
for _, nd := range nds {
if nd.filtH == nds[0].filtH {
fld.nodeStep++
}
y.nodes = append(
y.nodes, nd,
)
}
y.fields = append(
y.fields, fld,
)
}
}
layer3()
}
layer1 := func() *layout {
switch ctx.platform {
case raw.AVX512Float32:
y.biasBytes = 4
y.wtBytes = 4
y.wtSliceWts1 = 6
y.datBytes = 4
y.datVecDats = 16
y.datSliceVecs = 4
default:
panic("bug")
}
layer2()
return &y
}
return layer1()
}

type ArrangeFilts struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeFilts) Prep() cgen.Gen {
a.layout = newLayout(a.Ctx, a.Spec)
const affix = "ArrangeFilts"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeFilts{ArrangeFilts: a},
cgen.Newline,
}
}

func (a *ArrangeFilts) Bytes() int {
return a.biasTotalBytes + a.wtTotalBytes
}

func (a *ArrangeFilts) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeFilts struct {
*ArrangeFilts
bundleFilts int
bundleTile int
bundleTiles int
bundleScrap int
bundleHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
bundleCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
slices int
coreBytes int
nodeBytes int
groupBytes int
epochFirst int
epochCnt int
arrangedB cgen.Gen
arrangedW cgen.Gen
filtsIdx int
wtPtr cgen.Gen
biasPtr cgen.Gen
bnPtrs []cgen.Gen
groupIdx cgen.Gen
bundleIdx cgen.Gen
bundleLast cgen.Gen
baseFilt int
baseBundle int
filts1 int
filts2 int
bundleFirst int
bundlePast int
}

func (a *arrangeFilts) Append(to []byte) []byte {
var (
vecWts int
threadVecs int
)
switch a.platform {
case raw.AVX512Float32:
vecWts = 16
a.bundleFilts = vecWts
threadVecs = 512
default:
panic("bug")
}
var (
epochChans = ceilQuo(a.fromChans, a.epochs2)
spatialWts = a.FilterH * a.FilterW
filtVecs int
)
switch {
case spatialWts <= vecWts:
filtVecs = ceilQuo(epochChans, vecWts/spatialWts)
default:
filtVecs = epochChans * ceilQuo(spatialWts, vecWts)
}
var (
bundleVecs = a.bundleFilts * filtVecs
groupVecs = a.toChans * filtVecs
groupBundles int
)
switch len(a.Filts) {
case 1:
groupBundles = ceilQuo(a.toChans, a.bundleFilts)
default:
for i := range a.Filts {
filts := a.Filts[i].Cnt
groupBundles += ceilQuo(filts, a.bundleFilts)
}
}
switch {
case threadVecs <= groupVecs:
var (
tile = ceilQuo(threadVecs, bundleVecs)
tiles = max(groupBundles/tile, 1)
)
a.bundleTile = groupBundles / tiles
a.bundleTiles = tiles
a.bundleScrap = groupBundles - tiles*a.bundleTile
a.bundleHull = tiles
if a.bundleScrap > 0 {
a.bundleTiles--
a.bundleScrap += a.bundleTile
}
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
default:
a.bundleTile = groupBundles
a.bundleTiles = 1
a.bundleScrap = 0
a.bundleHull = 1
var (
tile = ceilQuo(threadVecs, groupVecs)
tiles = max(a.Groups/tile, 1)
)
a.groupTile = a.Groups / tiles
a.groupTiles = tiles
a.groupScrap = a.Groups - tiles*a.groupTile
a.groupHull = tiles
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.bundleHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeFilts) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 7)
usedPt = false
)
a.tensors = vb(a.name("tensors"))
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: a.tensors,
Init: callee.Any(),
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(a.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
a.bundleCoord = coord("b", a.bundleHull, 0)
a.groupCoord = coord("g", a.groupHull, 1)
a.epochCoord = coord("e", a.epochs2, 2)
if !usedPt {
body[4] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
kernel := func() cgen.Gen {
var assn cgen.Gen
if a.epochs2 > 1 && a.epochCnt == 1 {
assn = cgen.Assign{
Expr1: a.epochCoord,
Expr2: il(a.epochFirst),
}
}
return cgen.Stmts{
assn,
a.kernel1(),
}
}
if a.epochs1 > 0 {
a.slices = a.slices1
a.coreBytes = a.wtCoreBytes11
a.nodeBytes = a.wtNodeBytes1
a.groupBytes = a.wtGroupBytes1
a.epochFirst = 0
a.epochCnt = a.epochs1
put := kernel()
if a.epochs1 < a.epochs2 {
put = cgen.If{
Cond: cgen.CmpL{
Expr1: a.epochCoord,
Expr2: il(a.epochs1),
},
Then: cgen.Stmts{
put,
cgen.Return{},
},
}
}
body[5] = put
}
if a.epochs1 < a.epochs2 {
a.slices = a.slices2
a.coreBytes = a.wtCoreBytes21
a.nodeBytes = a.wtNodeBytes2
a.groupBytes = a.wtGroupBytes2
a.epochFirst = a.epochs1
a.epochCnt = 1
body[6] = kernel()
}
return callee.Func(body)
}

func (a *arrangeFilts) kernel1() cgen.Gen {
var (
n = len(a.Filts)
savedFiltsIdx = 0
savedTensorIdx = 0
)
tensor := func(filtsIdx, off int) cgen.Gen {
if savedFiltsIdx != filtsIdx {
savedFiltsIdx = filtsIdx
at := 0
for x := 0; x < filtsIdx; x++ {
at += 2
at += a.Filts[x].BnPre
at += a.Filts[x].BnPost
}
savedTensorIdx = at
}
return cgen.Elem{
Arr: a.tensors,
Idx: il(savedTensorIdx + off),
}
}
ptrDecls := func(filtsIdx int) cgen.Gen {
wtDecl := func() cgen.Gen {
a.wtPtr = vb(a.name("wtPtr"))
filtHW := a.FilterH * a.FilterW
return cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.wtPtr,
Init: addMul(
tensor(filtsIdx, 0),
il(a.slices1*filtHW*a.wtBytes),
a.epochCoord,
),
}
}
biasDecl := func() cgen.Gen {
if a.epochFirst == 0 {
a.biasPtr = vb(a.name("biasPtr"))
return cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.biasPtr,
Init: tensor(filtsIdx, 1),
}
}
a.biasPtr = nil
return nil
}
bnDecls := func() cgen.Gen {
var (
pre = a.Filts[filtsIdx].BnPre
post = a.Filts[filtsIdx].BnPost
ret = make(cgen.Stmts, pre+post)
)
a.bnPtrs = make([]cgen.Gen, pre+post)
for x := range a.bnPtrs {
var (
bnPtr = vb(a.name("bnPtr"))
expr = tensor(filtsIdx, 2+x)
)
if x < pre {
expr = &bn.Offset{
Ctx: a.bc,
Mas: expr,
Channel: cgen.Mul{
Expr1: il(a.slices1),
Expr2: a.epochCoord,
},
}
}
ret[x] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: bnPtr, Init: expr,
}
a.bnPtrs[x] = bnPtr
}
return ret
}
a.filtsIdx = filtsIdx
return cgen.Stmts{
wtDecl(),
biasDecl(),
bnDecls(),
}
}
layer5 := func() cgen.Gen {
if n == 1 {
a.baseFilt = 0
a.baseBundle = 0
return a.kernel2()
}
var (
atFilt = make([]int, n+1)
atBundle = make([]int, n+1)
)
for x := 0; x < n; x++ {
var (
filts = a.Filts[x].Cnt
bundles = ceilQuo(filts, a.bundleFilts)
)
atFilt[x+1] = atFilt[x] + filts
atBundle[x+1] = atBundle[x] + bundles
}
leaf := func(x int) cgen.Stmts {
a.baseFilt = atFilt[x]
a.baseBundle = atBundle[x]
var assn cgen.Gen
if x+1 < n {
assn = cgen.Assign{
Expr1: a.bundleIdx,
Expr2: il(atBundle[x+1]),
}
}
return cgen.Stmts{
ptrDecls(x),
a.kernel2(),
assn,
}
}
var tree func(int, int) cgen.Stmts
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = atBundle[first]
stop = atBundle[last+1]
split = start + (stop-start)/2
x = first + 1
)
for atBundle[x+1] <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: a.bundleIdx,
Expr2: il(atBundle[x]),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, n-1)
}
layer4 := func() cgen.Gen {
a.bundleIdx = vb(a.name("j"))
switch a.bundleHull {
case 1:
a.bundleLast = nil
default:
a.bundleLast = vb(a.name("jj"))
}
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.bundleIdx,
Init: cgen.Mul{
Expr1: il(a.bundleTile),
Expr2: a.bundleCoord,
},
}
if a.bundleLast != nil {
var expr cgen.Gen
switch a.bundleTiles {
case a.bundleHull:
expr = il(a.bundleTile - 1)
case 0:
expr = il(a.bundleScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.bundleCoord,
Expr2: il(a.bundleTiles),
},
Then: il(a.bundleTile - 1),
Else: il(a.bundleScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.bundleLast,
Init: cgen.Add{
Expr1: a.bundleIdx,
Expr2: expr,
},
}
}
stmts[2] = layer5()
return stmts
}
layer3 := func() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: cgen.Mul{
Expr1: il(a.groupTile),
Expr2: a.groupCoord,
},
}
switch a.groupTiles {
case a.groupHull:
iters = a.groupTile
case 0:
iters = a.groupScrap
}
switch iters {
case 1:
stmts[2] = layer4()
default:
var (
last = vb(a.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile - 1),
Else: il(a.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: a.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: layer4(),
}
}
return stmts
}
layer2 := func() cgen.Gen {
var decls cgen.Gen
if n == 1 {
decls = ptrDecls(0)
}
return cgen.Gens{
decls,
layer3(),
}
}
layer1 := func() cgen.Gen {
a.arrangedB = vb(a.name("arrangedB"))
a.arrangedW = vb(a.name("arrangedW"))
return cgen.Stmts{
cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.arrangedB,
Init: addMul(
tensor(n, 0),
il(a.biasEpochBytes),
a.epochCoord,
),
},
cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.arrangedW,
Init: addMul(
cgen.Add{
Expr1: tensor(n, 0),
Expr2: il(a.biasTotalBytes),
},
il(a.wtEpochBytes1),
a.epochCoord,
),
},
layer2(),
}
}
return layer1()
}

func (a *arrangeFilts) kernel2() cgen.Gen {
var (
filts1 int
filts2 int
)
layer3 := func() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}
layer2 := func() cgen.Gen {
var (
retIf cgen.Gen
past = a.baseBundle
)
if a.bundleLast != nil {
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.bundleIdx,
Expr2: a.bundleLast,
},
Then: cgen.Return{},
}
}
do := func(bundles int) cgen.Gen {
a.bundleFirst = past
past += bundles
a.bundlePast = past
if bundles == 1 {
return cgen.If{
Cond: cgen.CmpE{
Expr1: a.bundleIdx,
Expr2: il(past - 1),
},
Then: cgen.Stmts{
layer3(),
retIf,
cgen.Assign{
Expr1: a.bundleIdx,
Expr2: il(past),
},
},
}
}
return cgen.If{
Cond: cgen.CmpL{
Expr1: a.bundleIdx,
Expr2: il(past),
},
Then: cgen.Stmts{
cgen.For{
Cond: cgen.CmpNE{
Expr1: a.bundleIdx,
Expr2: il(past),
},
Post: cgen.IncPre{
Expr: a.bundleIdx,
},
Body: cgen.Stmts{
layer3(),
retIf,
},
},
},
}
}
var (
stmts = make(cgen.Stmts, 4)
quo1 = filts1 / a.bundleFilts
rem1 = filts1 - a.bundleFilts*quo1
tail = filts2 - a.bundleFilts*quo1
)
if quo1 > 0 {
a.filts1 = a.bundleFilts
a.filts2 = a.bundleFilts
stmts[0] = do(quo1)
}
if rem1 > 0 {
a.filts1 = rem1
a.filts2 = min(tail, a.bundleFilts)
tail -= a.filts2
stmts[1] = do(1)
}
if tail > 0 {
var (
quo2 = tail / a.bundleFilts
rem2 = tail - a.bundleFilts*quo2
)
if quo2 > 0 {
a.filts1 = 0
a.filts2 = a.bundleFilts
stmts[2] = do(quo2)
}
if rem2 > 0 {
a.filts1 = 0
a.filts2 = rem2
stmts[3] = do(1)
}
}
return stmts
}
layer1 := func() cgen.Gen {
switch len(a.Filts) {
case 1:
filts2 = a.toChans
default:
filts2 = a.Filts[a.filtsIdx].Cnt
}
var (
past = a.baseFilt + filts2
split = a.toChans - a.wtSliceWts2
clamp1 = max(past-split, 0)
clamp2 = min(clamp1, filts2)
)
filts1 = filts2 - clamp2
return layer2()
}
return layer1()
}

func (a *arrangeFilts) m512() cgen.Gen {
var (
filtHW int
nodeIdxes []int
preCnt int
postCnt int
postMul1 cgen.Gen
postAdd1 cgen.Gen
bias cgen.Gen
coreIdx1 cgen.Gen
coreOff1 int
stepIdx cgen.Gen
stepChans int
stepWts int
haveWts int
tpOff int
tp *trans.Pose
colVec cgen.Gen
colQuo int
colRem int
preMul1 cgen.Gen
preAdd1 cgen.Gen
nodeIdx int
coreIdx2 int
coreOff2 int
emitLane int
emitLanes int
)
layer17 := func() cgen.Gen {
var (
ae = a.arrangedW
slicePitch = a.wtSliceBytes1
)
if emitLane == a.filts1 {
slicePitch = a.wtSliceBytes2
}
var (
stepPitch = stepChans * slicePitch
mask1 = 1<<uint(emitLanes) - 1
mask2 = mask1 << uint(emitLane)
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
nodeIdx*a.nodeBytes +
coreIdx2*a.coreBytes +
colQuo*slicePitch +
coreOff2*a.wtBytes -
emitLane*a.wtBytes,
),
}
ae = addMul(ae, il(a.groupBytes), a.groupIdx)
ae = addMul(ae, il(a.coreBytes), coreIdx1)
ae = addMul(ae, il(stepPitch), stepIdx)
return avx.Mm512MaskStoreuPs{
ae, il(mask2), colVec,
}
}
layer16 := func() cgen.Gen {
var stmts cgen.Stmts
nodeIdx = nodeIdxes[colRem]
coreIdx2 = 0
coreOff2 = coreOff1
emitLane = 0
for emitLane < a.filts2 {
var (
lanes1 = a.filts2 - emitLane
lanes2 = a.wtSliceWts1 - coreOff2
)
emitLanes = min(lanes1, lanes2)
stmts = append(stmts, layer17())
coreIdx2++
coreOff2 = 0
emitLane += emitLanes
}
return stmts
}
layer15 := func() cgen.Gen {
if preCnt == 0 {
return layer16()
}
return cgen.Stmts{
cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512FmaddPs{
colVec, preAdd1,
bias,
},
},
cgen.Assign{
Expr1: colVec,
Expr2: avx.Mm512MulPs{
colVec, preMul1,
},
},
layer16(),
}
}
layer14 := func() cgen.Gen {
if preCnt == 0 ||
colRem > 0 {
return layer15()
}
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
preCh := cgen.Paren{
Inner: addMul(
addMul(
il(colQuo),
il(a.fromChans),
a.groupIdx,
),
il(stepChans),
stepIdx,
),
}
for x, prePtr := range a.bnPtrs[:preCnt] {
var (
preMul2 = vb(a.name("preMul"))
preAdd2 = vb(a.name("preAdd"))
)
stmt(&bn.Load{
Ctx: a.bc,
Mas: prePtr,
Channel: preCh,
Mul: preMul2,
Add: preAdd2,
})
if x == 0 {
preMul1 = preMul2
preAdd1 = preAdd2
continue
}
stmt(cgen.Assign{
Expr1: preMul1,
Expr2: avx.Mm512MulPs{
preMul1, preMul2,
},
})
stmt(cgen.Assign{
Expr1: preAdd1,
Expr2: avx.Mm512FmaddPs{
preAdd1, preMul2,
preAdd2,
},
})
}
stmt(layer15())
return stmts
}
layer13 := func() cgen.Gen {
if postCnt == 0 {
return layer14()
}
return cgen.Stmts{
cgen.Assign{
Expr1: colVec,
Expr2: avx.Mm512MulPs{
colVec, postMul1,
},
},
layer14(),
}
}
layer12 := func() cgen.Gen {
var (
n = tp.Cols
gens = make(cgen.Gens, n)
)
for x, wt := range tp.Vars[:n] {
colVec = wt
colQuo = (tpOff + x) / filtHW
colRem = (tpOff + x) % filtHW
gens[x] = layer13()
}
return gens
}
layer11 := func() cgen.Gen {
var (
n = tp.Rows
stmts = make(cgen.Stmts, n+2)
)
for x, wt := range tp.Vars[:n] {
var (
mask = loMask(tp.Cols)
ae = a.wtPtr
filtPitch = a.fromChans * filtHW * a.wtBytes
groupPitch = a.toChans * filtPitch
bundlePitch = a.bundleFilts * filtPitch
stepPitch = stepWts * a.wtBytes
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
-a.baseBundle*bundlePitch +
x*filtPitch +
tpOff*a.wtBytes,
),
}
ae = addMul(ae, il(groupPitch), a.groupIdx)
ae = addMul(ae, il(bundlePitch), a.bundleIdx)
ae = addMul(ae, il(stepPitch), stepIdx)
stmts[x] = cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512MaskzLoaduPs{
mask, ae,
},
}
}
stmts[n] = tp
stmts[n+1] = layer12()
return stmts
}
layer10 := func() cgen.Gen {
var (
cols1 = haveWts - tpOff
cols2 = min(cols1, a.bundleFilts)
)
tp = &trans.Pose{
Platform: a.platform,
Nms: a.nms,
Rows: a.filts2,
Cols: cols2,
}
tp.Vars = make(
[]cgen.Gen,
max(tp.Rows, tp.Cols),
)
for x := range tp.Vars {
wt := vb(a.name("wt"))
tp.Vars[x] = wt
}
return layer11()
}
layer9 := func() cgen.Gen {
var (
n = ceilQuo(haveWts, a.bundleFilts)
gens = make(cgen.Gens, n)
)
for x := range gens {
tpOff = x * a.bundleFilts
gens[x] = layer10()
}
return gens
}
layer8 := func() cgen.Gen {
stepIdx = vb(a.name("k"))
switch {
case filtHW < a.bundleFilts:
stepChans = a.bundleFilts / filtHW
stepWts = stepChans * filtHW
default:
stepChans = 1
stepWts = filtHW
}
var (
stmts = make(cgen.Stmts, 3)
iters = a.slices / stepChans
after = a.slices % stepChans
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: stepIdx,
Init: il(0),
}
if iters > 0 {
haveWts = stepWts
stmts[1] = cgen.For{
Cond: cgen.CmpNE{
Expr1: stepIdx,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: stepIdx,
},
Body: layer9(),
}
}
if after > 0 {
haveWts = after * filtHW
stmts[2] = layer9()
}
return stmts
}
layer7 := func() cgen.Gen {
coreIdx1 = vb(a.name("c"))
var (
stmts = make(cgen.Stmts, 2)
add = a.baseFilt
sub = a.baseBundle * a.bundleFilts
numer = cgen.Cast{
Type: cgen.SizeT,
Expr: cgen.Paren{
Inner: addMul(
il(add-sub),
il(a.bundleFilts),
a.bundleIdx,
),
},
}
denom = il(a.wtSliceWts1)
marks = make([]bool, a.wtSliceWts1)
marked = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: coreIdx1,
Init: cgen.Quo{
Expr1: numer,
Expr2: denom,
},
}
for x1 := a.bundleFirst; x1 < a.bundlePast; x1++ {
var (
x2 = add - sub + x1*a.bundleFilts
x3 = x2 % a.wtSliceWts1
)
if marks[x3] {
break
}
marks[x3] = true
marked++
}
switch marked {
case 1:
for off, mark := range marks {
if mark {
coreOff1 = off
stmts[1] = layer8()
break
}
}
default:
cases := make(cgen.Stmts, 0, marked)
for off, mark := range marks {
if mark {
var expr cgen.Gen
if len(cases)+1 < marked {
expr = il(off)
}
coreOff1 = off
cases = append(
cases, cgen.Case{
Expr: expr,
Body: cgen.Stmts{
layer8(),
cgen.Break,
},
},
)
}
}
stmts[1] = cgen.Switch{
Expr: cgen.Rem{
Expr1: numer,
Expr2: denom,
},
Cases: cases,
}
}
return stmts
}
layer6 := func() cgen.Gen {
store := func() cgen.Gen {
var (
ae = a.arrangedB
bundlePitch = a.bundleFilts * a.biasBytes
mask = loMask(a.filts2)
)
ae = cgen.Sub{
Expr1: ae,
Expr2: il(
a.baseBundle*bundlePitch -
a.baseFilt*a.biasBytes,
),
}
ae = addMul(ae, il(a.biasGroupBytes), a.groupIdx)
ae = addMul(ae, il(bundlePitch), a.bundleIdx)
return avx.Mm512MaskStoreuPs{
ae, mask, bias,
}
}
if preCnt == 0 {
return cgen.Stmts{
store(),
layer7(),
}
}
return cgen.Stmts{
layer7(),
store(),
}
}
layer5 := func() cgen.Gen {
var stmt cgen.Gen
switch a.epochFirst {
case 0:
load := func() cgen.Gen {
var (
mask = loMask(a.filts2)
ae = a.biasPtr
groupPitch = a.toChans * a.biasBytes
bundlePitch = a.bundleFilts * a.biasBytes
)
ae = cgen.Sub{
Expr1: ae,
Expr2: il(a.baseBundle * bundlePitch),
}
ae = addMul(ae, il(groupPitch), a.groupIdx)
ae = addMul(ae, il(bundlePitch), a.bundleIdx)
return cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512MaskzLoaduPs{
mask, ae,
},
}
}
post := func() cgen.Gen {
if postCnt == 0 {
return nil
}
return cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512FmaddPs{
postMul1, bias,
postAdd1,
},
}
}
stmt = cgen.If{
Cond: cgen.IsZero{
Expr: a.epochCoord,
},
Then: cgen.Stmts{
load(),
post(),
},
}
default:
if postCnt > 0 {
stmt = cgen.Cast{
Type: cgen.Void,
Expr: postAdd1,
}
}
}
return cgen.Stmts{
stmt,
layer6(),
}
}
layer4 := func() cgen.Gen {
bias = vb(a.name("bias"))
return cgen.Stmts{
cgen.Var{
Type: avx.M512, What: bias,
Init: avx.Mm512SetzeroPs,
},
layer5(),
}
}
layer3 := func() cgen.Gen {
if postCnt == 0 {
return layer4()
}
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
postCh := cgen.Paren{
Inner: addMul(
addMul(
il(-a.baseBundle*a.bundleFilts),
il(a.toChans),
a.groupIdx,
),
il(a.bundleFilts),
a.bundleIdx,
),
}
for x, postPtr := range a.bnPtrs[preCnt:] {
var (
postMul2 = vb(a.name("postMul"))
postAdd2 = vb(a.name("postAdd"))
)
stmt(&bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul2,
Add: postAdd2,
Cnt: a.filts2,
})
if x == 0 {
postMul1 = postMul2
postAdd1 = postAdd2
continue
}
stmt(cgen.Assign{
Expr1: postMul1,
Expr2: avx.Mm512MulPs{
postMul1, postMul2,
},
})
stmt(cgen.Assign{
Expr1: postAdd1,
Expr2: avx.Mm512FmaddPs{
postAdd1, postMul2,
postAdd2,
},
})
}
stmt(layer4())
return stmts
}
layer2 := func() cgen.Gen {
preCnt = a.Filts[a.filtsIdx].BnPre
postCnt = a.Filts[a.filtsIdx].BnPost
return layer3()
}
layer1 := func() cgen.Gen {
filtHW = a.FilterH * a.FilterW
nodeIdxes = make([]int, filtHW)
for x1, nd := range a.nodes {
x2 := nd.filtH*a.FilterW + nd.filtW
nodeIdxes[x2] = x1
}
return layer2()
}
return layer1()
}

type ArrangeDats struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeDats) Prep() cgen.Gen {
a.layout = newLayout(a.Ctx, a.Spec)
const affix = "ArrangeDats"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeDats{ArrangeDats: a},
cgen.Newline,
}
}

func (a *ArrangeDats) Bytes() int {
return a.datTotalBytes
}

func (a *ArrangeDats) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeDats struct {
*ArrangeDats
sliceTile1 int
sliceTile2 int
sliceTiles int
sliceScrap1 int
sliceScrap2 int
sliceHull int
coreTile int
coreTiles int
coreScrap int
coreHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
sliceCoord cgen.Gen
coreCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
sliceTile int
sliceScrap int
coreBytes int
groupBytes int
fieldBytes int
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
arranged cgen.Gen
groupIdx cgen.Gen
coreIdx cgen.Gen
coreLast cgen.Gen
coreH cgen.Gen
coreW cgen.Gen
*spans
sliceIdx cgen.Gen
bnMuls []cgen.Gen
bnAdds []cgen.Gen
}

func (a *arrangeDats) Append(to []byte) []byte {
var threadVecs int
switch a.platform {
case raw.AVX512Float32:
threadVecs = 512
default:
panic("bug")
}
var (
blkVecs = len(a.fields) * a.datSliceVecs
chanVecs = a.datCores * blkVecs
groupVecs1 = a.fromChans * chanVecs
groupVecs2 = ceilQuo(groupVecs1, a.epochs2)
coreVecs = ceilQuo(groupVecs2, a.datCores)
)
a.sliceTile1 = a.slices1
a.sliceTile2 = a.slices2
a.sliceTiles = 1
a.sliceScrap1 = 0
a.sliceScrap2 = 0
a.sliceHull = 1
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
switch {
case threadVecs <= coreVecs:
minSlices := a.slices1
switch {
case a.epochs1 == a.epochs2:
case a.epochs1 == 0 || a.slices1 > a.slices2:
minSlices = a.slices2
}
var (
tile = ceilQuo(threadVecs, blkVecs)
tiles = max(minSlices/tile, 1)
)
a.sliceTile1 = a.slices1 / tiles
a.sliceTile2 = a.slices2 / tiles
a.sliceTiles = tiles
a.sliceScrap1 = a.slices1 - tiles*a.sliceTile1
a.sliceScrap2 = a.slices2 - tiles*a.sliceTile2
a.sliceHull = tiles
if (a.epochs1 > 0 && a.sliceScrap1 > 0) ||
(a.epochs1 < a.epochs2 && a.sliceScrap2 > 0) {
a.sliceTiles--
a.sliceScrap1 += a.sliceTile1
a.sliceScrap2 += a.sliceTile2
}
a.coreTile = 1
a.coreTiles = a.datCores
a.coreScrap = 0
a.coreHull = a.datCores
case threadVecs <= groupVecs2:
var (
tile = ceilQuo(threadVecs, coreVecs)
tiles = max(a.datCores/tile, 1)
)
a.coreTile = a.datCores / tiles
a.coreTiles = tiles
a.coreScrap = a.datCores - tiles*a.coreTile
a.coreHull = tiles
if a.coreScrap > 0 {
a.coreTiles--
a.coreScrap += a.coreTile
}
default:
a.coreTile = a.datCores
a.coreTiles = 1
a.coreScrap = 0
a.coreHull = 1
var (
tile = ceilQuo(threadVecs, groupVecs2)
tiles = max(a.Groups/tile, 1)
)
a.groupTile = a.Groups / tiles
a.groupTiles = tiles
a.groupScrap = a.Groups - tiles*a.groupTile
a.groupHull = tiles
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.sliceHull),
il(a.coreHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeDats) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 8)
usedPt = false
)
a.tensors = vb(a.name("tensors"))
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: a.tensors,
Init: callee.Any(),
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(a.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
a.sliceCoord = coord("s", a.sliceHull, 0)
a.coreCoord = coord("c", a.coreHull, 1)
a.groupCoord = coord("g", a.groupHull, 2)
a.epochCoord = coord("e", a.epochs2, 3)
if !usedPt {
body[5] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
kernel := func(first, cnt int) cgen.Gen {
var assn cgen.Gen
if a.epochs2 > 1 && cnt == 1 {
assn = cgen.Assign{
Expr1: a.epochCoord,
Expr2: il(first),
}
}
return cgen.Stmts{
assn,
a.kernel1(),
}
}
if a.epochs1 > 0 {
a.sliceTile = a.sliceTile1
a.sliceScrap = a.sliceScrap1
a.coreBytes = a.datCoreBytes1
a.groupBytes = a.datGroupBytes1
a.fieldBytes = a.datFieldBytes1
put := kernel(0, a.epochs1)
if a.epochs1 < a.epochs2 {
put = cgen.If{
Cond: cgen.CmpL{
Expr1: a.epochCoord,
Expr2: il(a.epochs1),
},
Then: cgen.Stmts{
put,
cgen.Return{},
},
}
}
body[6] = put
}
if a.epochs1 < a.epochs2 {
a.sliceTile = a.sliceTile2
a.sliceScrap = a.sliceScrap2
a.coreBytes = a.datCoreBytes2
a.groupBytes = a.datGroupBytes2
a.fieldBytes = a.datFieldBytes2
body[7] = kernel(a.epochs1, 1)
}
return callee.Func(body)
}

func (a *arrangeDats) kernel1() cgen.Gen {
a.datPtrs = a.datPtrs[:0]
a.bnPtrs = a.bnPtrs[:0]
var (
stmts cgen.Stmts
tensorIdx = 0
)
decl := func(ptr, expr cgen.Gen) {
stmts = append(
stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr, Init: expr,
},
)
}
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: a.tensors,
Idx: il(i),
}
}
datPtr := func() {
var (
ptr = vb(a.name("datPtr"))
i = len(a.datPtrs)
pitch1 = a.From.Pitch1Bytes[i]
pitch2 = a.From.Pitch2Bytes[i]
)
a.datPtrs = append(a.datPtrs, ptr)
decl(
ptr, addMul(
cgen.Sub{
Expr1: tensor(),
Expr2: il(
a.PaddingH*pitch1 +
a.PaddingW*a.datBytes,
),
},
il(a.slices1*pitch2),
a.epochCoord,
),
)
}
datPtrs := func(n int) {
for ; n > 0; n-- {
datPtr()
}
}
bnPtr := func() {
ptr := vb(a.name("bnPtr"))
a.bnPtrs = append(a.bnPtrs, ptr)
decl(
ptr, &bn.Offset{
Ctx: a.bc,
Mas: tensor(),
Channel: cgen.Mul{
Expr1: il(a.slices1),
Expr2: a.epochCoord,
},
},
)
}
datPtr()
for op := range a.From.Ops {
op := &a.From.Ops[op]
switch op.Kind {
case mod.Add:
datPtrs(op.Int)
case mod.Bn:
bnPtr()
case mod.ReLU:
default:
panic("bug")
}
}
a.arranged = vb(a.name("arranged"))
decl(
a.arranged, addMul(
tensor(),
il(a.datEpochBytes1),
a.epochCoord,
),
)
return append(
stmts,
a.kernel2(),
)
}

func (a *arrangeDats) kernel2() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: cgen.Mul{
Expr1: il(a.groupTile),
Expr2: a.groupCoord,
},
}
switch a.groupTiles {
case a.groupHull:
iters = a.groupTile
case 0:
iters = a.groupScrap
}
switch iters {
case 1:
stmts[2] = a.kernel3()
default:
var (
last = vb(a.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile - 1),
Else: il(a.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: a.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: a.kernel3(),
}
}
return stmts
}

func (a *arrangeDats) kernel3() cgen.Gen {
a.coreIdx = vb(a.name("j"))
switch a.coreHull {
case 1:
a.coreLast = nil
default:
a.coreLast = vb(a.name("last"))
}
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreIdx,
Init: cgen.Mul{
Expr1: il(a.coreTile),
Expr2: a.coreCoord,
},
}
if a.coreLast != nil {
var expr cgen.Gen
switch a.coreTiles {
case a.coreHull:
expr = il(a.coreTile - 1)
case 0:
expr = il(a.coreScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.coreCoord,
Expr2: il(a.coreTiles),
},
Then: il(a.coreTile - 1),
Else: il(a.coreScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreLast,
Init: cgen.Add{
Expr1: a.coreIdx,
Expr2: expr,
},
}
}
stmts[2] = a.kernel4()
return stmts
}

func (a *arrangeDats) kernel4() cgen.Gen {
var (
lh *loopH
rel cgen.Gen
lw *loopW
)
layer7 := func() cgen.Gen {
var retIf cgen.Gen
if a.coreLast != nil {
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.coreIdx,
Expr2: a.coreLast,
},
Then: cgen.Return{},
}
}
return cgen.Stmts{
a.kernel5(),
retIf,
cgen.IncPre{
Expr: a.coreIdx,
},
}
}
layer6 := func() cgen.Gen {
if lw.fromStep == 0 {
return layer7()
}
last := vb(a.name("jj"))
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: cgen.Sub{
Expr1: il(lw.blkPast - 1),
Expr2: rel,
},
Expr2: a.coreIdx,
},
},
cgen.For{
Cond: cgen.CmpLE{
Expr1: a.coreIdx,
Expr2: last,
},
Post: cgen.AddAssign{
Expr1: a.coreW,
Expr2: il(lw.fromStep),
},
Body: layer7(),
},
}
}
layer5 := func() cgen.Gen {
a.coreW = vb(a.name("w"))
a.spans = &lw.spans
var expr cgen.Gen
switch lw.fromStep {
case 0:
expr = il(lw.fromW)
default:
expr = addMul(
il(lw.fromW-lw.blkFirst*lw.fromStep),
il(lw.fromStep),
rel,
)
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreW,
Init: expr,
},
layer6(),
}
}
layer4 := func() cgen.Gen {
var (
lws = lh.lws
tree func(int, int) cgen.Stmts
)
leaf := func(x int) cgen.Stmts {
lw = lws[x]
var assn cgen.Gen
if x+1 < len(lws) {
assn = cgen.Assign{
Expr1: rel,
Expr2: il(lw.blkPast),
}
}
return cgen.Stmts{
layer5(),
assn,
}
}
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = lws[first].blkFirst
stop = lws[last].blkPast
split = start + (stop-start)/2
x = first + 1
)
for lws[x].blkPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: rel,
Expr2: il(lws[x].blkFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(lws)-1)
}
layer3 := func() cgen.Gen {
if lh.fromStep == 0 {
return layer4()
}
return cgen.For{
Cond: cgen.CmpL{
Expr1: a.coreIdx,
Expr2: il(lh.blkPast),
},
Post: cgen.CommaSpaced{
cgen.Assign{
Expr1: rel,
Expr2: il(0),
},
cgen.AddAssign{
Expr1: a.coreH,
Expr2: il(lh.fromStep),
},
},
Body: layer4(),
}
}
layer2 := func() cgen.Gen {
rel = vb(a.name("rel"))
a.coreH = vb(a.name("h"))
var (
relExpr cgen.Gen = cgen.Sub{
Expr1: a.coreIdx,
Expr2: il(lh.blkFirst),
}
hExpr = il(lh.fromH)
)
if lh.blkStep != 0 {
var (
numer cgen.Gen = cgen.Cast{
Type: cgen.SizeT,
Expr: cgen.Paren{
Inner: relExpr,
},
}
denom = il(lh.blkStep)
)
relExpr = cgen.Rem{
Expr1: numer,
Expr2: denom,
}
hExpr = addMul(
hExpr,
cgen.Quo{
Expr1: numer,
Expr2: denom,
},
il(lh.fromStep),
)
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: rel,
Init: relExpr,
},
cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreH,
Init: hExpr,
},
layer3(),
}
}
layer1 := func() cgen.Gen {
var (
lhs = a.blks.lhs
tree func(int, int) cgen.Stmts
)
leaf := func(x int) cgen.Stmts {
lh = lhs[x]
var assn cgen.Gen
if x+1 < len(lhs) {
assn = cgen.Assign{
Expr1: a.coreIdx,
Expr2: il(lh.blkPast),
}
}
return cgen.Stmts{
layer2(),
assn,
}
}
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = lhs[first].blkFirst
stop = lhs[last].blkPast
split = start + (stop-start)/2
x = first + 1
)
for lhs[x].blkPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: a.coreIdx,
Expr2: il(lhs[x].blkFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(lhs)-1)
}
return layer1()
}

func (a *arrangeDats) kernel5() cgen.Gen {
a.sliceIdx = vb(a.name("k"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.sliceIdx,
Init: cgen.Mul{
Expr1: il(a.sliceTile),
Expr2: a.sliceCoord,
},
}
switch {
case a.sliceTiles == a.sliceHull:
iters = a.sliceTile
case a.sliceTiles == 0:
fallthrough
case a.sliceTile == a.sliceScrap:
iters = a.sliceScrap
}
switch iters {
case 1:
stmts[2] = a.kernel6()
default:
var (
last = vb(a.name("kk"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.sliceCoord,
Expr2: il(a.sliceTiles),
},
Then: il(a.sliceTile - 1),
Else: il(a.sliceScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.sliceIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: a.sliceIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: a.sliceIdx,
},
Body: a.kernel6(),
}
}
return stmts
}

func (a *arrangeDats) kernel6() cgen.Gen {
layer2 := func() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}
layer1 := func() cgen.Gen {
a.bnMuls = a.bnMuls[:0]
a.bnAdds = a.bnAdds[:0]
var (
last = len(a.bnPtrs)
gens = make(cgen.Gens, last+1)
)
ch := cgen.Paren{
Inner: addMul(
a.sliceIdx,
il(a.fromChans),
a.groupIdx,
),
}
for x, bnPtr := range a.bnPtrs {
var (
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
a.bnMuls = append(a.bnMuls, bnMul)
a.bnAdds = append(a.bnAdds, bnAdd)
gens[x] = &bn.Load{
Ctx: a.bc,
Mas: bnPtr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
gens[last] = layer2()
return gens
}
return layer1()
}

func (a *arrangeDats) m512() cgen.Gen {
var (
sbox [][]int
locs []int
sboxH int
vecIdx int
dats []cgen.Gen
sboxW int
fieldIdx int
pms map[int]cgen.Gen
)
layer6 := func() cgen.Gen {
var (
stmts = cgen.Stmts{nil}
zero = avx.Mm512SetzeroPs
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
anOff := func(from int) int {
var (
loc1 = 0
loc2 = -1
seek int
)
for x := 0; x < 16; x++ {
for w, fld := range sbox[0] {
if fld == 0 {
continue
}
at := x*a.StrideW + w
if at >= loc1+16 {
loc1 = at
}
switch loc2 {
case -1:
if at == loc1+from {
loc2 = loc1
seek = at
for seek < loc1+16 {
seek += a.StrideW
}
}
default:
if at == seek {
return loc1 - loc2
}
}
}
}
return 16
}
ctrl := func(from, off int) cgen.Gen {
if pms == nil {
pms = make(map[int]cgen.Gen)
}
pm := pms[from+off*16]
if pm == nil {
pm = vb(a.name("pm"))
if off == 0 {
off = anOff(from)
}
var (
set = make(avx.Mm512SetEpi32, 16)
at = from
)
for x := 15; x >= 0; x-- {
set[x] = il(min(at, 31))
was := at
at += a.StrideW
if was < 16 && at >= 16 {
at -= off
at += 16
}
}
stmt(cgen.Var{
Type: avx.M512i, What: pm,
Init: set,
})
pms[from+off*16] = pm
pms[from] = pm
}
return pm
}
store := func(to, lanes int, expr cgen.Gen) cgen.Gen {
var (
ae = a.arranged
mask = loMask(lanes)
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
fieldIdx*a.fieldBytes +
vecIdx*a.datVecBytes +
to*a.datBytes,
),
}
ae = addMul(ae, il(a.groupBytes), a.groupIdx)
ae = addMul(ae, il(a.coreBytes), a.coreIdx)
ae = addMul(ae, il(a.datSliceBytes), a.sliceIdx)
return avx.Mm512MaskStoreuPs{
ae, mask, expr,
}
}
var (
x = 0
w = sboxW
)
for need := 16; need > 0; {
var (
lanes = 0
nonpad = 0
)
take := func() {
lanes++
h := vecIdx*a.StrideH + sboxH
if h >= a.spanH1 && h < a.spanH2 &&
w >= a.spanW1 && w < a.spanW2 {
nonpad++
}
w += a.StrideW
}
for locs[x]+16 <= w {
x++
}
var (
ndats = 1
loc1 = locs[x]
dat1 = dats[x]
from = w - loc1
)
take()
for lanes < need && w < loc1+16 {
take()
}
var (
loc2 = loc1
dat2 cgen.Gen
)
if lanes < need {
x++
for locs[x]+16 <= w {
x++
}
ndats = 2
loc2 = locs[x]
dat2 = dats[x]
take()
for lanes < need && w < loc2+16 {
take()
}
if lanes < need && loc2+16 > a.spanW2 {
lanes = need
}
}
switch {
case nonpad == 0:
if stmts[0] == nil {
stmts[0] = store(0, 16, zero)
}
case a.StrideW == 1:
stmt(store(0, 16, dat1))
default:
var (
to = 16 - need
off = loc2 - loc1
pm = ctrl(from, off)
expr cgen.Gen
)
switch ndats {
case 1:
expr = avx.Mm512PermutexvarPs{
pm, dat1,
}
default:
switch {
case dat1 == nil:
dat1 = zero
case dat2 == nil:
dat2 = zero
}
expr = avx.Mm512Permutex2varPs{
dat1, pm, dat2,
}
}
stmt(store(to, lanes, expr))
}
need -= lanes
}
return stmts
}
layer5 := func() cgen.Gen {
var gens cgen.Gens
for w, fld := range sbox[sboxH] {
if fld == 0 {
continue
}
sboxW = w
fieldIdx = ^fld
gens = append(
gens, layer6(),
)
}
return gens
}
layer4 := func() cgen.Gen {
if dats == nil {
dats = make([]cgen.Gen, len(locs))
}
h := vecIdx*a.StrideH + sboxH
if h < a.spanH1 || h >= a.spanH2 {
for x := range dats {
dats[x] = nil
}
return layer5()
}
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
load := func(mask cgen.Gen, ptr, w int) cgen.Gen {
var (
ae = a.datPtrs[ptr]
pitch1 = a.From.Pitch1Bytes[ptr]
pitch2 = a.From.Pitch2Bytes[ptr]
groupPitch = a.fromChans * pitch2
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(h*pitch1 + w*a.datBytes),
}
ae = addMul(ae, il(groupPitch), a.groupIdx)
ae = addMul(ae, il(pitch2), a.sliceIdx)
ae = addMul(ae, il(pitch1), a.coreH)
ae = addMul(ae, il(a.datBytes), a.coreW)
return avx.Mm512MaskzLoaduPs{
mask, ae,
}
}
for x, loc := range locs {
if loc+16 <= a.spanW1 || loc >= a.spanW2 {
dats[x] = nil
continue
}
dat := vb(a.name("dat"))
dats[x] = dat
var (
lane = max(a.spanW1-loc, 0)
lanes = min(a.spanW2-loc, 16) - lane
mask1 = 1<<uint(lanes) - 1
mask2 = mask1 << uint(lane)
mask3 = il(mask2)
datPtr = 0
bnPtr = 0
)
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: load(mask3, datPtr, loc),
})
for op := range a.From.Ops {
op := &a.From.Ops[op]
switch op.Kind {
case mod.Add:
for n := op.Int; n > 0; n-- {
datPtr++
ld := load(mask3, datPtr, loc)
stmt(cgen.Assign{
Expr1: dat,
Expr2: avx.Mm512AddPs{
dat, ld,
},
})
}
case mod.Bn:
stmt(&bn.Apply{
Ctx: a.bc,
Mul: a.bnMuls[bnPtr],
Add: a.bnAdds[bnPtr],
To: dat,
Mask: mask3,
})
bnPtr++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: dat,
})
default:
panic("bug")
}
}
}
stmt(layer5())
return stmts
}
layer3 := func() cgen.Gen {
var gens cgen.Gens
for h := range sbox {
if sbox[h] == nil {
continue
}
sboxH = h
for x := 0; x < a.datSliceVecs; x++ {
vecIdx = x
gens = append(
gens, layer4(),
)
}
}
return gens
}
layer2 := func() cgen.Gen {
past := 0
for x := 0; x < 16; x++ {
for w, fld := range sbox[0] {
if fld == 0 {
continue
}
at := x*a.StrideW + w
if at < past {
continue
}
locs = append(locs, at)
past = at + 16
}
}
return layer3()
}
layer1 := func() cgen.Gen {
sbox = make([][]int, a.StrideH)
for x, fld := range a.fields {
var (
h = fld.sboxH
w = fld.sboxW
)
if sbox[h] == nil {
sbox[h] = make([]int, a.StrideW)
}
sbox[h][w] = ^x
}
return layer2()
}
return layer1()
}

type ProduceSums struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (p *ProduceSums) Prep() cgen.Gen {
p.layout = newLayout(p.Ctx, p.Spec)
const affix = "ProduceSums"
sig := fmt.Sprint(affix, " ", p.Spec)
if prior, ok := p.dedup[sig]; ok {
p.callerName = prior.(string)
return nil
}
p.callerName = p.name(p.prefix + affix)
p.dedup[sig] = p.callerName
return cgen.Gens{
&produceSums{ProduceSums: p},
cgen.Newline,
}
}

func (p *ProduceSums) Bytes() int {
return p.sumTotalBytes
}

func (p *ProduceSums) Append(to []byte) []byte {
var (
tensors = vb(p.name("tensors"))
ptrs = cgen.CommaLines(p.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(p.callerName),
Args: cgen.CommaSpaced{
p.Team, tensors,
},
},
}.Append(to)
}

type produceSums struct {
*ProduceSums
nodeTbl cgen.Gen
epochFirst int
epochCnt int
slices int
wtCoreBytes int
wtNodeBytes int
wtGroupBytes int
datCoreBytes int
datGroupBytes int
datFieldBytes int
wtTile int
wtTiles int
wtScrap int
wtHull int
calleeName string
tensors cgen.Gen
epochCoord cgen.Gen
fieldCoord cgen.Gen
nodeFirst cgen.Gen
groupCoord cgen.Gen
toCoord cgen.Gen
nodeOff cgen.Gen
wtCoord cgen.Gen
nodeCoord cgen.Gen
lift cgen.Gen
pileCoord cgen.Gen
base cgen.Gen
fromCoord cgen.Gen
biasPtr cgen.Gen
wtPtr cgen.Gen
datPtr cgen.Gen
sumPtr cgen.Gen
vecs1 int
vecs2 int
bnPre bool
bias bool
rdwr bool
wtIdx cgen.Gen
wtShort bool
}

func (p *produceSums) Append(to []byte) []byte {
nm := func(s string) string {
return p.name(p.callerName + s)
}
fieldTbl := vb(nm("FieldTbl"))
p.nodeTbl = vb(nm("NodeTbl"))
type fn func(int) cgen.Gen
table := func(arr cgen.Gen, n int, line fn) cgen.Gen {
lines := make(cgen.CommaLines, n)
for x := range lines {
lines[x] = line(x)
}
return cgen.Gens{
cgen.Static{
Tail: cgen.Var{
Type: cgen.PtrdiffT,
What: cgen.Elem{Arr: arr},
Init: cgen.Brace{Inner: lines},
},
},
cgen.Newline,
cgen.Newline,
}
}
callee := func(first, cnt int) cgen.Gen {
p.epochFirst = first
p.epochCnt = cnt
switch {
case first < p.epochs1:
p.slices = p.slices1
p.wtCoreBytes = p.wtCoreBytes11
p.wtNodeBytes = p.wtNodeBytes1
p.wtGroupBytes = p.wtGroupBytes1
p.datCoreBytes = p.datCoreBytes1
p.datGroupBytes = p.datGroupBytes1
p.datFieldBytes = p.datFieldBytes1
default:
p.slices = p.slices2
p.wtCoreBytes = p.wtCoreBytes21
p.wtNodeBytes = p.wtNodeBytes2
p.wtGroupBytes = p.wtGroupBytes2
p.datCoreBytes = p.datCoreBytes2
p.datGroupBytes = p.datGroupBytes2
p.datFieldBytes = p.datFieldBytes2
}
var threadSlices int
switch p.platform {
case raw.AVX512Float32:
threadSlices = 512
default:
panic("bug")
}
var (
tile = ceilQuo(threadSlices, p.slices)
tiles = max(p.wtCores2/tile, 1)
)
p.wtTile = p.wtCores2 / tiles
p.wtTiles = tiles
p.wtScrap = p.wtCores2 - tiles*p.wtTile
p.wtHull = tiles
if p.wtScrap > 0 {
p.wtTiles--
p.wtScrap += p.wtTile
}
p.calleeName = nm("Callee")
return cgen.Gens{
p.calleeFunc(),
cgen.Newline,
}
}
var (
team = vb(p.name("team"))
tensors = vb(p.name("tensors"))
tuple = vb(p.name("tuple"))
)
store := func(x int, expr cgen.Gen) cgen.Gen {
return cgen.Assign{
Expr1: cgen.Elem{
Arr: tuple, Idx: il(x),
},
Expr2: cgen.Cast{
Type: cgen.PtrVoid,
Expr: expr,
},
}
}
loop3 := func(fieldCoord cgen.Gen) cgen.Gen {
var (
nodeFirst = vb(p.name("node"))
nodeStep = vb(p.name("step"))
nodePast = vb(p.name("past"))
)
load := func(x int, what cgen.Gen) cgen.Gen {
return cgen.Var{
Type: cgen.PtrdiffT,
What: what,
Init: cgen.Elem{
Arr: fieldTbl,
Idx: addMul(
il(x),
il(2),
fieldCoord,
),
},
}
}
return cgen.Stmts{
load(0, nodeFirst),
load(1, nodeStep),
load(2, nodePast),
cgen.For{
Cond: cgen.CmpL{
Expr1: nodeFirst,
Expr2: nodePast,
},
Post: cgen.AddAssign{
Expr1: nodeFirst,
Expr2: nodeStep,
},
Body: cgen.Stmts{
store(3, nodeFirst),
&threader.Do{
Ctx: p.tc,
Callee: vb(p.calleeName),
Any: tuple,
Hull: []cgen.Gen{
il(p.wtHull),
nodeStep,
il(p.sumCores),
il(p.Groups),
},
Team: team,
},
},
},
}
}
loop2 := func() cgen.Gen {
fieldCoord := vb(p.name("field"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: fieldCoord,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: fieldCoord,
Expr2: il(len(p.fields)),
},
Post: cgen.IncPre{
Expr: fieldCoord,
},
Body: cgen.Stmts{
store(2, fieldCoord),
loop3(fieldCoord),
},
}
}
loop1 := func() cgen.Gen {
epochCoord := vb(p.name("epoch"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: epochCoord,
Init: il(p.epochFirst),
},
Cond: cgen.CmpL{
Expr1: epochCoord,
Expr2: il(
p.epochFirst + p.epochCnt,
),
},
Post: cgen.IncPre{
Expr: epochCoord,
},
Body: cgen.Stmts{
store(1, epochCoord),
loop2(),
},
}
}
var (
prep = make(cgen.Gens, 4)
body = make(cgen.Stmts, 4)
)
prep[0] = table(
fieldTbl,
len(p.fields),
func(x int) cgen.Gen {
var (
fld = p.fields[x]
past cgen.Gen
)
if x+1 == len(p.fields) {
past = il(len(p.nodes))
}
return cgen.CommaSpaced{
il(fld.nodeFirst),
il(fld.nodeStep),
past,
}
},
)
prep[1] = table(
p.nodeTbl,
len(p.nodes),
func(x int) cgen.Gen {
nd := p.nodes[x]
return cgen.CommaSpaced{
il(p.lifts[nd.deck]),
il(nd.pile),
il(btoi(nd.base)),
}
},
)
body[0] = cgen.Var{
Type: cgen.PtrVoid,
What: cgen.Elem{
Arr: tuple, Idx: il(4),
},
}
body[1] = cgen.Assign{
Expr1: cgen.Elem{
Arr: tuple, Idx: il(0),
},
Expr2: tensors,
}
if p.epochs1 > 0 {
prep[2] = callee(0, p.epochs1)
body[2] = loop1()
}
if p.epochs1 < p.epochs2 {
prep[3] = callee(p.epochs1, 1)
body[3] = loop1()
}
return cgen.Gens{
prep,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: p.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: body,
},
}.Append(to)
}

func (p *produceSums) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: p.tc,
Name: p.calleeName,
Task: vb(p.name("task")),
Pt: vb(p.name("pt")),
}
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
tuple := vb(p.name("tuple"))
stmt(cgen.Var{
Type: cgen.PtrPtrVoid, What: tuple,
Init: callee.Any(),
})
p.tensors = vb(p.name("tensors"))
stmt(cgen.Var{
Type: cgen.PtrPtrChar, What: p.tensors,
Init: cgen.Elem{
Arr: tuple, Idx: il(0),
},
})
tupleIdx := 1
tupleVar := func(nm string, expr cgen.Gen) cgen.Gen {
ret := vb(p.name(nm))
if expr == nil {
expr = cgen.Cast{
Type: cgen.PtrdiffT,
Expr: cgen.Elem{
Arr: tuple, Idx: il(tupleIdx),
},
}
}
tupleIdx++
stmt(cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
})
return ret
}
p.epochCoord = tupleVar(
"epoch",
func() cgen.Gen {
if p.epochCnt == 1 {
return il(p.epochFirst)
}
return nil
}(),
)
p.fieldCoord = tupleVar(
"field",
func() cgen.Gen {
if len(p.fields) == 1 {
return il(0)
}
return nil
}(),
)
p.nodeFirst = tupleVar(
"nodeFirst",
func() cgen.Gen {
step := p.fields[0].nodeStep
if step == len(p.nodes) {
return il(0)
}
return nil
}(),
)
var (
ptIdx = 3
ptUsed = false
)
ptVar := func(nm string, expr cgen.Gen) cgen.Gen {
ret := vb(p.name(nm))
if expr == nil {
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(ptIdx),
}
ptUsed = true
}
ptIdx--
stmt(cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
})
return ret
}
p.groupCoord = ptVar(
"group",
func() cgen.Gen {
if p.Groups == 1 {
return il(0)
}
return nil
}(),
)
p.toCoord = ptVar(
"to",
func() cgen.Gen {
if p.sumCores == 1 {
return il(0)
}
return nil
}(),
)
p.nodeOff = ptVar(
"nodeOff",
func() cgen.Gen {
for _, fld := range p.fields {
if fld.nodeStep > 1 {
return nil
}
}
return il(0)
}(),
)
p.wtCoord = ptVar(
"w",
func() cgen.Gen {
if p.wtHull == 1 {
return il(0)
}
return nil
}(),
)
if !ptUsed {
stmt(cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
})
}
stmt(p.kernel1())
return callee.Func(stmts)
}

func (p *produceSums) kernel1() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
decl := func(nm string, expr cgen.Gen) cgen.Gen {
ret := vb(p.name(nm))
stmt(cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
})
return ret
}
p.nodeCoord = decl(
"node",
cgen.Add{
Expr1: p.nodeFirst,
Expr2: p.nodeOff,
},
)
var (
tblOff = 0
tblUsed = false
)
tblVar := func(nm string, expr cgen.Gen) cgen.Gen {
if expr == nil {
expr = cgen.Elem{
Arr: p.nodeTbl,
Idx: addMul(
il(tblOff),
il(3),
p.nodeCoord,
),
}
tblUsed = true
}
tblOff++
return decl(nm, expr)
}
p.lift = tblVar(
"lift",
func() cgen.Gen {
if len(p.lifts) == 1 {
return il(0)
}
return nil
}(),
)
p.pileCoord = tblVar(
"pile",
func() cgen.Gen {
if len(p.shifts) == 1 {
return il(0)
}
return nil
}(),
)
p.base = tblVar(
"base",
func() cgen.Gen {
for _, nd := range p.nodes {
if !nd.base {
return nil
}
}
return il(1)
}(),
)
if !tblUsed {
stmt(cgen.Cast{
Type: cgen.Void,
Expr: p.nodeTbl,
})
}
p.fromCoord = decl(
"from",
addMul(
p.toCoord,
cgen.Quo{
Expr1: cgen.Cast{
Type: cgen.SizeT,
Expr: p.lift,
},
Expr2: il(p.datSliceVecs),
},
il(p.blkStep),
),
)
stmt(cgen.If1{
Cond: cgen.CmpGE{
Expr1: p.fromCoord,
Expr2: il(p.datCores),
},
Then: cgen.Return{},
})
stmt(p.kernel2())
return stmts
}

func (p *produceSums) kernel2() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
decl := func(nm string, expr cgen.Gen) cgen.Gen {
ret := vb(p.name(nm))
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: ret, Init: expr,
})
return ret
}
tensor := func(x int) cgen.Gen {
return cgen.Elem{
Arr: p.tensors, Idx: il(x),
}
}
p.biasPtr = decl(
"biasPtr",
func() (ae cgen.Gen) {
ae = tensor(0)
ae = addMul(ae, il(p.biasEpochBytes), p.epochCoord)
ae = addMul(ae, il(p.biasGroupBytes), p.groupCoord)
return
}(),
)
p.wtPtr = decl(
"wtPtr",
func() (ae cgen.Gen) {
ae = cgen.Add{
Expr1: tensor(0),
Expr2: il(p.biasTotalBytes),
}
ae = addMul(ae, il(p.wtEpochBytes1), p.epochCoord)
ae = addMul(ae, il(p.wtGroupBytes), p.groupCoord)
ae = addMul(ae, il(p.wtNodeBytes), p.nodeCoord)
return
}(),
)
p.datPtr = decl(
"datPtr",
func() (ae cgen.Gen) {
ae = tensor(1)
ae = addMul(ae, il(p.datEpochBytes1), p.epochCoord)
ae = addMul(ae, il(p.datFieldBytes), p.fieldCoord)
ae = addMul(ae, il(p.datGroupBytes), p.groupCoord)
ae = addMul(ae, il(p.datCoreBytes), p.fromCoord)
return
}(),
)
p.sumPtr = decl(
"sumPtr",
func() (ae cgen.Gen) {
ae = tensor(2)
ae = addMul(ae, il(p.sumGroupBytes), p.groupCoord)
ae = addMul(ae, il(p.sumCoreBytes), p.toCoord)
ae = addMul(ae, il(p.sumPileBytes), p.pileCoord)
return
}(),
)
stmt(p.kernel3())
return stmts
}

func (p *produceSums) kernel3() cgen.Gen {
if len(p.lifts) == 1 {
p.vecs1 = 0
p.vecs2 = p.datSliceVecs
return p.kernel4()
}
var (
cases cgen.Stmts
need = 0
pair = 1
)
if p.sumCores > p.blkStep {
pair = 3
}
for _, lift := range p.lifts {
vecs := lift % p.datSliceVecs
need |= pair << uint(vecs*2)
}
for x := 0; need != 0; x, need = x+1, need>>1 {
if x == 1 || need&1 == 0 {
continue
}
p.vecs1 = x >> 1
p.vecs2 = p.datSliceVecs - p.vecs1
if x&1 == 0 {
p.vecs1 = 0
}
cases = append(
cases,
cgen.Case{
Expr: func() cgen.Gen {
if x == 0 {
return nil
}
return il(x)
}(),
Body: cgen.Stmts{
p.kernel4(),
cgen.Break,
},
},
)
}
return cgen.Switch{
Expr: cgen.Add{
Expr1: cgen.Mul{
Expr1: cgen.Rem{
Expr1: cgen.Cast{
Type: cgen.SizeT,
Expr: p.lift,
},
Expr2: il(p.datSliceVecs),
},
Expr2: il(2),
},
Expr2: cgen.Paren{
Inner: cgen.CmpGE{
Expr1: p.toCoord,
Expr2: il(p.blkStep),
},
},
},
Cases: cases,
}
}

func (p *produceSums) kernel4() cgen.Gen {
p.bnPre = false
for x := range p.Filts {
if p.Filts[x].BnPre > 0 {
p.bnPre = true
break
}
}
do := func(b bool) cgen.Gen {
p.bias = b
return p.kernel5()
}
cond := p.nodeCoord
switch {
case p.bnPre:
if len(p.nodes) == 1 {
return do(true)
}
default:
if p.epochFirst > 0 {
return do(false)
}
if p.epochCnt == 1 && len(p.nodes) == 1 {
return do(true)
}
cond = cgen.Or{
Expr1: p.epochCoord,
Expr2: cond,
}
}
return cgen.Stmts{
cgen.If{
Cond: cond,
Then: cgen.Stmts{
do(false),
cgen.Return{},
},
},
do(true),
}
}

func (p *produceSums) kernel5() cgen.Gen {
used := false
do := func(rw bool) cgen.Gen {
p.rdwr = rw
var cast cgen.Gen
if !used {
cast = cgen.Cast{
Type: cgen.Void,
Expr: p.base,
}
}
return cgen.Stmts{
cast,
p.kernel6(),
}
}
if p.epochFirst > 0 {
return do(true)
}
if p.epochCnt == 1 {
if p.bias {
return do(false)
}
all := true
for _, nd := range p.nodes {
if !nd.base {
all = false
break
}
}
if all {
return do(false)
}
}
if !p.bnPre && p.bias {
return do(false)
}
used = true
return cgen.Stmts{
cgen.If{
Cond: cgen.Land{
Expr1: cgen.IsZero{
Expr: p.epochCoord,
},
Expr2: p.base,
},
Then: cgen.Stmts{
do(false),
cgen.Return{},
},
},
do(true),
}
}

func (p *produceSums) kernel6() cgen.Gen {
p.wtIdx = vb(p.name("i"))
var (
stmts = make(cgen.Stmts, 4)
retIf cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: p.wtIdx,
Init: cgen.Mul{
Expr1: il(p.wtTile),
Expr2: p.wtCoord,
},
}
if p.wtHull > 1 {
var (
last = vb(p.name("ii"))
expr cgen.Gen
)
switch p.wtTiles {
case p.wtHull:
expr = il(p.wtTile - 1)
case 0:
expr = il(p.wtScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: p.wtCoord,
Expr2: il(p.wtTiles),
},
Then: il(p.wtTile - 1),
Else: il(p.wtScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: p.wtIdx,
Expr2: expr,
},
}
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: p.wtIdx,
Expr2: last,
},
Then: cgen.Return{},
}
}
if p.wtCores1 > 0 {
p.wtShort = false
stmts[2] = cgen.For{
Cond: cgen.CmpNE{
Expr1: p.wtIdx,
Expr2: il(p.wtCores1),
},
Post: cgen.IncPre{
Expr: p.wtIdx,
},
Body: cgen.Stmts{
p.kernel7(),
retIf,
},
}
}
if p.wtCores1 < p.wtCores2 {
p.wtShort = true
stmts[3] = p.kernel7()
}
return stmts
}

func (p *produceSums) kernel7() cgen.Gen {
switch p.platform {
case raw.AVX512Float32:
return p.m512()
default:
panic("bug")
}
}

func (p *produceSums) m512() cgen.Gen {
var (
rows int
cols int
sums [][]cgen.Gen
sliceIdx cgen.Gen
dats []cgen.Gen
)
layer8 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
slicePitch := p.wtSliceBytes1
if p.wtShort {
slicePitch = p.wtSliceBytes2
}
for r, sums := range sums {
var (
ae = p.wtPtr
wt = vb(p.name("wt"))
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(r * p.wtBytes),
}
ae = addMul(ae, il(p.wtCoreBytes), p.wtIdx)
ae = addMul(ae, il(slicePitch), sliceIdx)
stmt(cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512Set1Ps{
cgen.At{
Expr: cgen.Cast{
Type: cgen.PtrFloat,
Expr: cgen.Paren{
Inner: ae,
},
},
},
},
})
for c, sum := range sums {
stmt(cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
wt, dats[c], sum,
},
})
}
}
return stmts
}
layer7 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
dats = make([]cgen.Gen, cols)
for c := range dats {
dat := vb(p.name("dat"))
dats[c] = dat
}
for c, dat := range dats {
ae := addMul(
cgen.Add{
Expr1: p.datPtr,
Expr2: il(
p.datSliceBytes +
(c-cols)*p.datVecBytes,
),
},
il(p.datSliceBytes),
sliceIdx,
)
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512LoaduPs{ae},
})
}
stmt(layer8())
return stmts
}
layer6 := func() cgen.Gen {
sliceIdx = vb(p.name("j"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: sliceIdx,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: sliceIdx,
Expr2: il(p.slices),
},
Post: cgen.IncPre{
Expr: sliceIdx,
},
Body: layer7(),
}
}
layer5 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
stmt(layer6())
for r, sums := range sums {
for c, sum := range sums {
off := r * p.datSliceBytes
off += c * p.datVecBytes
switch {
case c < p.vecs1:
off -= p.blkStep * p.sumCoreBytes
off += p.vecs2 * p.datVecBytes
default:
off -= p.vecs1 * p.datVecBytes
}
ae := addMul(
cgen.Add{
Expr1: p.sumPtr,
Expr2: il(off),
},
il(p.sumSiteBytes1),
p.wtIdx,
)
if p.rdwr {
sum = avx.Mm512AddPs{
sum,
avx.Mm512LoaduPs{ae},
}
}
stmt(avx.Mm512StoreuPs{
ae, sum,
})
}
}
return stmts
}
layer4 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
for _, sums := range sums {
for c := 1; c < cols; c++ {
stmt(cgen.Var{
Type: avx.M512,
What: sums[c],
Init: sums[0],
})
}
}
stmt(layer5())
return stmts
}
layer3 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
for r, sums := range sums {
var expr cgen.Gen
switch {
case p.bias:
expr = addMul(
cgen.Add{
Expr1: p.biasPtr,
Expr2: il(r * p.biasBytes),
},
il(p.wtSliceWts1*p.biasBytes),
p.wtIdx,
)
expr = avx.Mm512Set1Ps{
cgen.At{
Expr: cgen.Cast{
Type: cgen.PtrFloat,
Expr: cgen.Paren{
Inner: expr,
},
},
},
}
default:
expr = avx.Mm512SetzeroPs
}
stmt(cgen.Var{
Type: avx.M512,
What: sums[0],
Init: expr,
})
}
if !p.bias {
stmt(cgen.Cast{
Type: cgen.Void,
Expr: p.biasPtr,
})
}
stmt(layer4())
return stmts
}
layer2 := func() cgen.Gen {
sums = make([][]cgen.Gen, rows)
for r := range sums {
sums[r] = make([]cgen.Gen, cols)
for c := range sums[r] {
sum := vb(p.name("sum"))
sums[r][c] = sum
}
}
return layer3()
}
layer1 := func() cgen.Gen {
rows = p.wtSliceWts1
if p.wtShort {
rows = p.wtSliceWts2
}
cols = p.vecs1 + p.vecs2
return layer2()
}
return layer1()
}

type ConsumeSums struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
callerName string
}

func (c *ConsumeSums) Prep() cgen.Gen {
const affix = "ConsumeSums"
sig := fmt.Sprint(affix, " ", c.Spec)
if prior, ok := c.dedup[sig]; ok {
c.callerName = prior.(string)
return nil
}
c.callerName = c.name(c.prefix + affix)
c.dedup[sig] = c.callerName
return cgen.Gens{
&consumeSums{ConsumeSums: c},
cgen.Newline,
}
}

func (c *ConsumeSums) Append(to []byte) []byte {
var (
tensors = vb(c.name("tensors"))
ptrs = cgen.CommaLines(c.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(c.callerName),
Args: cgen.CommaSpaced{
c.Team, tensors,
},
},
}.Append(to)
}

type consumeSums struct {
*ConsumeSums
*layout
remH int
remW int
strips1 int
strips2 int
cells1 int
cells2 int
cellTile int
cellTiles int
cellScrap int
cellHull int
stripTile int
stripTiles int
stripScrap int
stripHull int
chanTile int
chanTiles int
chanScrap int
chanHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
cellCoord cgen.Gen
stripCoord cgen.Gen
chanCoord cgen.Gen
groupCoord cgen.Gen
sumPtr cgen.Gen
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
groupIdx cgen.Gen
chanIdx cgen.Gen
bnMuls []cgen.Gen
bnAdds []cgen.Gen
stripIdx cgen.Gen
shortH bool
cellIdx cgen.Gen
shortW bool
}

func (c *consumeSums) Append(to []byte) []byte {
c.layout = newLayout(c.Ctx, c.Spec)
var (
yield = func(from, pad, filt, dila, str int) int {
var (
n1 = from + 2*pad
n2 = 1 + (filt-1)*dila
)
return (n1-n2)/str + 1
}
yieldH = yield(
c.From.Height, c.PaddingH,
c.FilterH, c.DilationH,
c.StrideH,
)
yieldW = yield(
c.From.Width, c.PaddingW,
c.FilterW, c.DilationW,
c.StrideW,
)
)
c.remH = yieldH % c.datSliceVecs
c.remW = yieldW % c.datVecDats
c.strips1 = yieldH / c.datSliceVecs
c.strips2 = c.strips1 + btoi(c.remH > 0)
c.cells1 = yieldW / c.datVecDats
c.cells2 = c.cells1 + btoi(c.remW > 0)
var (
cellWork = len(c.shifts)
stripWork = c.cells2 * cellWork
chanWork = c.strips2 * stripWork
groupWork = c.toChans * chanWork
threadWork int
)
switch c.platform {
case raw.AVX512Float32:
threadWork = 512
default:
panic("bug")
}
c.cellTile = c.cells2
c.cellTiles = 1
c.cellScrap = 0
c.cellHull = 1
c.stripTile = c.strips2
c.stripTiles = 1
c.stripScrap = 0
c.stripHull = 1
c.chanTile = 1
c.chanTiles = c.toChans
c.chanScrap = 0
c.chanHull = c.toChans
c.groupTile = 1
c.groupTiles = c.Groups
c.groupScrap = 0
c.groupHull = c.Groups
switch {
case threadWork <= stripWork:
var (
tile = ceilQuo(threadWork, cellWork)
tiles = max(c.cells2/tile, 1)
)
c.cellTile = c.cells2 / tiles
c.cellTiles = tiles
c.cellScrap = c.cells2 - tiles*c.cellTile
c.cellHull = tiles
if c.cellScrap > 0 {
c.cellTiles--
c.cellScrap += c.cellTile
}
c.stripTile = 1
c.stripTiles = c.strips2
c.stripScrap = 0
c.stripHull = c.strips2
case threadWork <= chanWork:
var (
tile = ceilQuo(threadWork, stripWork)
tiles = max(c.strips2/tile, 1)
)
c.stripTile = c.strips2 / tiles
c.stripTiles = tiles
c.stripScrap = c.strips2 - tiles*c.stripTile
c.stripHull = tiles
if c.stripScrap > 0 {
c.stripTiles--
c.stripScrap += c.stripTile
}
case threadWork <= groupWork:
var (
tile = ceilQuo(threadWork, chanWork)
tiles = max(c.toChans/tile, 1)
)
c.chanTile = c.toChans / tiles
c.chanTiles = tiles
c.chanScrap = c.toChans - tiles*c.chanTile
c.chanHull = tiles
if c.chanScrap > 0 {
c.chanTiles--
c.chanScrap += c.chanTile
}
default:
c.chanTile = c.toChans
c.chanTiles = 1
c.chanScrap = 0
c.chanHull = 1
var (
tile = ceilQuo(threadWork, groupWork)
tiles = max(c.Groups/tile, 1)
)
c.groupTile = c.Groups / tiles
c.groupTiles = tiles
c.groupScrap = c.Groups - tiles*c.groupTile
c.groupHull = tiles
if c.groupScrap > 0 {
c.groupTiles--
c.groupScrap += c.groupTile
}
}
c.calleeName = c.name(c.callerName + "Callee")
var (
team = vb(c.name("team"))
tensors = vb(c.name("tensors"))
)
return cgen.Gens{
c.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: c.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: c.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: c.tc,
Callee: vb(c.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(c.cellHull),
il(c.stripHull),
il(c.chanHull),
il(c.groupHull),
},
Team: team,
},
},
}.Append(to)
}

func (c *consumeSums) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: c.tc,
Name: c.calleeName,
Task: vb(c.name("task")),
Pt: vb(c.name("pt")),
}
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
c.tensors = vb(c.name("tensors"))
stmt(cgen.Var{
Type: cgen.PtrPtrChar, What: c.tensors,
Init: callee.Any(),
})
var (
ptIdx = 0
ptUsed = false
)
ptVar := func(nm string, hull int) cgen.Gen {
var (
ret = vb(c.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(ptIdx),
}
ptUsed = true
}
ptIdx++
stmt(cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
})
return ret
}
c.cellCoord = ptVar("cell", c.cellHull)
c.stripCoord = ptVar("strip", c.stripHull)
c.chanCoord = ptVar("chan", c.chanHull)
c.groupCoord = ptVar("group", c.groupHull)
if !ptUsed {
stmt(cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
})
}
stmt(c.kernel1())
return callee.Func(stmts)
}

func (c *consumeSums) kernel1() cgen.Gen {
c.datPtrs = nil
c.bnPtrs = nil
var (
stmts cgen.Stmts
tensorIdx = 0
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
decl := func(ptr cgen.Gen) {
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr,
Init: cgen.Elem{
Arr: c.tensors,
Idx: il(tensorIdx),
},
})
tensorIdx++
}
decls := func(n int) {
for ; n > 0; n-- {
datPtr := vb(c.name("datPtr"))
c.datPtrs = append(c.datPtrs, datPtr)
decl(datPtr)
}
}
c.sumPtr = vb(c.name("sumPtr"))
decl(c.sumPtr)
for op := range c.To.Ops {
op := &c.To.Ops[op]
switch op.Kind {
case mod.Add:
decls(op.Int)
case mod.Bn:
bnPtr := vb(c.name("bnPtr"))
c.bnPtrs = append(c.bnPtrs, bnPtr)
decl(bnPtr)
case mod.ReLU:
default:
panic("bug")
}
}
var (
need = len(c.To.Pitch1Bytes)
have = len(c.datPtrs)
)
decls(need - have)
stmt(c.kernel2())
return stmts
}

func (c *consumeSums) kernel2() cgen.Gen {
c.groupIdx = vb(c.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.groupIdx,
Init: cgen.Mul{
Expr1: il(c.groupTile),
Expr2: c.groupCoord,
},
}
switch c.groupTiles {
case c.groupHull:
iters = c.groupTile
case 0:
iters = c.groupScrap
}
switch iters {
case 1:
stmts[2] = c.kernel3()
default:
var (
last = vb(c.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.groupCoord,
Expr2: il(c.groupTiles),
},
Then: il(c.groupTile - 1),
Else: il(c.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: c.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: c.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: c.groupIdx,
},
Body: c.kernel3(),
}
}
return stmts
}

func (c *consumeSums) kernel3() cgen.Gen {
c.chanIdx = vb(c.name("j"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.chanIdx,
Init: cgen.Mul{
Expr1: il(c.chanTile),
Expr2: c.chanCoord,
},
}
switch c.chanTiles {
case c.chanHull:
iters = c.chanTile
case 0:
iters = c.chanScrap
}
switch iters {
case 1:
stmts[2] = c.kernel4()
default:
var (
last = vb(c.name("jj"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.chanCoord,
Expr2: il(c.chanTiles),
},
Then: il(c.chanTile - 1),
Else: il(c.chanScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: c.chanIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: c.chanIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: c.chanIdx,
},
Body: c.kernel4(),
}
}
return stmts
}

func (c *consumeSums) kernel4() cgen.Gen {
c.bnMuls = nil
c.bnAdds = nil
var (
last = len(c.bnPtrs)
gens = make(cgen.Gens, last+1)
)
ch := cgen.Paren{
Inner: addMul(
c.chanIdx,
il(c.toChans),
c.groupIdx,
),
}
for x, bnPtr := range c.bnPtrs {
var (
bnMul = vb(c.name("bnMul"))
bnAdd = vb(c.name("bnAdd"))
)
c.bnMuls = append(c.bnMuls, bnMul)
c.bnAdds = append(c.bnAdds, bnAdd)
gens[x] = &bn.Load{
Ctx: c.bc,
Mas: bnPtr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
gens[last] = c.kernel5()
return gens
}

func (c *consumeSums) kernel5() cgen.Gen {
c.stripIdx = vb(c.name("k"))
var (
stmts = make(cgen.Stmts, 4)
retIf cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.stripIdx,
Init: cgen.Mul{
Expr1: il(c.stripTile),
Expr2: c.stripCoord,
},
}
if c.stripHull > 1 {
var (
last = vb(c.name("kk"))
expr cgen.Gen
)
switch c.stripTiles {
case c.stripHull:
expr = il(c.stripTile - 1)
case 0:
expr = il(c.stripScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.stripCoord,
Expr2: il(c.stripTiles),
},
Then: il(c.stripTile - 1),
Else: il(c.stripScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: c.stripIdx,
Expr2: expr,
},
}
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: c.stripIdx,
Expr2: last,
},
Then: cgen.Return{},
}
}
if c.strips1 > 0 {
c.shortH = false
stmts[2] = cgen.For{
Cond: cgen.CmpNE{
Expr1: c.stripIdx,
Expr2: il(c.strips1),
},
Post: cgen.IncPre{
Expr: c.stripIdx,
},
Body: cgen.Stmts{
c.kernel6(),
retIf,
},
}
}
if c.strips1 < c.strips2 {
c.shortH = true
stmts[3] = c.kernel6()
}
return stmts
}

func (c *consumeSums) kernel6() cgen.Gen {
c.cellIdx = vb(c.name("l"))
var (
stmts = make(cgen.Stmts, 4)
retIf cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.cellIdx,
Init: cgen.Mul{
Expr1: il(c.cellTile),
Expr2: c.cellCoord,
},
}
if c.cellHull > 1 {
var (
last = vb(c.name("ll"))
expr cgen.Gen
)
switch c.cellTiles {
case c.cellHull:
expr = il(c.cellTile - 1)
case 0:
expr = il(c.cellScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.cellCoord,
Expr2: il(c.cellTiles),
},
Then: il(c.cellTile - 1),
Else: il(c.cellScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: c.cellIdx,
Expr2: expr,
},
}
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: c.cellIdx,
Expr2: last,
},
Then: cgen.Return{},
}
}
if c.cells1 > 0 {
c.shortW = false
stmts[2] = cgen.For{
Cond: cgen.CmpNE{
Expr1: c.cellIdx,
Expr2: il(c.cells1),
},
Post: cgen.IncPre{
Expr: c.cellIdx,
},
Body: cgen.Stmts{
c.kernel7(),
retIf,
},
}
}
if c.cells1 < c.cells2 {
c.shortW = true
stmts[3] = c.kernel7()
}
return stmts
}

func (c *consumeSums) kernel7() cgen.Gen {
switch c.platform {
case raw.AVX512Float32:
return c.m512()
default:
panic("bug")
}
}

func (c *consumeSums) m512() cgen.Gen {
var (
rows int
cols int
rowIdx int
stmts cgen.Stmts
out cgen.Gen
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
layer4 := func() {
var (
datPtr = 0
mask = loMask(cols)
bnPtr = 0
)
ae := func() cgen.Gen {
var (
ret = c.datPtrs[datPtr]
pitch1 = c.To.Pitch1Bytes[datPtr]
pitch2 = c.To.Pitch2Bytes[datPtr]
groupPitch = c.toChans * pitch2
stripPitch = c.datSliceVecs * pitch1
)
ret = cgen.Add{
Expr1: ret,
Expr2: il(rowIdx * pitch1),
}
ret = addMul(ret, il(groupPitch), c.groupIdx)
ret = addMul(ret, il(pitch2), c.chanIdx)
ret = addMul(ret, il(stripPitch), c.stripIdx)
ret = addMul(ret, il(c.datVecBytes), c.cellIdx)
return ret
}
for op := range c.To.Ops {
op := &c.To.Ops[op]
switch op.Kind {
case mod.Add:
for n := op.Int; n > 0; n-- {
stmt(cgen.Assign{
Expr1: out,
Expr2: avx.Mm512AddPs{
out,
avx.Mm512MaskzLoaduPs{
mask, ae(),
},
},
})
datPtr++
}
case mod.Bn:
stmt(&bn.Apply{
Ctx: c.bc,
Mul: c.bnMuls[bnPtr],
Add: c.bnAdds[bnPtr],
To: out,
})
bnPtr++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: c.ac,
NegSlope: op.Float,
Var: out,
})
default:
panic("bug")
}
}
for datPtr < len(c.datPtrs) {
stmt(avx.Mm512MaskStoreuPs{
ae(), mask, out,
})
datPtr++
}
}
layer3 := func() {
var (
trees []cgen.Gen
)
load := func(coreOff, pileIdx int) cgen.Gen {
var (
ae = c.sumPtr
stripPitch = c.blkStep * c.sumCoreBytes
ret = vb(c.name("load"))
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
coreOff*c.sumCoreBytes +
pileIdx*c.sumPileBytes +
rowIdx*c.datVecBytes,
),
}
ae = addMul(ae, il(c.sumGroupBytes), c.groupIdx)
ae = addMul(ae, il(stripPitch), c.stripIdx)
ae = addMul(ae, il(c.sumCoreBytes), c.cellIdx)
ae = addMul(ae, il(c.datSliceBytes), c.chanIdx)
stmt(cgen.Var{
Type: avx.M512, What: ret,
Init: avx.Mm512LoaduPs{ae},
})
return ret
}
cast := func(ps cgen.Gen) cgen.Gen {
ret := vb(c.name("cast"))
stmt(cgen.Var{
Type: avx.M512i, What: ret,
Init: avx.Mm512CastpsSi512{ps},
})
return ret
}
join := func(hi, lo cgen.Gen, drop int) cgen.Gen {
var (
ret = vb(c.name("join"))
castLo = cast(lo)
castHi = castLo
)
if hi != nil {
castHi = cast(hi)
}
stmt(cgen.Var{
Type: avx.M512, What: ret,
Init: avx.Mm512Castsi512Ps{
avx.Mm512AlignrEpi32{
castHi, castLo,
il(drop),
},
},
})
return ret
}
add := func(older, newer cgen.Gen) cgen.Gen {
ret := vb(c.name("add"))
stmt(cgen.Var{
Type: avx.M512, What: ret,
Init: avx.Mm512AddPs{
older, newer,
},
})
return ret
}
sublayer3 := func() {
out = nil
for _, tree := range trees {
if tree != nil {
switch out {
case nil:
out = tree
default:
out = add(tree, out)
}
}
}
}
sublayer2 := func() {
coreOff := 0
for pileIdx, shift := range c.shifts {
at := coreOff*16 - shift
for ; at <= -16; at += 16 {
coreOff++
}
tree1 := load(coreOff, pileIdx)
if at < 0 {
var hi cgen.Gen
if at+16 < cols {
hi = load(coreOff+1, pileIdx)
}
tree1 = join(hi, tree1, -at)
}
for x, tree2 := range trees {
if tree2 == nil {
trees[x] = tree1
break
}
tree1 = add(tree2, tree1)
trees[x] = nil
}
}
sublayer3()
}
sublayer1 := func() {
var (
n1 = len(c.shifts)
n2 = 0
)
for ; n1 > 0; n1 >>= 1 {
n2++
}
trees = make([]cgen.Gen, n2)
sublayer2()
}
sublayer1()
layer4()
}
layer2 := func() cgen.Gen {
toMix := make([]cgen.Stmts, rows)
for x := range toMix {
rowIdx = x
stmts = nil
layer3()
toMix[x] = stmts
}
return mix(toMix)
}
layer1 := func() cgen.Gen {
rows = c.datSliceVecs
if c.shortH {
rows = c.remH
}
cols = c.datVecDats
if c.shortW {
cols = c.remW
}
return layer2()
}
return layer1()
}

Top || internal/compile/author/mod/mod.go

package mod

type Kind int

const (
Add Kind = iota
Bn
ReLU
)

type Op struct {
Kind
Int int
Float float32
}

Top || internal/compile/author/net/net.go

package net

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

type Ctx struct {
StructName string
StructAlloc string
StructAlign string
Alignment int
paramsName string
createName string
CreateNet cgen.Gen
CreateParams cgen.Gen
CreateThreads cgen.Gen
destroyName string
destroyNet cgen.Gen
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, paramsName string) *Ctx {
var (
structName = pl.Config.Prefix + "Net"
alignment int
)
switch pl.Config.Platform {
case raw.AVX512Float32:
alignment = 64
default:
panic("bug")
}
return &Ctx{
StructName: structName,
StructAlloc: nms.Name("alloc"),
StructAlign: nms.Name("align"),
Alignment: alignment,
paramsName: paramsName,
createName: structName + "Create",
CreateNet: vb(nms.Name("net")),
CreateParams: vb(nms.Name("params")),
CreateThreads: vb(nms.Name("threads")),
destroyName: structName + "Destroy",
destroyNet: vb(nms.Name("net")),
}
}

func (c *Ctx) Comment() cgen.Gen {
const (
space = " "
indent = space + space + space + space
)
return cgen.Comment{
`The Net contains weights, biases, and other trained parameters in a`,
`form that enables efficient inference. It is created from the input`,
`parameter struct without modifying that struct. The input parameter`,
`struct is no longer needed once the Net has been created. Threads`,
`that are used to create the Net are temporary (in particular, those`,
`threads are not used for inference).`,
``,
indent + c.paramsName + `* params = malloc(sizeof(` + c.paramsName + `));`,
``,
indent + `... Load params (read from a file, perhaps) ...`,
``,
indent + c.StructName + `* net; // For example, 4 threads:`,
indent + `char* err = ` + c.createName + `(&net, params, 4);`,
indent + `free(params);`,
``,
indent + `if (err) { // Nonzero err indicates failure; net is unmodified.`,
indent + indent + `printf("%s\n", err); // Explain the failure, add a newline.`,
indent + indent + `free(err); // Free the error string to avoid a memory leak.`,
indent + indent + `exit(1); // Exit, or propagate the failure some other way.`,
indent + `}`,
``,
indent + `... Perform all inference that depends on net ...`,
``,
indent + c.destroyName + `(net);`,
``,
`The Net can be shared and reused without restriction because it is`,
`never modified (not even temporarily) after being created. The Net`,
`should be destroyed (to free memory) once all dependent inference`,
`is complete.`,
}
}

func (c *Ctx) StructFwd() cgen.Gen {
return cgen.StructFwd(c.StructName)
}

func (c *Ctx) StructDef() cgen.Gen {
return cgen.StructDef{
Name: c.StructName,
Fields: cgen.Stmts{
cgen.Field{
Type: cgen.PtrChar,
What: vb(c.StructAlloc),
},
cgen.Field{
Type: cgen.PtrChar,
What: vb(c.StructAlign),
},
},
}
}

func (c *Ctx) CreateDecl() cgen.Gen {
return cgen.FuncDecl{
ReturnType: cgen.PtrChar,
Name: c.createName,
Params: cgen.CommaLines{
cgen.Ptr{
Type: cgen.Ptr{
Type: vb(c.StructName),
},
},
cgen.Ptr{
Type: vb(c.paramsName),
},
cgen.Param{
Type: cgen.PtrdiffT,
What: vb("threads"),
},
},
}
}

func (c *Ctx) CreateDef(body cgen.Gen) cgen.Gen {
return cgen.FuncDef{
ReturnType: cgen.PtrChar,
Name: c.createName,
Params: cgen.CommaLines{
cgen.Param{
Type: cgen.Ptr{
Type: cgen.Ptr{
Type: vb(c.StructName),
},
},
What: c.CreateNet,
},
cgen.Param{
Type: cgen.Ptr{
Type: vb(c.paramsName),
},
What: c.CreateParams,
},
cgen.Param{
Type: cgen.PtrdiffT,
What: c.CreateThreads,
},
},
Body: body,
}
}

func (c *Ctx) DestroyDecl() cgen.Gen {
return cgen.FuncDecl{
ReturnType: cgen.Void,
Name: c.destroyName,
Params: cgen.Ptr{
Type: vb(c.StructName),
},
}
}

func (c *Ctx) DestroyDef() cgen.Gen {
return cgen.FuncDef{
ReturnType: cgen.Void,
Name: c.destroyName,
Params: cgen.Param{
Type: cgen.Ptr{
Type: vb(c.StructName),
},
What: c.destroyNet,
},
Body: cgen.Stmts{
cgen.Call{
Func: cgen.Free,
Args: cgen.Arrow{
Expr: c.destroyNet,
Name: c.StructAlloc,
},
},
cgen.Call{
Func: cgen.Free,
Args: c.destroyNet,
},
},
}
}

Top || internal/compile/author/one/one.go

package one

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/author/trans"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

func btoi(b bool) int {
if b {
return 1
}
return 0
}

func min(x, y int) int {
if x <= y {
return x
}
return y
}

func max(x, y int) int {
if x >= y {
return x
}
return y
}

func ceilQuo(n, d int) int {
return (n + d - 1) / d
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func loMask(n int) cgen.Gen {
return il(1<<uint(n) - 1)
}

func cast(i int) cgen.Gen {
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: il(i),
}
}

func addMul(x, y, z cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: x,
Expr2: cgen.Mul{
Expr1: y,
Expr2: z,
},
}
}

func mix(a []cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Ctx struct {
prefix string
platform raw.Platform
cacheBytes1 int
cacheBytes2 int
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]interface{}
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "One",
platform: pl.Config.Platform,
cacheBytes1: pl.Config.L1DataCachePerThread,
cacheBytes2: pl.Config.L2CachePerThreadExL1,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]interface{}),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

type Spec struct {
From SpecFrom
Filts []SpecFilts
To SpecTo
StrideH int
StrideW int
PaddingH int
PaddingW int
Groups int
}

type SpecFrom struct {
Chans int
Height int
Width int
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type SpecFilts struct {
Cnt int
BnPre int
BnPost int
}

type SpecTo struct {
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type ctxSpec struct {
*Ctx
*Spec
}

type tokens struct {
IdxPast int
Sects []*section
}

type section struct {
IdxFirst int
IdxPast int
FromBase int
FromWrap int
ToBase int
ToWrap int
Uniqs []*token
}

type token struct {
From tokenFrom
To tokenTo
Slots int
}

type tokenFrom struct {
FirstH int
FirstW int
LastH int
Cmds []interface{}
}

type tokenTo struct {
FirstH int
Cmds []interface{}
}

func tokenize(cs ctxSpec, slots int) *tokens {
var (
ret tokens
toks []*token
)
switch cs.platform {
case raw.AVX512Float32:
toks = m512Toks(cs, slots)
default:
panic("bug")
}
ret.IdxPast = len(toks)
put := func(s *section, i, j int) {
s.Uniqs = make([]*token, j-i)
copy(s.Uniqs, toks[i:j])
for _, tok := range s.Uniqs {
tok.From.FirstH -= s.FromBase
tok.From.LastH -= s.FromBase
tok.To.FirstH -= s.ToBase
}
ret.Sects = append(
ret.Sects, s,
)
}
var (
start = 0
loop = -1
tie = -1
stop = 0
)
encode := func() {
if loop == -1 {
loop = stop
}
if start < loop {
sect := &section{
IdxFirst: start,
IdxPast: loop,
FromBase: toks[start].From.FirstH,
ToBase: toks[start].To.FirstH,
}
put(sect, start, loop)
}
if loop < stop {
var (
loopTok = toks[loop]
tieTok = toks[tie]
fromBase = loopTok.From.FirstH
toBase = loopTok.To.FirstH
)
sect := &section{
IdxFirst: loop,
IdxPast: stop,
FromBase: fromBase,
FromWrap: tieTok.From.FirstH - fromBase,
ToBase: toBase,
ToWrap: tieTok.To.FirstH - toBase,
}
put(sect, loop, tie)
}
start = stop
loop = -1
tie = -1
}
zone := func(h int) int {
switch {
case h < cs.PaddingH:
return 0
case h < cs.PaddingH+cs.From.Height:
return 1
}
return 2
}
type Sig struct {
FirstW int
Zones int
Span int
}
idx := make(map[Sig]int)
for i, tok := range toks {
var (
firstH = tok.From.FirstH
lastH = tok.From.LastH
)
sig := Sig{
FirstW: tok.From.FirstW,
Zones: zone(lastH)<<2 | zone(firstH),
Span: lastH - firstH,
}
if at, ok := idx[sig]; ok {
if loop == -1 {
loop = at
tie = i
}
} else {
idx[sig] = i
if loop != -1 {
encode()
}
}
stop++
}
encode()
return &ret
}

type m512CmdZero struct {
Id int
}

type m512CmdCopy struct {
DstId int
SrcId int
}

type m512CmdRotate struct {
DstId int
SrcId int
Cnt int
}

type m512CmdBlend struct {
DstId int
SrcId int
Off int
Cnt int
}

type m512CmdPermute1 struct {
DstId int
SrcId int
Off int
Inc int
}

type m512CmdPermute2 struct {
DstId int
SrcId1 int
SrcId2 int
Off int
Inc int
}

type m512CmdLoad struct {
Id int
RelH int
W int
Cnt int
}

type m512CmdFromModAddPre struct {
Id int
RelH int
W int
Cnt int
}

type m512CmdFromModPostAdd struct {
Id int
Mask int
}

type m512CmdSlotPut struct {
Slot int
Id int
}

type m512CmdSlotGet struct {
Id int
Slot int
}

type m512CmdToModPreAdd struct {
Id int
}

type m512CmdToModAddPost struct {
Id int
RelH int
W int
Cnt int
}

type m512CmdStore struct {
Id int
RelH int
W int
Cnt int
}

func m512Toks(cs ctxSpec, slots int) (ret []*token) {
const lanes = 16
var (
strideH = cs.StrideH
strideW = cs.StrideW
padH = cs.PaddingH
padW = cs.PaddingW
edgeH = padH + cs.From.Height
edgeW = padW + cs.From.Width
fromH = edgeH + padH
fromW = edgeW + padW
toH = 1 + (fromH-1)/strideH
toW = 1 + (fromW-1)/strideW
h = 0
w = 0
n int
hh int
ww int
tok *token
fromNextId int
fromPileId int
fromPileFree int
fromPileMask int
fromPileSlot int
toNextId int
toPileId int
loPad int
hiPad int
nonPad int
)
fromCmd := func(cmd interface{}) {
tok.From.Cmds = append(
tok.From.Cmds, cmd,
)
}
toCmd := func(cmd interface{}) {
tok.To.Cmds = append(
tok.To.Cmds, cmd,
)
}
newSlot := func() int {
slot := tok.Slots
tok.Slots++
return slot
}
fromNewId := func() int {
id := fromNextId
fromNextId++
return id
}
toNewId := func() int {
id := toNextId
toNextId++
return id
}
zero := func() int {
id := fromNewId()
fromCmd(&m512CmdZero{
Id: id,
})
return id
}
load := func(at, cnt int) int {
var (
id = fromNewId()
relH = hh - tok.From.FirstH
)
fromCmd(&m512CmdLoad{
Id: id,
RelH: relH,
W: at,
Cnt: cnt,
})
fromCmd(&m512CmdFromModAddPre{
Id: id,
RelH: relH,
W: at,
Cnt: cnt,
})
return id
}
broadcast := func(at int) int {
return load(at, 0)
}
pilePut := func() {
if fromPileId == -1 {
return
}
if fromPileMask != 0 {
fromCmd(&m512CmdFromModPostAdd{
Id: fromPileId,
Mask: fromPileMask,
})
}
fromCmd(&m512CmdSlotPut{
Slot: fromPileSlot,
Id: fromPileId,
})
}
slotGet := func(slot int) int {
id := toNewId()
toCmd(&m512CmdSlotGet{
Id: id,
Slot: slot,
})
toCmd(&m512CmdToModPreAdd{
Id: id,
})
return id
}
store := func(id int) {
relH := h - tok.To.FirstH
toCmd(&m512CmdToModAddPost{
Id: id,
RelH: relH,
W: w,
Cnt: n,
})
toCmd(&m512CmdStore{
Id: id,
RelH: relH,
W: w,
Cnt: n,
})
}
build := func() int {
switch {
case nonPad == 0:
return zero()
case n == 1:
return broadcast(ww)
}
var (
lane = loPad
at = ww + loPad*strideW
)
if n <= fromPileFree {
lane += lanes - fromPileFree
}
if strideW == 1 || nonPad == 1 {
id := load(at, nonPad)
if lane > 0 {
fromCmd(&m512CmdRotate{
DstId: id,
SrcId: id,
Cnt: lanes - lane,
})
}
return id
}
var (
id = -1
each = 1 + (lanes-1)/strideW
take int
)
for have := 0; have < nonPad; have += take {
take = min(nonPad-have, 2*each)
var (
off = lane + have
lower = at + have*strideW
tight int
)
switch {
case take == 1:
tight = broadcast(lower)
case take <= each:
var (
span = 1 + (take-1)*strideW
loose = load(lower, span)
)
tight = fromNewId()
fromCmd(&m512CmdPermute1{
DstId: tight,
SrcId: loose,
Off: off,
Inc: strideW,
})
default:
var (
upper = lower + each*strideW
span1 = 1 + (each-1)*strideW
span2 = 1 + (take-each-1)*strideW
loose1 = load(lower, span1)
loose2 = load(upper, span2)
)
tight = fromNewId()
fromCmd(&m512CmdPermute2{
DstId: tight,
SrcId1: loose1,
SrcId2: loose2,
Off: off,
Inc: strideW,
})
}
if id == -1 {
if loPad+hiPad == 0 {
id = tight
continue
}
id = zero()
}
fromCmd(&m512CmdBlend{
DstId: id,
SrcId: tight,
Off: off,
Cnt: take,
})
}
return id
}
encode := func() {
loPad = n
hiPad = 0
if hh >= padH && hh < edgeH {
var (
lo = padW - ww
nn = 1 + (n-1)*strideW
hi = ww + nn - edgeW
)
loPad = 0
if lo > 0 {
loPad = 1 + (lo-1)/strideW
loPad = min(loPad, n)
}
if hi > 0 {
hiPad = 1 + (hi-1)/strideW
hiPad = min(hiPad, n)
}
}
nonPad = n - loPad - hiPad
var (
fromId = build()
mask1 = 1<<uint(nonPad) - 1
mask2 = mask1 << uint(loPad)
)
switch {
case n == lanes:
if mask2 != 0 {
fromCmd(&m512CmdFromModPostAdd{
Id: fromId,
Mask: mask2,
})
}
slot := newSlot()
fromCmd(&m512CmdSlotPut{
Slot: slot,
Id: fromId,
})
toId := slotGet(slot)
store(toId)
case n <= fromPileFree:
off := lanes - fromPileFree
fromCmd(&m512CmdBlend{
DstId: fromPileId,
SrcId: fromId,
Off: off,
Cnt: n,
})
fromPileFree -= n
fromPileMask |= mask2 << uint(off)
toId := toNewId()
toCmd(&m512CmdRotate{
DstId: toId,
SrcId: toPileId,
Cnt: off,
})
store(toId)
default:
pilePut()
fromPileId = fromId
fromPileFree = lanes - n
fromPileMask = mask2
fromPileSlot = newSlot()
toPileId = slotGet(fromPileSlot)
toId := toNewId()
toCmd(&m512CmdCopy{
DstId: toId,
SrcId: toPileId,
})
store(toId)
}
}
for {
for {
n = toW - w
if n == 0 {
if h++; h == toH {
break
}
w = 0
n = toW
}
n = min(n, lanes)
hh = h * strideH
ww = w * strideW
if tok == nil {
tok = &token{
From: tokenFrom{
FirstH: hh,
FirstW: ww,
},
To: tokenTo{
FirstH: h,
},
}
fromNextId = 0
fromPileId = -1
fromPileFree = 0
toNextId = 0
}
if tok.Slots == slots {
if n > fromPileFree {
break
}
}
tok.From.LastH = hh
encode()
w += n
}
if tok != nil {
pilePut()
ret = append(ret, tok)
tok = nil
}
if h == toH {
break
}
}
return
}

type layout struct {
fromChans int
toChans int
slices1 int
slices2 int
epochs1 int
epochs2 int
wtBytes int
wtSliceWts1 int
wtSliceWts2 int
wtSliceBytes1 int
wtSliceBytes2 int
wtCores1 int
wtCores2 int
wtCoreBytes11 int
wtCoreBytes12 int
wtCoreBytes21 int
wtCoreBytes22 int
wtGroupBytes1 int
wtGroupBytes2 int
wtEpochBytes1 int
wtEpochBytes2 int
wtTotalBytes int
datBytes int
slotDats int
slotBytes int
datSliceSlots1 int
datSliceSlots2 int
datSliceDats1 int
datSliceDats2 int
datSliceBytes1 int
datSliceBytes2 int
datCores1 int
datCores2 int
datCoreBytes11 int
datCoreBytes12 int
datCoreBytes21 int
datCoreBytes22 int
datGroupBytes1 int
datGroupBytes2 int
datEpochBytes1 int
datEpochBytes2 int
datTotalBytes int
toks *tokens
}

func newLayout(cs ctxSpec) *layout {
var y layout
special := func() bool {
if cs.StrideH != 1 ||
cs.StrideW != 1 ||
cs.PaddingH != 0 ||
cs.PaddingW != 0 {
return false
}
tight := cs.From.Width * y.datBytes
for _, pitch := range cs.From.Pitch1Bytes {
if pitch != tight {
return false
}
}
for _, pitch := range cs.To.Pitch1Bytes {
if pitch != tight {
return false
}
}
return true
}
stage7 := func() {
y.datCoreBytes11 = y.slices1 * y.datSliceBytes1
y.datCoreBytes12 = y.slices1 * y.datSliceBytes2
y.datCoreBytes21 = y.slices2 * y.datSliceBytes1
y.datCoreBytes22 = y.slices2 * y.datSliceBytes2
y.datGroupBytes1 = y.datCores1 * y.datCoreBytes11
y.datGroupBytes2 = y.datCores1 * y.datCoreBytes21
if y.datCores1 < y.datCores2 {
y.datGroupBytes1 += y.datCoreBytes12
y.datGroupBytes2 += y.datCoreBytes22
}
y.datEpochBytes1 = cs.Groups * y.datGroupBytes1
y.datEpochBytes2 = cs.Groups * y.datGroupBytes2
y.datTotalBytes = y.epochs1 * y.datEpochBytes1
if y.epochs1 < y.epochs2 {
y.datTotalBytes += y.datEpochBytes2
}
}
stage6 := func() {
var (
withBias1 = 1 + y.slices1
withBias2 = 1 + y.slices2
)
y.wtCoreBytes11 = withBias1 * y.wtSliceBytes1
y.wtCoreBytes12 = withBias1 * y.wtSliceBytes2
y.wtCoreBytes21 = withBias2 * y.wtSliceBytes1
y.wtCoreBytes22 = withBias2 * y.wtSliceBytes2
y.wtGroupBytes1 = y.wtCores1 * y.wtCoreBytes11
y.wtGroupBytes2 = y.wtCores1 * y.wtCoreBytes21
if y.wtCores1 < y.wtCores2 {
y.wtGroupBytes1 += y.wtCoreBytes12
y.wtGroupBytes2 += y.wtCoreBytes22
}
y.wtEpochBytes1 = cs.Groups * y.wtGroupBytes1
y.wtEpochBytes2 = cs.Groups * y.wtGroupBytes2
y.wtTotalBytes = y.epochs1 * y.wtEpochBytes1
if y.epochs1 < y.epochs2 {
y.wtTotalBytes += y.wtEpochBytes2
}
stage7()
}
stage5 := func() {
wtSliceBytes := y.wtSliceBytes1
if y.wtCores1 == 0 {
wtSliceBytes = y.wtSliceBytes2
}
datSliceBytes := y.datSliceBytes1
if y.datCores1 == 0 {
datSliceBytes = y.datSliceBytes2
}
switch cs.platform {
case raw.AVX512Float32:
var (
sliceBytes = 2*wtSliceBytes + datSliceBytes
cacheBytes = cs.cacheBytes1 + cs.cacheBytes2
)
const (
empirical1 = 4
empirical2 = 512
empirical3 = 4
)
y.slices1 = cacheBytes / empirical1 / sliceBytes
y.slices1 = max(y.slices1, empirical2)
y.slices2 = y.fromChans % y.slices1
y.epochs1 = y.fromChans / y.slices1
y.epochs2 = y.epochs1 + btoi(y.slices2 > 0)
if y.epochs1 > 0 && y.epochs1 < y.epochs2 {
if y.slices2*empirical3 < y.slices1 {
y.slices2 += y.slices1
y.epochs1--
y.epochs2--
}
}
default:
panic("bug")
}
stage6()
}
stage4 := func() {
if special() {
chanDats := cs.From.Height * cs.From.Width
y.datSliceDats2 = chanDats % y.datSliceDats1
y.datSliceSlots2 = ceilQuo(y.datSliceDats2, y.slotDats)
y.datSliceBytes2 = y.datSliceSlots2 * y.slotBytes
y.datCores1 = chanDats / y.datSliceDats1
y.datCores2 = y.datCores1 + btoi(y.datSliceDats2 > 0)
} else {
sig := fmt.Sprint(
"tokenize", " ",
cs.From.Height, cs.From.Width,
cs.StrideH, cs.StrideW,
cs.PaddingH, cs.PaddingW,
)
if prior, ok := cs.dedup[sig]; ok {
y.toks = prior.(*tokens)
} else {
y.toks = tokenize(cs, y.datSliceSlots1)
cs.dedup[sig] = y.toks
}
y.datCores1 = y.toks.IdxPast
y.datCores2 = y.datCores1
var (
sect = y.toks.Sects[len(y.toks.Sects)-1]
tok = sect.Uniqs[len(sect.Uniqs)-1]
)
if tok.Slots != y.datSliceSlots1 {
y.datSliceSlots2 = tok.Slots
y.datSliceDats2 = y.datSliceSlots2 * y.slotDats
y.datSliceBytes2 = y.datSliceSlots2 * y.slotBytes
y.datCores1--
}
}
stage5()
}
stage3 := func() {
y.wtSliceWts2 = y.toChans % y.wtSliceWts1
y.wtSliceBytes2 = y.wtSliceWts2 * y.wtBytes
y.wtCores1 = y.toChans / y.wtSliceWts1
y.wtCores2 = y.wtCores1 + btoi(y.wtSliceWts2 > 0)
stage4()
}
stage2 := func() {
y.fromChans = cs.From.Chans / cs.Groups
for i := range cs.Filts {
y.toChans += cs.Filts[i].Cnt
}
y.toChans /= cs.Groups
stage3()
}
stage1 := func() {
switch cs.platform {
case raw.AVX512Float32:
y.wtBytes = 4
y.wtSliceWts1 = 6
y.datBytes = 4
y.slotDats = 16
y.datSliceSlots1 = 4
default:
panic("bug")
}
y.wtSliceBytes1 = y.wtSliceWts1 * y.wtBytes
y.slotBytes = y.slotDats * y.datBytes
y.datSliceDats1 = y.datSliceSlots1 * y.slotDats
y.datSliceBytes1 = y.datSliceSlots1 * y.slotBytes
stage2()
}
stage1()
return &y
}

type ArrangeWts struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeWts) Prep() cgen.Gen {
a.layout = newLayout(ctxSpec{
Ctx: a.Ctx,
Spec: a.Spec,
})
const affix = "ArrangeWts"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeWts{ArrangeWts: a},
cgen.Newline,
}
}

func (a *ArrangeWts) Bytes() int {
return a.wtTotalBytes
}

func (a *ArrangeWts) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeWts struct {
*ArrangeWts
bundleChans int
bundleTile int
bundleTiles int
bundleScrap int
bundleHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
bundleCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
slices int
coreBytes int
groupBytes int
epochFirst int
epochCnt int
wtPtrs []cgen.Gen
biasPtrs []cgen.Gen
bnPtrs [][]cgen.Gen
arranged cgen.Gen
groupIdx cgen.Gen
filtsIdx int
shortCore bool
workChan cgen.Gen
workChans int
workCore cgen.Gen
workCut cgen.Gen
workCores int
}

func (a *arrangeWts) Append(to []byte) []byte {
var (
threadWts int
groupBundles int
)
switch a.platform {
case raw.AVX512Float32:
a.bundleChans = 16
threadWts = a.bundleChans * 512
default:
panic("bug")
}
if len(a.Filts) == 1 {
groupBundles = ceilQuo(a.toChans, a.bundleChans)
} else {
if a.Groups != 1 {
panic("bug")
}
for i := range a.Filts {
chans := a.Filts[i].Cnt
groupBundles += ceilQuo(chans, a.bundleChans)
}
}
var (
n1 = a.slices1 * a.epochs1
n2 = a.slices2 * (a.epochs2 - a.epochs1)
chanWts = ceilQuo(n1+n2, a.epochs2)
bundleWts = a.bundleChans * chanWts
groupWts = a.toChans * chanWts
)
switch {
case threadWts <= bundleWts:
a.bundleTile = 1
a.bundleTiles = groupBundles
a.bundleScrap = 0
a.bundleHull = groupBundles
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
case threadWts <= groupWts:
var (
threadBundles = ceilQuo(threadWts, bundleWts)
fit = max(groupBundles/threadBundles, 1)
)
a.bundleTile = groupBundles / fit
a.bundleTiles = fit
a.bundleScrap = groupBundles - fit*a.bundleTile
a.bundleHull = fit
if a.bundleScrap > 0 {
a.bundleTiles--
a.bundleScrap += a.bundleTile
}
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
default:
a.bundleTile = groupBundles
a.bundleTiles = 1
a.bundleScrap = 0
a.bundleHull = 1
var (
threadGroups = ceilQuo(threadWts, groupWts)
fit = max(a.Groups/threadGroups, 1)
)
a.groupTile = a.Groups / fit
a.groupTiles = fit
a.groupScrap = a.Groups - fit*a.groupTile
a.groupHull = fit
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.bundleHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeWts) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 6)
tensors = vb(a.name("tensors"))
usedPt = false
)
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
coord := func(hull, i int, nm string) cgen.Gen {
if hull == 1 {
return nil
}
ret := vb(a.name(nm))
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: cgen.Elem{
Arr: callee.Pt, Idx: il(i),
},
}
usedPt = true
return ret
}
a.bundleCoord = coord(a.bundleHull, 0, "b")
a.groupCoord = coord(a.groupHull, 1, "g")
a.epochCoord = coord(a.epochs2, 2, "e")
if !usedPt {
body[1] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
impl := func() cgen.Gen {
return cgen.Gens{
a.ptrs(tensors),
a.kernel(),
}
}
if a.epochs1 > 0 {
a.slices = a.slices1
a.coreBytes = a.wtCoreBytes11
a.groupBytes = a.wtGroupBytes1
a.epochFirst = 0
a.epochCnt = a.epochs1
put := impl()
if a.epochs1 < a.epochs2 {
put = cgen.If{
Cond: cgen.CmpL{
Expr1: a.epochCoord,
Expr2: il(a.epochs1),
},
Then: cgen.Stmts{
put,
cgen.Return{},
},
}
}
body[4] = put
}
if a.epochs1 < a.epochs2 {
a.slices = a.slices2
a.coreBytes = a.wtCoreBytes21
a.groupBytes = a.wtGroupBytes2
a.epochFirst = a.epochs1
a.epochCnt = 1
body[5] = impl()
}
return callee.Func(body)
}

func (a *arrangeWts) ptrs(tensors cgen.Gen) cgen.Gen {
var (
parts = len(a.Filts)
epoch = a.epochCoord
group = a.groupCoord
wtOff cgen.Gen
biasOff cgen.Gen
preCh cgen.Gen
postCh cgen.Gen
arOff cgen.Gen
)
stage6 := func() cgen.Gen {
var (
stmts cgen.Stmts
next = 0
)
stmt := func(s cgen.Gen) {
stmts = append(stmts, s)
}
tensor := func() cgen.Gen {
i := next
next++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
decl := func(what, off cgen.Gen) {
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: what,
Init: cgen.Add{
Expr1: tensor(),
Expr2: off,
},
})
}
for i := 0; i < parts; i++ {
decl(a.wtPtrs[i], wtOff)
if a.epochFirst == 0 {
decl(a.biasPtrs[i], biasOff)
} else {
next++
}
split := a.Filts[i].BnPre
for j, ptr := range a.bnPtrs[i] {
ch := preCh
if j >= split {
ch = postCh
}
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr,
Init: &bn.Offset{
Ctx: a.bc,
Mas: tensor(),
Channel: ch,
},
})
}
}
decl(a.arranged, arOff)
return stmts
}
stage5 := func() cgen.Gen {
a.arranged = vb(a.name("arranged"))
arOff = cgen.Add{
Expr1: cgen.Mul{
Expr1: cast(a.wtEpochBytes1),
Expr2: epoch,
},
Expr2: cgen.Mul{
Expr1: cast(a.groupBytes),
Expr2: group,
},
}
return stage6()
}
stage4 := func() cgen.Gen {
a.bnPtrs = make([][]cgen.Gen, parts)
for i := range a.bnPtrs {
var (
pre = a.Filts[i].BnPre
post = a.Filts[i].BnPost
put = make([]cgen.Gen, pre+post)
)
for j := range put {
put[j] = vb(a.name("bnPtr"))
}
a.bnPtrs[i] = put
}
preCh = cgen.Paren{
Inner: cgen.Add{
Expr1: cgen.Mul{
Expr1: cast(a.slices1),
Expr2: epoch,
},
Expr2: cgen.Mul{
Expr1: cast(a.fromChans),
Expr2: group,
},
},
}
postCh = cgen.Mul{
Expr1: il(a.toChans),
Expr2: group,
}
return stage5()
}
stage3 := func() cgen.Gen {
if a.epochFirst == 0 {
a.biasPtrs = make([]cgen.Gen, parts)
for i := range a.biasPtrs {
a.biasPtrs[i] = vb(a.name("biasPtr"))
}
biasOff = cgen.Mul{
Expr1: cast(a.toChans * a.wtBytes),
Expr2: group,
}
} else {
a.biasPtrs = nil
}
return stage4()
}
stage2 := func() cgen.Gen {
a.wtPtrs = make([]cgen.Gen, parts)
for i := range a.wtPtrs {
a.wtPtrs[i] = vb(a.name("wtPtr"))
}
filtBytes := a.fromChans * a.wtBytes
wtOff = cgen.Add{
Expr1: cgen.Mul{
Expr1: cast(a.slices1 * a.wtBytes),
Expr2: epoch,
},
Expr2: cgen.Mul{
Expr1: cast(a.toChans * filtBytes),
Expr2: group,
},
}
return stage3()
}
stage1 := func() cgen.Gen {
if a.epochCnt == 1 {
epoch = il(a.epochFirst)
}
if group == nil {
group = il(0)
} else {
group = cgen.Mul{
Expr1: il(a.groupTile),
Expr2: group,
}
}
return stage2()
}
return stage1()
}

func (a *arrangeWts) kernel() cgen.Gen {
var (
bundleIdx cgen.Gen
outerChan int
outerChans int
outerBundle int
innerChan int
innerChans int
innerBundle int
)
layer8 := func() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}
layer7 := func() cgen.Gen {
var (
n = a.wtSliceWts1
spans = make([]int, n)
stop = innerChan + innerChans
each = a.workChans
)
for ch := innerChan; ch != stop; ch += each {
cut := ch % n
if spans[cut] != 0 {
break
}
span := 1
if fill := n - cut; each > fill {
span += ceilQuo(each-fill, n)
}
spans[cut] = span
}
only := 0
for _, span := range spans {
switch {
case span == 0:
case only == 0:
only = span
case only != span:
only = -1
}
}
if only != -1 {
a.workCores = only
return layer8()
}
var (
cases = make(cgen.Stmts, 0, n)
cuts = make([]int, 0, n-1)
)
for {
var (
take = 0
last = true
)
for cut, span := range spans {
switch {
case span == 0:
case take == 0:
take = span
fallthrough
case take == span:
spans[cut] = 0
cuts = append(cuts, cut)
default:
last = false
}
}
a.workCores = take
if last {
var assn cgen.Gen
if len(cuts) == 1 {
assn = cgen.Assign{
Expr1: a.workCut,
Expr2: il(cuts[0]),
}
}
cases = append(
cases, cgen.Case{
Body: cgen.Stmts{
assn,
layer8(),
},
},
)
break
}
for x, cut := range cuts {
var body cgen.Gen
if x == len(cuts)-1 {
body = cgen.Stmts{
layer8(),
cgen.Break,
}
}
cases = append(
cases, cgen.Case{
Expr: il(cut),
Body: body,
},
)
}
cuts = cuts[:0]
}
return cgen.Switch{
Expr: a.workCut,
Cases: cases,
}
}
layer6 := func() cgen.Gen {
a.workCore = vb(a.name("l"))
a.workCut = vb(a.name("cut"))
var (
stmts = make(cgen.Stmts, 3)
numer = cgen.Cast{
Type: cgen.SizeT,
Expr: cgen.Paren{
Inner: cgen.Add{
Expr1: il(outerChan),
Expr2: a.workChan,
},
},
}
denom = il(a.wtSliceWts1)
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.workCore,
Init: cgen.Quo{
Expr1: numer,
Expr2: denom,
},
}
var cut cgen.Gen
if a.bundleChans%a.wtSliceWts1 == 0 {
cut = il(outerChan % a.wtSliceWts1)
} else {
cut = cgen.Rem{
Expr1: numer,
Expr2: denom,
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.workCut,
Init: cut,
}
stmts[2] = layer7()
return stmts
}
layer5 := func() cgen.Stmts {
a.workChan = vb(a.name("k"))
a.workChans = min(innerChans, a.bundleChans)
var (
stmts = make(cgen.Stmts, 2)
ch = il(innerChan - outerChan)
)
if a.workChans < innerChans {
ch = addMul(
ch,
il(a.bundleChans),
cgen.Paren{
Inner: cgen.Sub{
Expr1: bundleIdx,
Expr2: il(innerBundle),
},
},
)
}
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.workChan,
Init: ch,
}
stmts[1] = layer6()
return stmts
}
layer4 := func() cgen.Stmts {
var stmts cgen.Stmts
ite := func(upper int) {
if stmts == nil {
stmts = layer5()
return
}
stmts = cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: bundleIdx,
Expr2: il(upper),
},
Then: layer5(),
Else: stmts,
},
}
}
var (
first = a.toChans - a.wtSliceWts2
past = outerChan + outerChans
short = min(max(past-first, 0), outerChans)
bunds = (outerChans - short) / a.bundleChans
chans1 = bunds * a.bundleChans
chans2 = outerChans - chans1
quo = chans2 / a.bundleChans
rem = chans2 % a.bundleChans
split = outerBundle + bunds
)
if rem > 0 {
a.shortCore = short > 0
innerChan = past - rem
innerChans = rem
innerBundle = split + quo
stmts = layer5()
}
if quo > 0 {
a.shortCore = true
innerChan = past - chans2
innerChans = chans2 - rem
innerBundle = split
ite(split + quo)
}
if chans1 > 0 {
a.shortCore = false
innerChan = outerChan
innerChans = chans1
innerBundle = outerBundle
ite(split)
}
return stmts
}
layer3 := func() cgen.Gen {
parts := len(a.Filts)
if parts == 1 {
a.filtsIdx = 0
outerChan = 0
outerChans = a.toChans
outerBundle = 0
return layer4()
}
var (
atChan = make([]int, parts+1)
atBund = make([]int, parts+1)
)
for part := 0; part < parts; part++ {
var (
chans = a.Filts[part].Cnt
bunds = ceilQuo(chans, a.bundleChans)
)
atChan[part+1] = atChan[part] + chans
atBund[part+1] = atBund[part] + bunds
}
leaf := func(part int) cgen.Stmts {
a.filtsIdx = part
outerChan = atChan[part]
outerChans = a.Filts[part].Cnt
outerBundle = atBund[part]
return layer4()
}
var tree func(int, int) cgen.Stmts
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = atBund[first]
stop = atBund[last+1]
upper = start + (stop-start)/2
split = first + 1
)
for atBund[split+1] <= upper {
split++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: bundleIdx,
Expr2: il(atBund[split]),
},
Then: tree(first, split-1),
Else: tree(split, last),
},
}
}
return tree(0, parts-1)
}
layer2 := func() cgen.Gen {
bundleIdx = vb(a.name("j"))
var (
past = vb(a.name("jj"))
first cgen.Gen
iters cgen.Gen
)
if a.bundleCoord == nil {
first = il(0)
} else {
first = cgen.Mul{
Expr1: il(a.bundleTile),
Expr2: a.bundleCoord,
}
}
switch a.bundleTiles {
case a.bundleHull:
iters = il(a.bundleTile)
case 0:
iters = il(a.bundleScrap)
default:
iters = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.bundleCoord,
Expr2: il(a.bundleTiles),
},
Then: il(a.bundleTile),
Else: il(a.bundleScrap),
},
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: bundleIdx,
Init: first,
},
cgen.Var{
Type: cgen.PtrdiffT,
What: past,
Init: cgen.Add{
Expr1: bundleIdx,
Expr2: iters,
},
},
cgen.For{
Cond: cgen.CmpL{
Expr1: bundleIdx,
Expr2: past,
},
Post: cgen.IncPre{
Expr: bundleIdx,
},
Body: layer3(),
},
}
}
layer1 := func() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
past = vb(a.name("ii"))
iters cgen.Gen
)
switch a.groupTiles {
case a.groupHull:
iters = il(a.groupTile)
case 0:
iters = il(a.groupScrap)
default:
iters = cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile),
Else: il(a.groupScrap),
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: past, Init: iters,
},
cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: a.groupIdx,
Expr2: past,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: layer2(),
},
}
}
return layer1()
}

func (a *arrangeWts) m512() cgen.Gen {
const lanes = 16
var (
sum cgen.Gen
postMul1 cgen.Gen
cellIdx cgen.Gen
cellSlices int
wts []cgen.Gen
)
emit := func(what, sliceIdx cgen.Gen) cgen.Stmts {
var (
cores = a.workCores
stmts = make(cgen.Stmts, cores)
ae = a.arranged
slicePitch1 = il(a.wtSliceBytes1)
slicePitch2 = slicePitch1
n = a.wtSliceWts1
)
ae = addMul(ae, il(a.groupBytes), a.groupIdx)
ae = addMul(ae, il(a.coreBytes), a.workCore)
ae = addMul(ae, il(a.wtBytes), a.workCut)
if a.shortCore {
slicePitch2 = il(a.wtSliceBytes2)
}
if cores == 1 {
stmts[0] = avx.Mm512MaskStoreuPs{
addMul(ae, slicePitch2, sliceIdx),
loMask(a.workChans),
what,
}
return stmts
}
for x := 0; x < cores-1; x++ {
var (
fwd = x * a.coreBytes
bwd = x * a.wtSliceBytes1
)
stmts[x] = avx.Mm512MaskStoreuPs{
cgen.Add{
Expr1: addMul(ae, slicePitch1, sliceIdx),
Expr2: cast(fwd - bwd),
},
cgen.ShiftLow{
Expr1: il((1<<uint(n) - 1) << uint(x*n)),
Expr2: a.workCut,
},
what,
}
}
var (
x = cores - 1
fwd = x * a.coreBytes
bwd = x * a.wtSliceBytes1
)
stmts[x] = avx.Mm512MaskStoreuPs{
cgen.Add{
Expr1: addMul(ae, slicePitch2, sliceIdx),
Expr2: cast(fwd - bwd),
},
cgen.Sub{
Expr1: loMask(a.workChans),
Expr2: cgen.Paren{
Inner: cgen.ShiftLow{
Expr1: loMask(x * n),
Expr2: a.workCut,
},
},
},
what,
}
return stmts
}
layer7 := func() cgen.Gen {
toMix := make([]cgen.Stmts, cellSlices)
for x := range toMix {
toMix[x] = emit(
wts[x],
cgen.Paren{
Inner: addMul(
il(1+x), il(lanes), cellIdx,
),
},
)
}
return mix(toMix)
}
layer6 := func() cgen.Gen {
preCnt := a.Filts[a.filtsIdx].BnPre
if preCnt == 0 {
return layer7()
}
var (
n1 = cellSlices
outer = make(cgen.Gens, n1+1)
prePtrs = a.bnPtrs[a.filtsIdx][:preCnt]
)
for x1 := 0; x1 < n1; x1++ {
var (
n2 = preCnt * 3
inner = make(cgen.Stmts, n2+2)
wt = wts[x1]
preMul1 cgen.Gen
preAdd1 cgen.Gen
)
preCh := cgen.Paren{
Inner: addMul(
addMul(il(x1), il(lanes), cellIdx),
il(a.fromChans),
a.groupIdx,
),
}
for x2, prePtr := range prePtrs {
var (
preMul2 = vb(a.name("preMul"))
preAdd2 = vb(a.name("preAdd"))
)
inner[x2*3] = &bn.Load{
Ctx: a.bc,
Mas: prePtr,
Channel: preCh,
Mul: preMul2,
Add: preAdd2,
}
if x2 == 0 {
preMul1 = preMul2
preAdd1 = preAdd2
continue
}
inner[x2*3+1] = cgen.Assign{
Expr1: preMul1,
Expr2: avx.Mm512MulPs{
preMul1, preMul2,
},
}
inner[x2*3+2] = cgen.Assign{
Expr1: preAdd1,
Expr2: avx.Mm512FmaddPs{
preAdd1, preMul2, preAdd2,
},
}
}
inner[n2] = cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
wt, preAdd1, sum,
},
}
inner[n2+1] = cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
wt, preMul1,
},
}
outer[x1] = inner
}
outer[n1] = layer7()
return outer
}
layer5 := func() cgen.Gen {
if postMul1 == nil {
return layer6()
}
var (
n = cellSlices
stmts = make(cgen.Stmts, n+1)
)
for x := 0; x < n; x++ {
wt := wts[x]
stmts[x] = cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
wt, postMul1,
},
}
}
stmts[n] = layer6()
return stmts
}
layer4 := func() cgen.Gen {
var (
rows = a.workChans
cols = cellSlices
)
wts = make([]cgen.Gen, max(rows, cols))
for x := range wts {
wts[x] = vb(a.name("wt"))
}
var (
stmts = make(cgen.Stmts, rows+2)
mask = loMask(cols)
ae = a.wtPtrs[a.filtsIdx]
filtBytes = a.fromChans * a.wtBytes
groupPitch = il(a.toChans * filtBytes)
chanPitch = il(filtBytes)
cellPitch = il(lanes * a.wtBytes)
)
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, chanPitch, a.workChan)
ae = addMul(ae, cellPitch, cellIdx)
for x := 0; x < rows; x++ {
stmts[x] = cgen.Var{
Type: avx.M512, What: wts[x],
Init: avx.Mm512MaskzLoaduPs{
mask,
cgen.Add{
Expr1: ae,
Expr2: cast(x * filtBytes),
},
},
}
}
stmts[rows] = &trans.Pose{
Platform: a.platform,
Nms: a.nms,
Rows: rows,
Cols: cols,
Vars: wts,
}
stmts[rows+1] = layer5()
return stmts
}
layer3 := func() cgen.Gen {
cellIdx = vb(a.name("c"))
var (
stmts = make(cgen.Stmts, 3)
quo = a.slices / lanes
rem = a.slices % lanes
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: cellIdx,
Init: il(0),
}
if quo > 0 {
cellSlices = lanes
stmts[1] = cgen.For{
Cond: cgen.CmpNE{
Expr1: cellIdx,
Expr2: il(quo),
},
Post: cgen.IncPre{
Expr: cellIdx,
},
Body: layer4(),
}
}
if rem > 0 {
cellSlices = rem
stmts[2] = layer4()
}
return stmts
}
layer2 := func() cgen.Gen {
var (
outer = make(cgen.Gens, 3)
filts = &a.Filts[a.filtsIdx]
preCnt = filts.BnPre
postCnt = filts.BnPost
)
if postCnt > 0 {
var (
inner = make(cgen.Stmts, postCnt*3)
postPtrs = a.bnPtrs[a.filtsIdx][preCnt:]
)
postCh := cgen.Paren{
Inner: addMul(
a.workChan,
il(a.toChans),
a.groupIdx,
),
}
for x, postPtr := range postPtrs {
var (
postMul2 = vb(a.name("postMul"))
postAdd = vb(a.name("postAdd"))
)
inner[x*3] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul2,
Add: postAdd,
Cnt: a.workChans,
}
if x == 0 {
postMul1 = postMul2
} else {
inner[x*3+1] = cgen.Assign{
Expr1: postMul1,
Expr2: avx.Mm512MulPs{
postMul1, postMul2,
},
}
}
var stmt cgen.Gen
if a.epochFirst == 0 {
stmt = cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
sum, postMul2, postAdd,
},
}
if a.epochCnt > 1 {
stmt = cgen.If1{
Cond: cgen.IsZero{
Expr: a.epochCoord,
},
Then: stmt,
}
}
} else {
stmt = cgen.Cast{
Type: cgen.Void,
Expr: postAdd,
}
}
inner[x*3+2] = stmt
}
outer[0] = inner
}
if preCnt > 0 {
outer[1] = layer3()
outer[2] = emit(sum, il(0))
} else {
outer[1] = emit(sum, il(0))
outer[2] = layer3()
}
return outer
}
layer1 := func() cgen.Gen {
sum = vb(a.name("sum"))
var (
decl = cgen.Var{
Type: avx.M512, What: sum,
}
bias cgen.Gen
assn cgen.Gen
)
if a.epochFirst == 0 {
var (
ae = a.biasPtrs[a.filtsIdx]
groupPitch = il(a.toChans * a.wtBytes)
chanPitch = il(a.wtBytes)
mask = loMask(a.workChans)
)
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, chanPitch, a.workChan)
bias = avx.Mm512MaskzLoaduPs{
mask, ae,
}
}
switch {
case bias == nil:
decl.Init = avx.Mm512SetzeroPs
case a.epochCnt == 1:
decl.Init = bias
default:
assn = cgen.If{
Cond: cgen.IsZero{
Expr: a.epochCoord,
},
Then: cgen.Stmts{cgen.Assign{
Expr1: sum,
Expr2: bias,
}},
Else: cgen.Stmts{cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512SetzeroPs,
}},
}
}
return cgen.Stmts{
decl,
assn,
layer2(),
}
}
return layer1()
}

type ArrangeDats struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeDats) Prep() cgen.Gen {
a.layout = newLayout(ctxSpec{
Ctx: a.Ctx,
Spec: a.Spec,
})
const affix = "ArrangeDats"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeDats{ArrangeDats: a},
cgen.Newline,
}
}

func (a *ArrangeDats) Bytes() int {
return a.datTotalBytes
}

func (a *ArrangeDats) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeDats struct {
*ArrangeDats
sliceTile1 int
sliceTile2 int
sliceTiles int
sliceScrap1 int
sliceScrap2 int
sliceHull int
coreTile int
coreTiles int
coreScrap int
coreHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
sliceCoord cgen.Gen
coreCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
sliceTile int
sliceScrap int
coreBytes int
groupBytes int
epochFirst int
epochCnt int
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
arranged cgen.Gen
groupIdx cgen.Gen
coreIdx cgen.Gen
short bool
sectH cgen.Gen
tok *token
sliceIdx cgen.Gen
}

func (a *arrangeDats) Append(to []byte) []byte {
var (
threadSlots int
sliceSlots = a.datSliceSlots1
)
switch a.platform {
case raw.AVX512Float32:
threadSlots = 512
default:
panic("bug")
}
if a.datCores1 == 0 {
sliceSlots = a.datSliceSlots2
}
var (
threadSlices = ceilQuo(threadSlots, sliceSlots)
coreSlices = a.slices1
)
switch {
case a.epochs1 == a.epochs2:
case a.epochs1 == 0 || a.slices1 > a.slices2:
coreSlices = a.slices2
}
a.coreTile = 1
a.coreTiles = a.datCores2
a.coreScrap = 0
a.coreHull = a.coreTiles
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.groupTiles
if threadSlices < coreSlices {
var (
fit = coreSlices / threadSlices
tiles = fit - 1
)
a.sliceTile1 = a.slices1 / fit
a.sliceTile2 = a.slices2 / fit
a.sliceTiles = tiles
a.sliceScrap1 = a.slices1 - tiles*a.sliceTile1
a.sliceScrap2 = a.slices2 - tiles*a.sliceTile2
a.sliceHull = fit
} else {
a.sliceTile1 = a.slices1
a.sliceTile2 = a.slices2
a.sliceTiles = 1
a.sliceScrap1 = 0
a.sliceScrap2 = 0
a.sliceHull = 1
var (
threadCores = ceilQuo(threadSlices, coreSlices)
groupCores = a.datCores2
)
if threadCores < groupCores {
fit := groupCores / threadCores
a.coreTile = groupCores / fit
a.coreTiles = fit
a.coreScrap = groupCores - fit*a.coreTile
a.coreHull = fit
if a.coreScrap > 0 {
a.coreTiles--
a.coreScrap += a.coreTile
}
} else {
a.coreTile = groupCores
a.coreTiles = 1
a.coreScrap = 0
a.coreHull = 1
var (
threadGroups = ceilQuo(threadCores, groupCores)
epochGroups = a.Groups
)
if threadGroups < epochGroups {
fit := epochGroups / threadGroups
a.groupTile = epochGroups / fit
a.groupTiles = fit
a.groupScrap = epochGroups - fit*a.groupTile
a.groupHull = fit
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
} else {
a.groupTile = epochGroups
a.groupTiles = 1
a.groupScrap = 0
a.groupHull = 1
}
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.sliceHull),
il(a.coreHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeDats) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 7)
tensors = vb(a.name("tensors"))
usedPt = false
)
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
coord := func(hull, i int, nm string) cgen.Gen {
if hull == 1 {
return nil
}
ret := vb(a.name(nm))
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: cgen.Elem{
Arr: callee.Pt, Idx: il(i),
},
}
usedPt = true
return ret
}
a.sliceCoord = coord(a.sliceHull, 0, "s")
a.coreCoord = coord(a.coreHull, 1, "c")
a.groupCoord = coord(a.groupHull, 2, "g")
a.epochCoord = coord(a.epochs2, 3, "e")
if !usedPt {
body[1] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
impl := func() cgen.Gen {
return cgen.Gens{
a.ptrs(tensors),
a.kernel(),
}
}
if a.epochs1 > 0 {
a.sliceTile = a.sliceTile1
a.sliceScrap = a.sliceScrap1
a.coreBytes = a.datCoreBytes11
a.groupBytes = a.datGroupBytes1
a.epochFirst = 0
a.epochCnt = a.epochs1
put := impl()
if a.epochs1 < a.epochs2 {
put = cgen.If{
Cond: cgen.CmpL{
Expr1: a.epochCoord,
Expr2: il(a.epochs1),
},
Then: cgen.Stmts{
put,
cgen.Return{},
},
}
}
body[5] = put
}
if a.epochs1 < a.epochs2 {
a.sliceTile = a.sliceTile2
a.sliceScrap = a.sliceScrap2
a.coreBytes = a.datCoreBytes21
a.groupBytes = a.datGroupBytes2
a.epochFirst = a.epochs1
a.epochCnt = 1
body[6] = impl()
}
return callee.Func(body)
}

func (a *arrangeDats) ptrs(tensors cgen.Gen) cgen.Gen {
var (
epoch = a.epochCoord
group = a.groupCoord
datCnt = len(a.From.Pitch1Bytes)
bnCnt = 0
datExprs []cgen.Gen
bnExpr cgen.Gen
arExpr cgen.Gen
)
stage5 := func() cgen.Gen {
var (
stmtCnt = datCnt + bnCnt + 1
stmts = make(cgen.Stmts, stmtCnt)
stmtIdx = 0
tensorIdx = 0
datIdx = 0
bnIdx = 0
)
stmt := func(s cgen.Gen) {
stmts[stmtIdx] = s
stmtIdx++
}
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
dp := func() {
i := datIdx
datIdx++
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.datPtrs[i],
Init: cgen.Add{
Expr1: tensor(),
Expr2: datExprs[i],
},
})
}
ndp := func(n int) {
for ; n > 0; n-- {
dp()
}
}
bp := func() {
i := bnIdx
bnIdx++
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.bnPtrs[i],
Init: &bn.Offset{
Ctx: a.bc,
Mas: tensor(),
Channel: bnExpr,
},
})
}
dp()
for i := range a.From.Ops {
op := &a.From.Ops[i]
switch op.Kind {
case mod.Add:
ndp(op.Int)
case mod.Bn:
bp()
case mod.ReLU:
default:
panic("bug")
}
}
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.arranged,
Init: cgen.Add{
Expr1: tensor(),
Expr2: arExpr,
},
})
return stmts
}
stage4 := func() cgen.Gen {
a.arranged = vb(a.name("arranged"))
arExpr = cgen.Add{
Expr1: cgen.Mul{
Expr1: cast(a.datEpochBytes1),
Expr2: epoch,
},
Expr2: cgen.Mul{
Expr1: cast(a.groupBytes),
Expr2: group,
},
}
return stage5()
}
stage3 := func() cgen.Gen {
a.bnPtrs = make([]cgen.Gen, bnCnt)
for i := range a.bnPtrs {
a.bnPtrs[i] = vb(a.name("bnPtr"))
}
bnExpr = cgen.Paren{
Inner: cgen.Add{
Expr1: cgen.Mul{
Expr1: cast(a.slices1),
Expr2: epoch,
},
Expr2: cgen.Mul{
Expr1: cast(a.fromChans),
Expr2: group,
},
},
}
return stage4()
}
stage2 := func() cgen.Gen {
a.datPtrs = make([]cgen.Gen, datCnt)
for i := range a.datPtrs {
a.datPtrs[i] = vb(a.name("datPtr"))
}
datExprs = make([]cgen.Gen, datCnt)
for i := range datExprs {
var (
pitch1 = a.From.Pitch1Bytes[i]
pitch2 = a.From.Pitch2Bytes[i]
padH = a.PaddingH * pitch1
padW = a.PaddingW * a.datBytes
expr = cast(-padH + -padW)
epochPitch = cast(a.slices1 * pitch2)
groupPitch = cast(a.fromChans * pitch2)
)
expr = addMul(expr, epochPitch, epoch)
expr = addMul(expr, groupPitch, group)
datExprs[i] = expr
}
return stage3()
}
stage1 := func() cgen.Gen {
if a.epochCnt == 1 {
epoch = il(a.epochFirst)
}
if group == nil {
group = il(0)
} else {
group = cgen.Mul{
Expr1: il(a.groupTile),
Expr2: group,
}
}
for i := range a.From.Ops {
if a.From.Ops[i].Kind == mod.Bn {
bnCnt++
}
}
return stage2()
}
return stage1()
}

func (a *arrangeDats) kernel() cgen.Gen {
var (
gotoNext cgen.Gen
)
layer6 := func() cgen.Gen {
if a.toks == nil {
return a.special()
}
return a.general()
}
layer5 := func() cgen.Gen {
a.sliceIdx = vb(a.name("k"))
var (
stmts = make(cgen.Stmts, 3)
first cgen.Gen
iters cgen.Gen
past = vb(a.name("kk"))
)
if a.sliceCoord == nil {
first = il(0)
} else {
first = cgen.Mul{
Expr1: il(a.sliceTile),
Expr2: a.sliceCoord,
}
}
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.sliceIdx,
Init: first,
}
switch {
case a.sliceTiles == a.sliceHull:
iters = il(a.sliceTile)
case a.sliceTiles == 0:
fallthrough
case a.sliceTile == a.sliceScrap:
iters = il(a.sliceScrap)
default:
iters = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.sliceCoord,
Expr2: il(a.sliceTiles),
},
Then: il(a.sliceTile),
Else: il(a.sliceScrap),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT, What: past,
Init: cgen.Add{
Expr1: a.sliceIdx,
Expr2: iters,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpL{
Expr1: a.sliceIdx,
Expr2: past,
},
Post: cgen.IncPre{
Expr: a.sliceIdx,
},
Body: layer6(),
}
return stmts
}
layer4Special := func() cgen.Gen {
stmts := make(cgen.Stmts, 2)
if a.datCores1 > 0 {
a.short = false
stmts[0] = cgen.For{
Cond: cgen.CmpNE{
Expr1: a.coreIdx,
Expr2: il(a.datCores1),
},
Post: cgen.IncPre{
Expr: a.coreIdx,
},
Body: cgen.Stmts{
layer5(),
gotoNext,
},
}
}
if a.datCores1 < a.datCores2 {
a.short = true
stmts[1] = layer5()
}
return stmts
}
layer4General := func() cgen.Gen {
leaf := func(sect *section) cgen.Stmts {
var (
decl cgen.Gen
which cgen.Gen
n = len(sect.Uniqs)
cases = make(cgen.Stmts, n)
)
which = cgen.Sub{
Expr1: cgen.Cast{
Type: cgen.SizeT,
Expr: a.coreIdx,
},
Expr2: il(sect.IdxFirst),
}
if sect.FromWrap == 0 {
a.sectH = cast(sect.FromBase)
for x, tok := range sect.Uniqs {
a.tok = tok
var (
expr cgen.Gen
body = make(cgen.Stmts, 3)
)
if x < n-1 {
expr = il(x)
}
body[0] = cgen.Assign{
Expr1: a.coreIdx,
Expr2: il(sect.IdxFirst + x),
}
body[1] = layer5()
body[2] = gotoNext
cases[x] = cgen.Case{
Expr: expr,
Body: body,
}
}
} else {
which = cgen.Paren{
Inner: which,
}
a.sectH = vb(a.name("h"))
decl = cgen.Var{
Type: cgen.PtrdiffT, What: a.sectH,
Init: addMul(
il(sect.FromBase),
cgen.Quo{
Expr1: which,
Expr2: il(n),
},
il(sect.FromWrap),
),
}
which = cgen.Rem{
Expr1: which,
Expr2: il(n),
}
var (
wrap = cgen.Label(a.name("wrap"))
last = sect.IdxPast - 1
at = (last - sect.IdxFirst) % n
)
for x, tok := range sect.Uniqs {
a.tok = tok
var (
expr cgen.Gen
body = make(cgen.Stmts, 7)
)
if x < n-1 {
expr = il(x)
}
if x == 0 {
body[0] = wrap
}
body[1] = layer5()
body[2] = gotoNext
if x == at {
body[3] = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.coreIdx,
Expr2: il(last),
},
Then: cgen.Break,
}
}
body[4] = cgen.IncPre{
Expr: a.coreIdx,
}
if x == n-1 {
body[5] = cgen.AddAssign{
Expr1: a.sectH,
Expr2: il(sect.FromWrap),
}
body[6] = cgen.Goto(wrap)
}
cases[x] = cgen.Case{
Expr: expr,
Body: body,
}
}
}
return cgen.Stmts{
decl,
cgen.Switch{
Expr: which,
Cases: cases,
},
cgen.Assign{
Expr1: a.coreIdx,
Expr2: il(sect.IdxPast),
},
}
}
var (
sects = a.toks.Sects
tree func(int, int) cgen.Stmts
)
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(sects[first])
}
var (
start = sects[first].IdxFirst
stop = sects[last].IdxPast
split = start + (stop-start)/2
x = first + 1
)
for sects[x].IdxPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: a.coreIdx,
Expr2: il(sects[x].IdxFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(sects)-1)
}
layer3 := func() cgen.Gen {
if a.toks == nil {
return layer4Special()
}
return layer4General()
}
layer2 := func() cgen.Gen {
a.coreIdx = vb(a.name("j"))
var (
stmts = make(cgen.Stmts, 4)
first cgen.Gen
)
if a.coreCoord == nil {
first = il(0)
} else {
first = cgen.Mul{
Expr1: il(a.coreTile),
Expr2: a.coreCoord,
}
}
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreIdx,
Init: first,
}
if a.coreCoord != nil {
var (
last = vb(a.name("jj"))
expr cgen.Gen
)
switch a.coreTiles {
case a.coreHull:
expr = il(a.coreTile - 1)
case 0:
expr = il(a.coreScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.coreCoord,
Expr2: il(a.coreTiles),
},
Then: il(a.coreTile - 1),
Else: il(a.coreScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT, What: last,
Init: cgen.Add{
Expr1: a.coreIdx,
Expr2: expr,
},
}
next := cgen.Label(a.name("next"))
gotoNext = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.coreIdx,
Expr2: last,
},
Then: cgen.Goto(next),
}
stmts[3] = next
}
stmts[2] = layer3()
return stmts
}
layer1 := func() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
past = vb(a.name("ii"))
iters cgen.Gen
)
switch a.groupTiles {
case a.groupHull:
iters = il(a.groupTile)
case 0:
iters = il(a.groupScrap)
default:
iters = cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile),
Else: il(a.groupScrap),
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: past, Init: iters,
},
cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: a.groupIdx,
Expr2: past,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: layer2(),
},
}
}
return layer1()
}

func (a *arrangeDats) special() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512Special()
default:
panic("bug")
}
}

func (a *arrangeDats) general() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512General()
default:
panic("bug")
}
}

func (a *arrangeDats) m512Special() cgen.Gen {
var (
bnMuls []cgen.Gen
bnAdds []cgen.Gen
slotIdx int
mask cgen.Gen
stmts cgen.Stmts
slot cgen.Gen
)
stmt := func(s cgen.Gen) {
stmts = append(stmts, s)
}
datLoad := func(x int) cgen.Gen {
var (
dat = vb(a.name("dat"))
ae = a.datPtrs[x]
pitch2 = a.From.Pitch2Bytes[x]
groupPitch = il(a.fromChans * pitch2)
corePitch = il(a.datSliceBytes1)
slicePitch = il(pitch2)
)
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, corePitch, a.coreIdx)
ae = addMul(ae, slicePitch, a.sliceIdx)
ae = cgen.Add{
Expr1: ae,
Expr2: cast(slotIdx * a.slotBytes),
}
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
mask, ae,
},
})
return dat
}
inner3 := func() {
var (
datIdx = 1
bnIdx = 0
)
for op := range a.From.Ops {
op := &a.From.Ops[op]
switch op.Kind {
case mod.Add:
var (
n = 1 + op.Int
ds = make([]cgen.Gen, n)
)
ds[0] = slot
for x := 1; x < n; x++ {
ds[x] = datLoad(datIdx)
datIdx++
}
for n > 1 {
fold := n >> 1
n -= fold
for x := 0; x < fold; x++ {
keep := ds[x]
stmt(cgen.Assign{
Expr1: keep,
Expr2: avx.Mm512AddPs{
keep, ds[n+x],
},
})
}
}
case mod.Bn:
x := bnIdx
bnIdx++
var (
bnMul = bnMuls[x]
bnAdd = bnAdds[x]
)
if bnMul == nil {
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
bnMuls[x] = bnMul
bnAdds[x] = bnAdd
stmt(&bn.Load{
Ctx: a.bc,
Mas: a.bnPtrs[x],
Channel: cgen.Paren{
Inner: addMul(
a.sliceIdx,
il(a.fromChans),
a.groupIdx,
),
},
Mul: bnMul,
Add: bnAdd,
})
} else {
stmt(nil)
}
stmt(&bn.Apply{
Ctx: a.bc,
Mul: bnMul,
Add: bnAdd,
To: slot,
})
case mod.ReLU:
stmt(&act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: slot,
})
default:
panic("bug")
}
}
}
inner2 := func() {
inner3()
var (
ae = a.arranged
groupPitch = il(a.groupBytes)
corePitch = il(a.coreBytes)
slicePitch = il(a.datSliceBytes1)
)
if a.short {
slicePitch = il(a.datSliceBytes2)
}
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, corePitch, a.coreIdx)
ae = addMul(ae, slicePitch, a.sliceIdx)
ae = cgen.Add{
Expr1: ae,
Expr2: cast(slotIdx * a.slotBytes),
}
stmt(avx.Mm512MaskStoreuPs{
ae, mask, slot,
})
}
inner1 := func() cgen.Stmts {
stmts = nil
slot = datLoad(0)
inner2()
return stmts
}
outer2 := func() cgen.Gen {
var (
ns = a.datSliceSlots1
nd = a.datSliceDats1
)
if a.short {
ns = a.datSliceSlots2
nd = a.datSliceDats2
}
toMix := make([]cgen.Stmts, ns)
for x := range toMix {
switch slotIdx = x; x {
case ns - 1:
rem := nd - x*a.slotDats
mask = loMask(rem)
case 0:
mask = loMask(a.slotDats)
}
toMix[x] = inner1()
}
return mix(toMix)
}
outer1 := func() cgen.Gen {
n := len(a.bnPtrs)
bnMuls = make([]cgen.Gen, n)
bnAdds = make([]cgen.Gen, n)
return outer2()
}
return outer1()
}

func (a *arrangeDats) m512General() cgen.Gen {
const lanes = 16
var (
stmts cgen.Stmts
bnMuls []cgen.Gen
bnAdds []cgen.Gen
bnSplit = 0
opSplit = 0
slots []cgen.Gen
dats []cgen.Gen
)
stmt := func(s cgen.Gen) {
stmts = append(stmts, s)
}
eval := func(id int, expr cgen.Gen) {
for id >= len(dats) {
dats = append(dats, nil)
}
dat := dats[id]
if dat == nil {
dat = vb(a.name("dat"))
dats[id] = dat
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: expr,
})
} else {
stmt(cgen.Assign{
Expr1: dat,
Expr2: expr,
})
}
}
rot := func(cmd *m512CmdRotate) {
var (
dst = cmd.DstId
src = dats[cmd.SrcId]
cnt = il(cmd.Cnt)
via = vb(a.name("via"))
)
stmt(cgen.Var{
Type: avx.M512i, What: via,
Init: avx.Mm512CastpsSi512{src},
})
stmt(cgen.Assign{
Expr1: via,
Expr2: avx.Mm512AlignrEpi32{
via, via, cnt,
},
})
eval(dst, avx.Mm512Castsi512Ps{via})
}
blend := func(cmd *m512CmdBlend) {
var (
dst = dats[cmd.DstId]
src = dats[cmd.SrcId]
mask1 = 1<<uint(cmd.Cnt) - 1
mask2 = mask1 << uint(cmd.Off)
)
stmt(cgen.Assign{
Expr1: dst,
Expr2: avx.Mm512MaskMovPs{
dst, il(mask2), src,
},
})
}
ctrl := func(off, inc int) cgen.Gen {
var (
pm = vb(a.name("pm"))
set = make(avx.Mm512SetEpi32, lanes)
x = 0
)
for lane := 0; lane < lanes; lane++ {
if lane > off {
was := x
x += inc
if was < lanes && x > lanes {
x = lanes
}
}
var entry cgen.Gen
if x == 0 || x >= 2*lanes {
entry = cgen.Zero
} else {
entry = il(x)
}
set[lanes-1-lane] = entry
}
stmt(cgen.Var{
Type: avx.M512i, What: pm,
Init: set,
})
return pm
}
perm1 := func(cmd *m512CmdPermute1) {
var (
dst = cmd.DstId
src = dats[cmd.SrcId]
pm = ctrl(cmd.Off, cmd.Inc)
)
eval(dst, avx.Mm512PermutexvarPs{
pm, src,
})
}
perm2 := func(cmd *m512CmdPermute2) {
var (
dst = cmd.DstId
src1 = dats[cmd.SrcId1]
src2 = dats[cmd.SrcId2]
pm = ctrl(cmd.Off, cmd.Inc)
)
eval(dst, avx.Mm512Permutex2varPs{
src1, pm, src2,
})
}
datLoad := func(x, relH, w, cnt int) cgen.Gen {
var (
ae = a.datPtrs[x]
pitch1 = a.From.Pitch1Bytes[x]
pitch2 = a.From.Pitch2Bytes[x]
groupPitch = il(a.fromChans * pitch2)
h = a.tok.From.FirstH + relH
)
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, il(pitch1), a.sectH)
ae = addMul(ae, il(pitch2), a.sliceIdx)
ae = cgen.Add{
Expr1: ae,
Expr2: cast(h*pitch1 + w*a.datBytes),
}
if cnt == 0 {
return avx.Mm512Set1Ps{
cgen.At{Expr: cgen.Cast{
Type: cgen.PtrFloat,
Expr: cgen.Paren{Inner: ae},
}},
}
}
return avx.Mm512MaskzLoaduPs{
loMask(cnt), ae,
}
}
load := func(cmd *m512CmdLoad) {
var (
dst = cmd.Id
relH = cmd.RelH
w = cmd.W
cnt = cmd.Cnt
)
eval(dst, datLoad(0, relH, w, cnt))
}
bnLoad := func(x int) {
if bnMuls[x] != nil {
return
}
var (
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
stmt(&bn.Load{
Ctx: a.bc,
Mas: a.bnPtrs[x],
Channel: cgen.Paren{
Inner: addMul(
a.sliceIdx,
il(a.fromChans),
a.groupIdx,
),
},
Mul: bnMul,
Add: bnAdd,
})
bnMuls[x] = bnMul
bnAdds[x] = bnAdd
}
modAddPre := func(cmd *m512CmdFromModAddPre) {
var (
dst = dats[cmd.Id]
relH = cmd.RelH
w = cmd.W
cnt = cmd.Cnt
datIdx = 1
bnIdx = 0
)
for op := 0; op < opSplit; op++ {
op := &a.From.Ops[op]
switch op.Kind {
case mod.Add:
var (
n = 1 + op.Int
ds = make([]cgen.Gen, n)
)
ds[0] = dst
for x := 1; x < n; x++ {
dat := vb(a.name("dat"))
ds[x] = dat
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: datLoad(
datIdx, relH, w, cnt,
),
})
datIdx++
}
for n > 1 {
fold := n >> 1
n -= fold
for x := 0; x < fold; x++ {
keep := ds[x]
stmt(cgen.Assign{
Expr1: keep,
Expr2: avx.Mm512AddPs{
keep, ds[n+x],
},
})
}
}
case mod.Bn:
n := cnt
if n == 0 {
n = lanes
}
bnLoad(bnIdx)
stmt(&bn.Apply{
Ctx: a.bc,
Mul: bnMuls[bnIdx],
Add: bnAdds[bnIdx],
To: dst,
Mask: loMask(n),
})
bnIdx++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: dst,
})
default:
panic("bug")
}
}
}
modPostAdd := func(cmd *m512CmdFromModPostAdd) {
var (
dst = dats[cmd.Id]
mask = il(cmd.Mask)
bnIdx = bnSplit
ops = a.From.Ops[opSplit:]
)
for op := range ops {
op := &ops[op]
switch op.Kind {
case mod.Bn:
bnLoad(bnIdx)
stmt(&bn.Apply{
Ctx: a.bc,
Mul: bnMuls[bnIdx],
Add: bnAdds[bnIdx],
To: dst,
Mask: mask,
})
bnIdx++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: dst,
})
default:
panic("bug")
}
}
}
stage3 := func() {
for _, cmd := range a.tok.From.Cmds {
switch cmd := cmd.(type) {
case *m512CmdZero:
eval(cmd.Id, avx.Mm512SetzeroPs)
case *m512CmdRotate:
rot(cmd)
case *m512CmdBlend:
blend(cmd)
case *m512CmdPermute1:
perm1(cmd)
case *m512CmdPermute2:
perm2(cmd)
case *m512CmdLoad:
load(cmd)
case *m512CmdFromModAddPre:
modAddPre(cmd)
case *m512CmdFromModPostAdd:
modPostAdd(cmd)
case *m512CmdSlotPut:
slots[cmd.Slot] = dats[cmd.Id]
default:
panic("bug")
}
}
}
stage2 := func() {
n := a.tok.Slots
slots = make([]cgen.Gen, n)
stage3()
ae := a.arranged
ae = addMul(ae, il(a.groupBytes), a.groupIdx)
ae = addMul(ae, il(a.coreBytes), a.coreIdx)
ae = addMul(ae, il(n*a.slotBytes), a.sliceIdx)
for x, slot := range slots {
stmt(avx.Mm512StoreuPs{
cgen.Add{
Expr1: ae,
Expr2: cast(x * a.slotBytes),
},
slot,
})
}
}
stage1 := func() {
bnCnt := len(a.bnPtrs)
bnMuls = make([]cgen.Gen, bnCnt)
bnAdds = make([]cgen.Gen, bnCnt)
bnPre := 0
for x := range a.From.Ops {
switch a.From.Ops[x].Kind {
case mod.Add:
bnSplit = bnPre
opSplit = x + 1
case mod.Bn:
bnPre++
}
}
stage2()
}
stage1()
return stmts
}

type Apply struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
callerName string
}

func (a *Apply) Prep() cgen.Gen {
const affix = "Apply"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&apply{Apply: a},
cgen.Newline,
}
}

func (a *Apply) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type apply struct {
*Apply
*layout
slices int
wtCoreBytes int
wtGroupBytes int
datCoreBytes int
datGroupBytes int
epochFirst bool
epochLast bool
wtTile int
wtTiles int
wtScrap int
wtHull int
datTile int
datTiles int
datScrap int
datHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
epochCoord cgen.Gen
groupCoord cgen.Gen
datCoord cgen.Gen
wtCoord cgen.Gen
arrangedWts cgen.Gen
arrangedDats cgen.Gen
datSplit int
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
groupIdx cgen.Gen
datCore cgen.Gen
datShort bool
sectH cgen.Gen
tok *token
wtCore cgen.Gen
wtShort bool
rows int
cols int
sums []cgen.Gen
}

func (a *apply) Append(to []byte) []byte {
a.layout = newLayout(ctxSpec{
Ctx: a.Ctx,
Spec: a.Spec,
})
callee := func(epoch int) cgen.Gen {
if epoch < a.epochs1 {
a.slices = a.slices1
a.wtCoreBytes = a.wtCoreBytes11
a.wtGroupBytes = a.wtGroupBytes1
a.datCoreBytes = a.datCoreBytes11
a.datGroupBytes = a.datGroupBytes1
} else {
a.slices = a.slices2
a.wtCoreBytes = a.wtCoreBytes21
a.wtGroupBytes = a.wtGroupBytes2
a.datCoreBytes = a.datCoreBytes21
a.datGroupBytes = a.datGroupBytes2
}
a.epochFirst = epoch == 0
a.epochLast = epoch == a.epochs2-1
a.wtTile = 1
a.wtTiles = a.wtCores2
a.wtScrap = 0
a.wtHull = a.wtCores2
a.datTile = 1
a.datTiles = a.datCores2
a.datScrap = 0
a.datHull = a.datCores2
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
a.calleeName = a.name(
a.callerName + "Callee",
)
var (
wtWork = a.slices
datWork = a.wtCores2 * wtWork
groupWork = a.datCores2 * datWork
threadWork int
)
switch a.platform {
case raw.AVX512Float32:
threadWork = 512
default:
panic("bug")
}
switch {
case threadWork <= wtWork:
case threadWork <= datWork:
var (
tile = ceilQuo(threadWork, wtWork)
tiles = max(a.wtCores2/tile, 1)
)
a.wtTile = a.wtCores2 / tiles
a.wtTiles = tiles
a.wtScrap = a.wtCores2 - tiles*a.wtTile
a.wtHull = tiles
if a.wtScrap > 0 {
a.wtTiles--
a.wtScrap += a.wtTile
}
case threadWork <= groupWork:
a.wtTile = a.wtCores2
a.wtTiles = 1
a.wtScrap = 0
a.wtHull = 1
var (
tile = ceilQuo(threadWork, datWork)
tiles = max(a.datCores2/tile, 1)
)
a.datTile = a.datCores2 / tiles
a.datTiles = tiles
a.datScrap = a.datCores2 - tiles*a.datTile
a.datHull = tiles
if a.datScrap > 0 {
a.datTiles--
a.datScrap += a.datTile
}
default:
a.wtTile = a.wtCores2
a.wtTiles = 1
a.wtScrap = 0
a.wtHull = 1
a.datTile = a.datCores2
a.datTiles = 1
a.datScrap = 0
a.datHull = 1
var (
tile = ceilQuo(threadWork, groupWork)
tiles = max(a.Groups/tile, 1)
)
a.groupTile = a.Groups / tiles
a.groupTiles = tiles
a.groupScrap = a.Groups - tiles*a.groupTile
a.groupHull = tiles
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
}
}
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
pair = vb(a.name("pair"))
)
do := func(epoch cgen.Gen) cgen.Gen {
stmts := make(cgen.Stmts, 2)
if a.epochFirst {
stmts[0] = cgen.Var{
Type: cgen.PtrVoid,
What: cgen.Elem{Arr: pair},
Init: cgen.Brace{
Inner: cgen.CommaSpaced{
tensors, epoch,
},
},
}
} else {
stmts[0] = cgen.Assign{
Expr1: cgen.Elem{
Arr: pair, Idx: il(1),
},
Expr2: cgen.Cast{
Type: cgen.PtrVoid,
Expr: epoch,
},
}
}
stmts[1] = &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: pair,
Hull: []cgen.Gen{
il(a.wtHull),
il(a.datHull),
il(a.groupHull),
},
Team: team,
}
return stmts
}
var (
prep = make(cgen.Gens, 3)
body = make(cgen.Gens, 3)
)
prep[0] = callee(0)
body[0] = do(il(0))
if a.epochs2 > 1 {
var (
start = 1
stop = a.epochs1
)
if stop == a.epochs2 {
if len(a.To.Ops) > 0 ||
len(a.To.Pitch1Bytes) > 1 {
stop--
}
}
if start < stop {
prep[1] = callee(start)
epoch := vb(a.name("e"))
body[1] = cgen.Stmts{
cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: epoch,
Init: il(start),
},
Cond: cgen.CmpL{
Expr1: epoch,
Expr2: il(stop),
},
Post: cgen.IncPre{
Expr: epoch,
},
Body: do(epoch),
},
}
}
if stop < a.epochs2 {
prep[2] = callee(stop)
body[2] = do(il(stop))
}
}
return cgen.Gens{
prep,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: body,
},
}.Append(to)
}

func (a *apply) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 9)
pair = vb(a.name("pair"))
tensors = vb(a.name("tensors"))
epoch cgen.Gen
usedPt = false
)
body[0] = cgen.Var{
Type: cgen.PtrPtrVoid, What: pair,
Init: callee.Any(),
}
body[1] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: cgen.Elem{
Arr: pair, Idx: il(0),
},
}
switch {
case a.epochFirst:
epoch = il(0)
case a.epochLast:
epoch = il(a.epochs2 - 1)
default:
epoch = cgen.Cast{
Type: cgen.PtrdiffT,
Expr: cgen.Elem{
Arr: pair, Idx: il(1),
},
}
}
a.epochCoord = vb(a.name("e"))
body[2] = cgen.Var{
Type: cgen.PtrdiffT, What: a.epochCoord,
Init: epoch,
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(a.name(nm))
expr cgen.Gen
)
if hull == 1 {
expr = il(0)
} else {
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[5-i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
a.groupCoord = coord("g", a.groupHull, 2)
a.datCoord = coord("d", a.datHull, 1)
a.wtCoord = coord("w", a.wtHull, 0)
if !usedPt {
body[6] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
body[7] = a.ptrs(tensors)
body[8] = a.kernel()
return callee.Func(body)
}

func (a *apply) ptrs(tensors cgen.Gen) cgen.Gen {
var (
group cgen.Gen
awOff cgen.Gen
adOff cgen.Gen
datOffs []cgen.Gen
bnCh cgen.Gen
)
stage6 := func() cgen.Gen {
var (
stmts cgen.Stmts
tensorIdx = 0
datIdx = 0
bnIdx = 0
)
stmt := func(s cgen.Gen) {
stmts = append(stmts, s)
}
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: tensors,
Idx: il(i),
}
}
decl := func(what, off cgen.Gen) {
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: what,
Init: cgen.Add{
Expr1: tensor(),
Expr2: off,
},
})
}
decl(a.arrangedWts, awOff)
decl(a.arrangedDats, adOff)
dps := func(n int) {
for ; n > 0; n-- {
i := datIdx
datIdx++
datPtr := a.datPtrs[i]
if datPtr == nil {
tensorIdx++
} else {
decl(datPtr, datOffs[i])
}
}
}
bp := func() {
bnPtr := a.bnPtrs[bnIdx]
bnIdx++
if bnPtr == nil {
tensorIdx++
return
}
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: bnPtr,
Init: &bn.Offset{
Ctx: a.bc,
Mas: tensor(),
Channel: bnCh,
},
})
}
for i := range a.To.Ops {
op := &a.To.Ops[i]
switch op.Kind {
case mod.Add:
dps(op.Int)
case mod.Bn:
bp()
case mod.ReLU:
default:
panic("bug")
}
}
dps(len(a.datPtrs) - datIdx)
return stmts
}
stage5 := func() cgen.Gen {
n := 0
for i := range a.To.Ops {
if a.To.Ops[i].Kind == mod.Bn {
n++
}
}
a.bnPtrs = make([]cgen.Gen, n)
if a.epochLast {
for i := range a.bnPtrs {
a.bnPtrs[i] = vb(a.name("bnPtr"))
}
bnCh = cgen.Mul{
Expr1: il(a.toChans),
Expr2: group,
}
}
return stage6()
}
stage4 := func() cgen.Gen {
a.datSplit = 0
for i := range a.To.Ops {
op := &a.To.Ops[i]
if op.Kind == mod.Add {
a.datSplit += op.Int
}
}
n := len(a.To.Pitch1Bytes)
a.datPtrs = make([]cgen.Gen, n)
for i := range a.datPtrs {
if a.epochLast || i == a.datSplit {
a.datPtrs[i] = vb(a.name("datPtr"))
}
}
datOffs = make([]cgen.Gen, n)
for i := range datOffs {
if a.datPtrs[i] == nil {
continue
}
chanPitch := a.To.Pitch2Bytes[i]
datOffs[i] = cgen.Mul{
Expr1: cast(a.toChans * chanPitch),
Expr2: group,
}
}
return stage5()
}
stage3 := func() cgen.Gen {
a.arrangedDats = vb(a.name("arrangedDats"))
adOff = cgen.Add{
Expr1: cgen.Mul{
Expr1: il(a.datEpochBytes1),
Expr2: a.epochCoord,
},
Expr2: cgen.Mul{
Expr1: cast(a.datGroupBytes),
Expr2: group,
},
}
return stage4()
}
stage2 := func() cgen.Gen {
a.arrangedWts = vb(a.name("arrangedWts"))
awOff = cgen.Add{
Expr1: cgen.Mul{
Expr1: il(a.wtEpochBytes1),
Expr2: a.epochCoord,
},
Expr2: cgen.Mul{
Expr1: cast(a.wtGroupBytes),
Expr2: group,
},
}
return stage3()
}
stage1 := func() cgen.Gen {
group = cgen.Mul{
Expr1: il(a.groupTile),
Expr2: a.groupCoord,
}
return stage2()
}
return stage1()
}

func (a *apply) kernel() cgen.Gen {
var (
datRet cgen.Gen
wtRet cgen.Gen
)
layer7 := func() cgen.Gen {
if a.toks == nil {
return a.special()
}
return a.general()
}
layer6 := func() cgen.Gen {
stmts := make(cgen.Stmts, 2)
if a.wtCores1 > 0 {
a.wtShort = false
stmts[0] = cgen.For{
Cond: cgen.CmpNE{
Expr1: a.wtCore,
Expr2: il(a.wtCores1),
},
Post: cgen.IncPre{
Expr: a.wtCore,
},
Body: cgen.Stmts{
layer7(),
wtRet,
},
}
}
if a.wtCores1 < a.wtCores2 {
a.wtShort = true
stmts[1] = layer7()
}
return stmts
}
layer5 := func() cgen.Gen {
a.wtCore = vb(a.name("k"))
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.wtCore,
Init: cgen.Mul{
Expr1: il(a.wtTile),
Expr2: a.wtCoord,
},
}
if a.wtHull > 1 {
var (
last = vb(a.name("kk"))
expr cgen.Gen
)
switch a.wtTiles {
case a.wtHull:
expr = il(a.wtTile - 1)
case 0:
expr = il(a.wtScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.wtCoord,
Expr2: il(a.wtTiles),
},
Then: il(a.wtTile - 1),
Else: il(a.wtScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.wtCore,
Expr2: expr,
},
}
wtRet = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.wtCore,
Expr2: last,
},
Then: cgen.Return{},
}
}
stmts[2] = layer6()
return stmts
}
layer4Special := func() cgen.Gen {
stmts := make(cgen.Stmts, 2)
if a.datCores1 > 0 {
a.datShort = false
stmts[0] = cgen.For{
Cond: cgen.CmpNE{
Expr1: a.datCore,
Expr2: il(a.datCores1),
},
Post: cgen.IncPre{
Expr: a.datCore,
},
Body: cgen.Stmts{
layer5(),
datRet,
},
}
}
if a.datCores1 < a.datCores2 {
a.datShort = true
stmts[1] = layer5()
}
return stmts
}
layer4General := func() cgen.Gen {
leaf := func(sect *section) cgen.Stmts {
a.sectH = vb(a.name("h"))
var (
initH cgen.Gen
which cgen.Gen
n = len(sect.Uniqs)
cases = make(cgen.Stmts, n)
)
if sect.ToWrap == 0 {
initH = il(sect.ToBase)
which = a.datCore
for x, tok := range sect.Uniqs {
a.tok = tok
var (
expr = il(sect.IdxFirst + x)
body = make(cgen.Stmts, 3)
)
body[0] = cgen.Assign{
Expr1: a.datCore,
Expr2: expr,
}
if x == n-1 {
expr = nil
}
body[1] = layer5()
body[2] = datRet
cases[x] = cgen.Case{
Expr: expr,
Body: body,
}
}
} else {
var (
numer = cgen.Paren{
Inner: cgen.Sub{
Expr1: cgen.Cast{
Type: cgen.SizeT,
Expr: a.datCore,
},
Expr2: il(sect.IdxFirst),
},
}
denom = il(n)
)
initH = addMul(
il(sect.ToBase),
cgen.Quo{
Expr1: numer,
Expr2: denom,
},
il(sect.ToWrap),
)
which = cgen.Rem{
Expr1: numer,
Expr2: denom,
}
var (
wrap = cgen.Label(a.name("wrap"))
last = sect.IdxPast - 1
at = (last - sect.IdxFirst) % n
)
for x, tok := range sect.Uniqs {
a.tok = tok
var (
expr cgen.Gen
body = make(cgen.Stmts, 7)
)
if x < n-1 {
expr = il(x)
}
if x == 0 {
body[0] = wrap
}
body[1] = layer5()
body[2] = datRet
if x == at {
body[3] = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.datCore,
Expr2: il(last),
},
Then: cgen.Break,
}
}
body[4] = cgen.IncPre{
Expr: a.datCore,
}
if x == n-1 {
body[5] = cgen.AddAssign{
Expr1: a.sectH,
Expr2: il(sect.ToWrap),
}
body[6] = cgen.Goto(wrap)
}
cases[x] = cgen.Case{
Expr: expr,
Body: body,
}
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: a.sectH, Init: initH,
},
cgen.Switch{
Expr: which,
Cases: cases,
},
cgen.Assign{
Expr1: a.datCore,
Expr2: il(sect.IdxPast),
},
}
}
var (
sects = a.toks.Sects
tree func(int, int) cgen.Stmts
)
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(sects[first])
}
var (
start = sects[first].IdxFirst
stop = sects[last].IdxPast
upper = start + (stop-start)/2
x = first + 1
)
for sects[x].IdxPast <= upper {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: a.datCore,
Expr2: il(sects[x].IdxFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(sects)-1)
}
layer3 := func() cgen.Gen {
if a.toks == nil {
return layer4Special()
}
return layer4General()
}
layer2 := func() cgen.Gen {
a.datCore = vb(a.name("j"))
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.datCore,
Init: cgen.Mul{
Expr1: il(a.datTile),
Expr2: a.datCoord,
},
}
if a.datHull > 1 {
var (
last = vb(a.name("jj"))
expr cgen.Gen
)
switch a.datTiles {
case a.datHull:
expr = il(a.datTile - 1)
case 0:
expr = il(a.datScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.datCoord,
Expr2: il(a.datTiles),
},
Then: il(a.datTile - 1),
Else: il(a.datScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.datCore,
Expr2: expr,
},
}
datRet = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.datCore,
Expr2: last,
},
Then: cgen.Return{},
}
}
stmts[2] = layer3()
return stmts
}
layer1 := func() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
past = vb(a.name("ii"))
groups cgen.Gen
)
switch a.groupTiles {
case a.groupHull:
groups = il(a.groupTile)
case 0:
groups = il(a.groupScrap)
default:
groups = cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile),
Else: il(a.groupScrap),
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: past, Init: groups,
},
cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: a.groupIdx,
Expr2: past,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: layer2(),
},
}
}
return layer1()
}

func (a *apply) special() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512Special()
default:
panic("bug")
}
}

func (a *apply) general() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512General()
default:
panic("bug")
}
}

func (a *apply) m512Special() cgen.Gen {
var (
row int
col int
slot cgen.Gen
stmts cgen.Stmts
bnMuls []cgen.Gen
bnAdds []cgen.Gen
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
mask := func() cgen.Gen {
n := a.slotDats
if a.datShort &&
col == a.cols-1 {
rem := a.datSliceDats2 % n
if rem > 0 {
n = rem
}
}
return loMask(n)
}
ae := func(datPtrIdx int) cgen.Gen {
var (
ret = a.datPtrs[datPtrIdx]
chanPitch = a.To.Pitch2Bytes[datPtrIdx]
groupPitch = a.toChans * chanPitch
datCorePitch = a.datSliceBytes1
wtCorePitch = a.wtSliceWts1 * chanPitch
rowPitch = chanPitch
colPitch = a.slotBytes
)
ret = addMul(ret, il(groupPitch), a.groupIdx)
ret = addMul(ret, il(datCorePitch), a.datCore)
ret = addMul(ret, il(wtCorePitch), a.wtCore)
ret = cgen.Add{
Expr1: ret,
Expr2: cast(row*rowPitch + col*colPitch),
}
return ret
}
addLd := func(datPtrIdx int) {
stmt(cgen.Assign{
Expr1: slot,
Expr2: avx.Mm512AddPs{
slot,
avx.Mm512MaskzLoaduPs{
mask(),
ae(datPtrIdx),
},
},
})
}
layer5 := func() {
var (
datPtrIdx = 0
bnPtrIdx = 0
)
bnPrep := func() cgen.Gen {
if col > 0 {
return nil
}
var (
bnCnt = len(a.bnPtrs)
bnLds = make(cgen.Gens, bnCnt)
)
bnCh := cgen.Paren{
Inner: addMul(
addMul(
il(row),
il(a.wtSliceWts1),
a.wtCore,
),
il(a.toChans),
a.groupIdx,
),
}
bnMuls = make([]cgen.Gen, bnCnt)
bnAdds = make([]cgen.Gen, bnCnt)
for x, bnPtr := range a.bnPtrs {
var (
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
bnLds[x] = &bn.Load{
Ctx: a.bc,
Mas: bnPtr,
Channel: bnCh,
Mul: bnMul,
Add: bnAdd,
}
bnMuls[x] = bnMul
bnAdds[x] = bnAdd
}
return bnLds
}
for op := range a.To.Ops {
op := &a.To.Ops[op]
switch op.Kind {
case mod.Add:
for n := op.Int; n > 0; n-- {
addLd(datPtrIdx)
datPtrIdx++
}
case mod.Bn:
if bnPtrIdx == 0 {
stmt(bnPrep())
}
stmt(&bn.Apply{
Ctx: a.bc,
Mul: bnMuls[bnPtrIdx],
Add: bnAdds[bnPtrIdx],
To: slot,
})
bnPtrIdx++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: slot,
})
default:
panic("bug")
}
}
}
layer4 := func() {
var (
to = a.datSplit
stop = to + 1
)
if !a.epochFirst {
addLd(to)
}
if a.epochLast {
layer5()
stop = len(a.datPtrs)
}
for ; to < stop; to++ {
stmt(avx.Mm512MaskStoreuPs{
ae(to), mask(), slot,
})
}
}
layer3 := func() cgen.Stmts {
stmts = nil
layer4()
return stmts
}
layer2 := func() cgen.Gen {
var (
rr = a.rows
cc = a.cols
ret = make(cgen.Gens, rr)
toMix = make([]cgen.Stmts, cc)
)
for r := 0; r < rr; r++ {
row = r
for c := 0; c < cc; c++ {
col = c
slot = a.sums[r*cc+c]
toMix[c] = layer3()
}
ret[r] = mix(toMix)
}
return ret
}
layer1 := func() cgen.Gen {
return cgen.Gens{
a.m512Dot(),
layer2(),
}
}
return layer1()
}

func (a *apply) m512General() cgen.Gen {
var (
opSplit int
bnSplit int
row int
stmts [2]cgen.Stmts
dats []cgen.Gen
bnMuls []cgen.Gen
bnAdds []cgen.Gen
)
stmt := func(x int, st cgen.Gen) {
stmts[x] = append(
stmts[x], st,
)
}
idPrep := func(id int) {
for id >= len(dats) {
dats = append(dats, nil)
}
}
eval := func(id int, expr cgen.Gen) {
idPrep(id)
dat := dats[id]
if dat == nil {
dat = vb(a.name("dat"))
dats[id] = dat
stmt(0, cgen.Var{
Type: avx.M512, What: dat,
Init: expr,
})
} else {
stmt(0, cgen.Assign{
Expr1: dat,
Expr2: expr,
})
}
}
rot := func(cmd *m512CmdRotate) {
var (
dst = cmd.DstId
src = dats[cmd.SrcId]
cnt = il(cmd.Cnt)
via = vb(a.name("via"))
)
stmt(0, cgen.Var{
Type: avx.M512i, What: via,
Init: avx.Mm512CastpsSi512{src},
})
stmt(0, cgen.Assign{
Expr1: via,
Expr2: avx.Mm512AlignrEpi32{
via, via, cnt,
},
})
eval(dst, avx.Mm512Castsi512Ps{via})
}
slotGet := func(cmd *m512CmdSlotGet) {
var (
id = cmd.Id
slot = cmd.Slot
)
idPrep(id)
dats[id] = a.sums[row*a.cols+slot]
}
bnPrep := func(x int) {
if bnMuls[x] != nil {
return
}
var (
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
stmt(0, &bn.Load{
Ctx: a.bc,
Mas: a.bnPtrs[x],
Channel: cgen.Paren{
Inner: addMul(
addMul(
il(row),
il(a.wtSliceWts1),
a.wtCore,
),
il(a.toChans),
a.groupIdx,
),
},
Mul: bnMul,
Add: bnAdd,
})
bnMuls[x] = bnMul
bnAdds[x] = bnAdd
}
modPreAdd := func(cmd *m512CmdToModPreAdd) {
if !a.epochLast {
return
}
var (
dst = dats[cmd.Id]
bnPtrIdx = 0
)
for op := 0; op < opSplit; op++ {
op := &a.To.Ops[op]
switch op.Kind {
case mod.Bn:
bnPrep(bnPtrIdx)
stmt(0, &bn.Apply{
Ctx: a.bc,
Mul: bnMuls[bnPtrIdx],
Add: bnAdds[bnPtrIdx],
To: dst,
})
bnPtrIdx++
case mod.ReLU:
stmt(0, &act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: dst,
})
default:
panic("bug")
}
}
}
ae := func(x, relH, w int) cgen.Gen {
var (
ret = a.datPtrs[x]
cPitch = a.To.Pitch2Bytes[x]
hPitch = a.To.Pitch1Bytes[x]
wPitch = a.datBytes
groupPitch = a.toChans * cPitch
wtCorePitch = a.wtSliceWts1 * cPitch
rowPitch = cPitch
h = a.tok.To.FirstH + relH
)
ret = addMul(ret, il(groupPitch), a.groupIdx)
ret = addMul(ret, il(hPitch), a.sectH)
ret = addMul(ret, il(wtCorePitch), a.wtCore)
ret = cgen.Add{
Expr1: ret,
Expr2: cast(row*rowPitch + h*hPitch + w*wPitch),
}
return ret
}
addLd := func(dst cgen.Gen, x, relH, w, cnt int) {
stmt(0, cgen.Assign{
Expr1: dst,
Expr2: avx.Mm512AddPs{
dst,
avx.Mm512MaskzLoaduPs{
loMask(cnt),
ae(x, relH, w),
},
},
})
}
modAddPost := func(cmd *m512CmdToModAddPost) {
var (
dst = dats[cmd.Id]
relH = cmd.RelH
w = cmd.W
cnt = cmd.Cnt
)
if !a.epochFirst {
addLd(dst, a.datSplit, relH, w, cnt)
}
if !a.epochLast {
return
}
var (
ops = a.To.Ops[opSplit:]
datPtrIdx = 0
bnPtrIdx = bnSplit
)
for op := range ops {
op := &ops[op]
switch op.Kind {
case mod.Add:
for n := op.Int; n > 0; n-- {
addLd(dst, datPtrIdx, relH, w, cnt)
datPtrIdx++
}
case mod.Bn:
bnPrep(bnPtrIdx)
stmt(0, &bn.Apply{
Ctx: a.bc,
Mul: bnMuls[bnPtrIdx],
Add: bnAdds[bnPtrIdx],
To: dst,
})
bnPtrIdx++
case mod.ReLU:
stmt(0, &act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: dst,
})
default:
panic("bug")
}
}
}
store := func(cmd *m512CmdStore) {
var (
src = dats[cmd.Id]
relH = cmd.RelH
w = cmd.W
cnt = cmd.Cnt
to = a.datSplit
stop = to + 1
)
if a.epochLast {
stop = len(a.datPtrs)
}
for ; to < stop; to++ {
stmt(1, avx.Mm512MaskStoreuPs{
ae(to, relH, w),
loMask(cnt),
src,
})
}
}
layer6 := func() {
for _, cmd := range a.tok.To.Cmds {
switch cmd := cmd.(type) {
case *m512CmdCopy:
src := dats[cmd.SrcId]
eval(cmd.DstId, src)
case *m512CmdRotate:
rot(cmd)
case *m512CmdSlotGet:
slotGet(cmd)
case *m512CmdToModPreAdd:
modPreAdd(cmd)
case *m512CmdToModAddPost:
modAddPost(cmd)
case *m512CmdStore:
store(cmd)
default:
panic("bug")
}
}
}
layer5 := func() {
dats = dats[:0]
if a.epochLast {
n := len(a.bnPtrs)
bnMuls = make([]cgen.Gen, n)
bnAdds = make([]cgen.Gen, n)
}
layer6()
}
layer4 := func() cgen.Gen {
stmts[0] = nil
stmts[1] = nil
layer5()
return cgen.Gens{
stmts[0],
stmts[1],
}
}
layer3 := func() cgen.Gen {
var (
rr = a.rows
ret = make(cgen.Gens, rr)
)
for r := 0; r < rr; r++ {
row = r
ret[r] = layer4()
}
return ret
}
layer2 := func() cgen.Gen {
if a.epochLast {
opSplit = 0
bnSplit = 0
if a.epochFirst {
for op := range a.To.Ops {
kind := a.To.Ops[op].Kind
if kind == mod.Add {
break
}
opSplit++
if kind == mod.Bn {
bnSplit++
}
}
}
}
return layer3()
}
layer1 := func() cgen.Gen {
return cgen.Gens{
a.m512Dot(),
layer2(),
}
}
return layer1()
}

func (a *apply) m512Dot() cgen.Gen {
var (
sliceIdx cgen.Gen
)
ldWt := func(row int) cgen.Gen {
var (
ae = a.arrangedWts
groupPitch = il(a.wtGroupBytes)
corePitch = il(a.wtCoreBytes)
sliceBytes = a.rows * a.wtBytes
)
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, corePitch, a.wtCore)
ae = addMul(ae, il(sliceBytes), sliceIdx)
ae = cgen.Add{
Expr1: ae,
Expr2: cast(sliceBytes + row*a.wtBytes),
}
return avx.Mm512Set1Ps{
cgen.At{
Expr: cgen.Cast{
Type: cgen.PtrFloat,
Expr: cgen.Paren{
Inner: ae,
},
},
},
}
}
ldDat := func(col int) cgen.Gen {
var (
ae = a.arrangedDats
groupPitch = il(a.datGroupBytes)
corePitch = il(a.datCoreBytes)
slicePitch = il(a.cols * a.slotBytes)
)
ae = addMul(ae, groupPitch, a.groupIdx)
ae = addMul(ae, corePitch, a.datCore)
ae = addMul(ae, slicePitch, sliceIdx)
ae = cgen.Add{
Expr1: ae,
Expr2: cast(col * a.slotBytes),
}
return avx.Mm512LoaduPs{ae}
}
layer5 := func() cgen.Gen {
var (
rr = a.rows
cc = a.cols
dats = make([]cgen.Gen, cc)
stmts = make(cgen.Stmts, 0, cc+rr*(1+cc))
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
for c := 0; c < cc; c++ {
dat := vb(a.name("dat"))
dats[c] = dat
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: ldDat(c),
})
}
for r := 0; r < rr; r++ {
wt := vb(a.name("wt"))
stmt(cgen.Var{
Type: avx.M512, What: wt,
Init: ldWt(r),
})
for c := 0; c < cc; c++ {
var (
sum = a.sums[r*cc+c]
dat = dats[c]
)
stmt(cgen.Assign{
Expr1: sum,
Expr2: avx.Mm512FmaddPs{
wt, dat, sum,
},
})
}
}
return stmts
}
layer4 := func() cgen.Gen {
return cgen.For{
Init: cgen.Assign{
Expr1: sliceIdx,
Expr2: il(0),
},
Cond: cgen.CmpL{
Expr1: sliceIdx,
Expr2: il(a.slices),
},
Post: cgen.IncPre{
Expr: sliceIdx,
},
Body: layer5(),
}
}
layer3 := func() cgen.Gen {
var (
rr = a.rows
cc = a.cols
stmts = make(cgen.Stmts, rr*(cc-1)+1)
x = 0
)
for r := 0; r < rr; r++ {
for c := 1; c < cc; c++ {
stmts[x] = cgen.Var{
Type: avx.M512,
What: a.sums[r*cc+c],
Init: a.sums[r*cc],
}
x++
}
}
stmts[x] = layer4()
return stmts
}
layer2 := func() cgen.Gen {
sliceIdx = vb(a.name("s"))
var (
rr = a.rows
stmts = make(cgen.Stmts, 1+rr+1)
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: sliceIdx,
Init: il(-1),
}
for r := 0; r < rr; r++ {
stmts[1+r] = cgen.Var{
Type: avx.M512,
What: a.sums[r*a.cols],
Init: ldWt(r),
}
}
stmts[1+rr] = layer3()
return stmts
}
layer1 := func() cgen.Gen {
a.rows = a.wtSliceWts1
if a.wtShort {
a.rows = a.wtSliceWts2
}
if a.toks == nil {
a.cols = a.datSliceSlots1
if a.datShort {
a.cols = a.datSliceSlots2
}
} else {
a.cols = a.tok.Slots
}
a.sums = make([]cgen.Gen, a.rows*a.cols)
for x := range a.sums {
a.sums[x] = vb(a.name("sum"))
}
return layer2()
}
return layer1()
}

Top || internal/compile/author/params/params.go

package params

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"fmt"
"sort"
)

const (
indent = space + space + space + space
space = " "
suffix = "Params"
)

func Name(pl *plan.Plan) string {
return pl.Config.Prefix + suffix
}

func Fwd(name string) cgen.Gen {
comment := cgen.Comment{
`All weights, biases, and other trained parameters are passed into`,
`the initialization code through the ` + suffix + ` struct that is declared`,
`just below this comment. The corresponding struct definition can be`,
`found near the end of this header file.`,
``,
`Each field of the ` + suffix + ` struct is an array of float that holds a`,
`parameter tensor in NCHW format with no padding. The struct fields`,
`are ordered by name, lexically bytewise. If you concatenate all the`,
`trained parameter tensors to a file in this same format and order`,
`you can load the struct as follows (error checking omitted here):`,
``,
indent + `size_t size = sizeof(` + name + `);`,
indent + name + `* to = malloc(size);`,
indent + `FILE* from = fopen("` + suffix + `File", "r");`,
indent + `fread(to, size, 1, from);`,
indent + `fclose(from);`,
``,
`Be careful to match endianness (and floating point format).`,
}
return cgen.Gens{
comment,
cgen.Newline,
cgen.StructFwd(name),
}
}

type byTensor []*plan.Param

func (by byTensor) Len() int {
return len(by)
}

func (by byTensor) Less(i, j int) bool {
return by[i].Tensor < by[j].Tensor
}

func (by byTensor) Swap(i, j int) {
by[i], by[j] = by[j], by[i]
}

func gather(pl *plan.Plan) (by byTensor) {
take := func(ps []plan.Param) {
for i := range ps {
by = append(by, &ps[i])
}
}
mods := func(a [][]plan.Mod) {
for _, ms := range a {
for i := range ms {
take(ms[i].Params)
}
}
}
for _, op := range pl.Seq {
for i, ps := range op.Params {
take(ps)
mods(op.ParamMods[i][:])
}
mods(op.FromMods)
mods(op.ToMods)
}
sort.Sort(by)
return
}

func fields(pl *plan.Plan) cgen.Gen {
by := gather(pl)
const cols = 2
table := cgen.Table{
Flat: make([]cgen.Gen, 0, len(by)*cols),
Cols: cols,
}
var tensor string
for _, param := range by {
if param.Tensor == tensor {
continue
}
tensor = param.Tensor
product := 1
for _, each := range &param.NCHW {
product *= each
}
table.Flat = append(table.Flat, cgen.Field{
Type: cgen.Float,
What: cgen.Elem{
Arr: cgen.Vb(tensor),
Idx: cgen.IntLit(product),
},
})
table.Flat = append(table.Flat, cgen.Comment{
fmt.Sprintf("%dx%dx%dx%d",
param.NCHW[0],
param.NCHW[1],
param.NCHW[2],
param.NCHW[3]),
})
}
return table
}

func Def(pl *plan.Plan, name string) cgen.Gen {
return cgen.Gens{
cgen.Comment{
`The fields of the following struct have been sorted by name using`,
`Go's "<" string comparison operator (bytewise lexical string sort).`,
`Tensor dimensions are NxCxHxW where N is the outermost/slowest and`,
`W is the innermost/fastest. There is no padding anywhere.`,
},
cgen.Newline,
cgen.StructDef{
Name: name,
Fields: fields(pl),
Attrs: cgen.Packed,
},
}
}

Top || internal/compile/author/quadfft/quadfft.go

package quadfft

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"math"
)

func fl(f float64) cgen.Gen {
return cgen.FloatLit(f)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

type Fwd struct {
Platform raw.Platform
Nms nmsrc.Src
In [16]cgen.Gen
Out [16]cgen.Gen
}

func (F *Fwd) Append(to []byte) []byte {
switch F.Platform {
case raw.AVX512Float32:
return F.m512().Append(to)
default:
panic("bug")
}
}

func (F *Fwd) m512() cgen.Gen {
var (
stmts cgen.Stmts
in []cgen.Gen
out []cgen.Gen
coeffs [6]cgen.Gen
perms [2]cgen.Gen
nodes [80]cgen.Gen
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
decl := func(t, id, expr cgen.Gen) cgen.Gen {
if id == nil {
fft := F.Nms.Name("fft")
id = cgen.Vb(fft)
}
stmt(cgen.Var{
Type: t, What: id,
Init: expr,
})
return id
}
add := func(a, b cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512AddPs{a, b},
)
}
sub := func(a, b cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512SubPs{a, b},
)
}
bcast := func(i int) cgen.Gen {
return avx.Mm512Set1PsLit(
[2]float64{
math.Sqrt2 * 0.5,
0.5,
}[i],
)
}
fmadd := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FmaddPs{a, b, c},
)
}
fnmsub := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FnmsubPs{a, b, c},
)
}
fnmadd := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FnmaddPs{a, b, c},
)
}
coeff := func(i int) cgen.Gen {
cf := coeffs[i]
switch cf {
case nil:
var (
neg1 = il(-1)
neg2 = fl(-math.Sqrt2 * 0.5)
pos1 = il(1)
pos2 = fl(math.Sqrt2 * 0.5)
zero = il(0)
expr cgen.Gen
)
switch i {
case 0:
expr = avx.Mm512SetPs{
neg1, neg1, neg1, neg1,
neg1, neg1, neg1, neg1,
pos1, pos1, pos1, pos1,
pos1, pos1, pos1, pos1,
}
case 1:
expr = avx.Mm512SetPs{
neg2, neg2, zero, zero,
pos2, pos2, pos1, pos1,
pos1, pos1, pos1, pos1,
pos1, pos1, pos1, pos1,
}
case 2:
expr = avx.Mm512SetPs{
pos2, pos2, pos1, pos1,
pos2, pos2, zero, zero,
zero, zero, zero, zero,
zero, zero, zero, zero,
}
case 3:
expr = avx.Mm512SetPs{
neg1, neg1, neg1, neg1,
pos1, pos1, pos1, pos1,
neg1, neg1, neg1, neg1,
pos1, pos1, pos1, pos1,
}
case 4:
expr = avx.Mm512SetPs{
neg1, neg1, pos1, pos1,
neg1, neg1, pos1, pos1,
neg1, neg1, pos1, pos1,
neg1, neg1, pos1, pos1,
}
case 5:
expr = avx.Mm512SetPs{
neg1, pos1, neg1, pos1,
neg1, pos1, zero, zero,
neg1, pos1, neg1, pos1,
neg1, pos1, zero, zero,
}
}
cf = decl(avx.M512, nil, expr)
coeffs[i] = cf
default:
stmt(nil)
}
return cf
}
shuf := func(i int, node cgen.Gen) cgen.Gen {
var (
ctrl int
expr cgen.Gen
)
switch i {
case 0, 2:
ctrl = 1<<6 | 0<<4 | 3<<2 | 2<<0
case 1:
ctrl = 2<<6 | 3<<4 | 0<<2 | 1<<0
}
switch i {
case 0, 1:
expr = avx.Mm512ShuffleF32x4{
node, node, il(ctrl),
}
case 2:
expr = avx.Mm512ShufflePs{
node, node, il(ctrl),
}
}
return expr
}
mul := func(a, b cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512MulPs{a, b},
)
}
blend := func(i int, node0, node1 cgen.Gen) cgen.Gen {
var (
mask cgen.Gen
expr cgen.Gen
)
switch i {
case 0, 1:
mask = il(0xc0c0)
case 2:
mask = il(0x5555)
case 3:
mask = il(0xa8a8)
case 4:
mask = il(0x5656)
case 5:
mask = il(0xfcfc)
}
switch i {
case 0, 2, 3, 4:
expr = avx.Mm512MaskMovPs{
node0, mask, node1,
}
case 1:
expr = avx.Mm512MaskSubPs{
node0, mask,
avx.Mm512SetzeroPs, node1,
}
case 5:
expr = avx.Mm512MaskMulPs{
node0, mask,
node1, bcast(1),
}
}
return decl(avx.M512, nil, expr)
}
perm := func(i int, node cgen.Gen) cgen.Gen {
pm := perms[i]
switch pm {
case nil:
var (
set = make(avx.Mm512SetEpi32, 16)
tbl []int
)
switch i {
case 0:
tbl = []int{13, 9, 5, 3}
case 1:
tbl = []int{11, 15, 7, 1}
}
for j := range set {
set[j] = il(tbl[j%8/2] - j/8)
}
pm = decl(avx.M512i, nil, set)
perms[i] = pm
default:
stmt(nil)
}
return decl(
avx.M512, nil,
avx.Mm512PermutexvarPs{
pm, node,
},
)
}
layer18 := func() {
decl(avx.M512, out[0], nodes[78])
decl(avx.M512, out[1], nodes[79])
decl(avx.M512, out[2], nodes[62])
decl(avx.M512, out[3], nodes[63])
decl(avx.M512, out[4], nodes[64])
decl(avx.M512, out[5], nodes[65])
decl(avx.M512, out[6], nodes[66])
decl(avx.M512, out[7], nodes[67])
}
layer17 := func() {
nodes[78] = blend(5, nodes[76], nodes[76])
nodes[79] = blend(5, nodes[77], nodes[77])
layer18()
}
layer16 := func() {
nodes[76] = blend(3, nodes[74], nodes[73])
nodes[77] = blend(4, nodes[75], nodes[73])
layer17()
}
layer15 := func() {
nodes[74] = blend(2, nodes[71], nodes[72])
nodes[75] = blend(3, nodes[68], nodes[72])
layer16()
}
layer14 := func() {
cf := coeff(5)
nodes[72] = fmadd(nodes[68], cf, nodes[69])
nodes[73] = fnmadd(nodes[71], cf, nodes[70])
layer15()
}
layer13 := func() {
nodes[68] = perm(0, nodes[60])
nodes[69] = perm(1, nodes[60])
nodes[70] = perm(0, nodes[61])
nodes[71] = perm(1, nodes[61])
layer14()
}
layer12 := func() {
cf := coeff(4)
nodes[60] = fmadd(nodes[52], cf, shuf(2, nodes[52]))
nodes[61] = fmadd(nodes[53], cf, shuf(2, nodes[53]))
nodes[62] = fmadd(nodes[54], cf, shuf(2, nodes[54]))
nodes[63] = fmadd(nodes[55], cf, shuf(2, nodes[55]))
nodes[64] = fmadd(nodes[56], cf, shuf(2, nodes[56]))
nodes[65] = fmadd(nodes[57], cf, shuf(2, nodes[57]))
nodes[66] = fmadd(nodes[58], cf, shuf(2, nodes[58]))
nodes[67] = fmadd(nodes[59], cf, shuf(2, nodes[59]))
layer13()
}
layer11 := func() {
nodes[52] = blend(0, nodes[44], nodes[45])
nodes[53] = blend(1, nodes[45], nodes[44])
nodes[54] = blend(0, nodes[46], nodes[47])
nodes[55] = blend(1, nodes[47], nodes[46])
nodes[56] = blend(0, nodes[48], nodes[49])
nodes[57] = blend(1, nodes[49], nodes[48])
nodes[58] = blend(0, nodes[50], nodes[51])
nodes[59] = blend(1, nodes[51], nodes[50])
layer12()
}
layer10 := func() {
cf := coeff(3)
nodes[44] = fmadd(nodes[36], cf, shuf(1, nodes[36]))
nodes[45] = fmadd(nodes[37], cf, shuf(1, nodes[37]))
nodes[46] = fmadd(nodes[38], cf, shuf(1, nodes[38]))
nodes[47] = fmadd(nodes[39], cf, shuf(1, nodes[39]))
nodes[48] = fmadd(nodes[40], cf, shuf(1, nodes[40]))
nodes[49] = fmadd(nodes[41], cf, shuf(1, nodes[41]))
nodes[50] = fmadd(nodes[42], cf, shuf(1, nodes[42]))
nodes[51] = fmadd(nodes[43], cf, shuf(1, nodes[43]))
layer11()
}
layer9 := func() {
cf := coeff(2)
nodes[36] = fmadd(nodes[21], cf, nodes[28])
nodes[37] = fnmadd(nodes[20], cf, nodes[29])
nodes[38] = fmadd(nodes[23], cf, nodes[30])
nodes[39] = fnmadd(nodes[22], cf, nodes[31])
nodes[40] = fmadd(nodes[25], cf, nodes[32])
nodes[41] = fnmadd(nodes[24], cf, nodes[33])
nodes[42] = fmadd(nodes[27], cf, nodes[34])
nodes[43] = fnmadd(nodes[26], cf, nodes[35])
layer10()
}
layer8 := func() {
cf := coeff(1)
nodes[28] = mul(nodes[20], cf)
nodes[29] = mul(nodes[21], cf)
nodes[30] = mul(nodes[22], cf)
nodes[31] = mul(nodes[23], cf)
nodes[32] = mul(nodes[24], cf)
nodes[33] = mul(nodes[25], cf)
nodes[34] = mul(nodes[26], cf)
nodes[35] = mul(nodes[27], cf)
layer9()
}
layer7 := func() {
cf := coeff(0)
nodes[20] = fmadd(nodes[14], cf, shuf(0, nodes[14]))
nodes[21] = fmadd(nodes[15], cf, shuf(0, nodes[15]))
nodes[22] = fmadd(nodes[16], cf, shuf(0, nodes[16]))
nodes[23] = fmadd(nodes[17], cf, shuf(0, nodes[17]))
nodes[24] = fmadd(nodes[9], cf, shuf(0, nodes[9]))
nodes[25] = fmadd(nodes[11], cf, shuf(0, nodes[11]))
nodes[26] = fmadd(nodes[18], cf, shuf(0, nodes[18]))
nodes[27] = fmadd(nodes[19], cf, shuf(0, nodes[19]))
layer8()
}
layer6 := func() {
bc := bcast(0)
nodes[14] = add(nodes[8], nodes[10])
nodes[15] = sub(nodes[8], nodes[10])
nodes[16] = fmadd(nodes[12], bc, nodes[1])
nodes[17] = fnmsub(nodes[13], bc, nodes[5])
nodes[18] = fnmadd(nodes[12], bc, nodes[1])
nodes[19] = fnmadd(nodes[13], bc, nodes[5])
layer7()
}
layer5 := func() {
nodes[8] = add(nodes[0], nodes[4])
nodes[9] = sub(nodes[0], nodes[4])
nodes[10] = add(nodes[2], nodes[6])
nodes[11] = sub(nodes[6], nodes[2])
nodes[12] = sub(nodes[3], nodes[7])
nodes[13] = add(nodes[3], nodes[7])
layer6()
}
layer4 := func() {
nodes[0] = add(in[0], in[8])
nodes[1] = sub(in[0], in[8])
nodes[2] = add(in[2], in[10])
nodes[3] = sub(in[2], in[10])
nodes[4] = add(in[4], in[12])
nodes[5] = sub(in[4], in[12])
nodes[6] = add(in[6], in[14])
nodes[7] = sub(in[6], in[14])
layer5()
}
layer3 := func(from, to int) cgen.Stmts {
stmts = nil
in = F.In[from:]
out = F.Out[to:]
layer4()
return stmts
}
layer2 := func() cgen.Gen {
toMix := [2]cgen.Stmts{
layer3(0, 0),
layer3(1, 8),
}
var (
n = len(toMix[0])
mixed = make(cgen.Stmts, 2*n)
)
for i := range mixed {
mixed[i] = toMix[i&1][i>>1]
}
return mixed
}
layer1 := func() cgen.Gen {
for i := 0; i < 16; i++ {
if F.In[i] == nil {
F.In[i] = avx.Mm512SetzeroPs
}
}
return layer2()
}
return layer1()
}

type Bwd struct {
Platform raw.Platform
Nms nmsrc.Src
In [16]cgen.Gen
Out [16]cgen.Gen
}

func (B *Bwd) Append(to []byte) []byte {
switch B.Platform {
case raw.AVX512Float32:
return B.m512().Append(to)
default:
panic("bug")
}
}

func (B *Bwd) m512() cgen.Gen {
var (
stmts cgen.Stmts
in []cgen.Gen
out []cgen.Gen
perms [2]cgen.Gen
coeffs [6]cgen.Gen
nodes [84]cgen.Gen
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
decl := func(t, id, expr cgen.Gen) cgen.Gen {
if id == nil {
ifft := B.Nms.Name("ifft")
id = cgen.Vb(ifft)
}
stmt(cgen.Var{
Type: t, What: id,
Init: expr,
})
return id
}
perm := func(i int, node cgen.Gen) cgen.Gen {
pm := perms[i]
switch pm {
case nil:
var (
set = make(avx.Mm512SetEpi32, 16)
tbl []int
)
switch i {
case 0:
tbl = []int{12, 14, 14, 12, 10, 10, 9, 8}
case 1:
tbl = []int{13, 15, 15, 13, 11, 11, 8, 9}
}
for j := range set {
set[j] = il(tbl[j%8] - j&8)
}
pm = decl(avx.M512i, nil, set)
perms[i] = pm
default:
stmt(nil)
}
return decl(
avx.M512, nil,
avx.Mm512PermutexvarPs{
pm, node,
},
)
}
coeff := func(i int) cgen.Gen {
cf := coeffs[i]
switch cf {
case nil:
var (
neg1 = il(-1)
neg2 = fl(-math.Sqrt2 * 0.5)
pos1 = il(1)
pos2 = fl(math.Sqrt2 * 0.5)
zero = il(0)
expr cgen.Gen
)
switch i {
case 0:
expr = avx.Mm512SetPs{
pos1, neg1, pos1, neg1,
pos1, neg1, zero, zero,
pos1, neg1, pos1, neg1,
pos1, neg1, zero, zero,
}
case 1:
expr = avx.Mm512SetPs{
neg1, pos1, neg1, pos1,
neg1, pos1, neg1, pos1,
neg1, pos1, neg1, pos1,
neg1, pos1, neg1, pos1,
}
case 2:
expr = avx.Mm512SetPs{
neg2, pos1, pos2, pos1,
zero, pos1, pos1, pos1,
neg2, pos1, pos2, pos1,
zero, pos1, pos1, pos1,
}
case 3:
expr = avx.Mm512SetPs{
pos2, zero, pos2, zero,
pos1, zero, zero, zero,
pos2, zero, pos2, zero,
pos1, zero, zero, zero,
}
case 4:
expr = avx.Mm512SetPs{
neg1, neg1, pos1, pos1,
neg1, neg1, pos1, pos1,
neg1, neg1, pos1, pos1,
neg1, neg1, pos1, pos1,
}
case 5:
expr = avx.Mm512SetPs{
neg1, neg1, neg1, neg1,
pos1, pos1, pos1, pos1,
neg1, neg1, neg1, neg1,
pos1, pos1, pos1, pos1,
}
}
cf = decl(avx.M512, nil, expr)
coeffs[i] = cf
default:
stmt(nil)
}
return cf
}
blend := func(i int, node0, node1 cgen.Gen) cgen.Gen {
var (
mask cgen.Gen
expr cgen.Gen
)
switch i {
case 0, 1:
mask = il(0xfdfd)
case 2, 3:
mask = il(0xc0c0)
}
switch i {
case 0:
expr = avx.Mm512MaskFmaddPs{
node0, mask, coeff(0),
node1,
}
case 1:
expr = avx.Mm512MaskFnmaddPs{
node0, mask, coeff(0),
node1,
}
case 2:
expr = avx.Mm512MaskSubPs{
node0, mask,
avx.Mm512SetzeroPs, node1,
}
case 3:
expr = avx.Mm512MaskMovPs{
node0, mask, node1,
}
}
return decl(avx.M512, nil, expr)
}
shuf := func(i int, node cgen.Gen) cgen.Gen {
var (
ctrl int
expr cgen.Gen
)
switch i {
case 0, 2:
ctrl = 2<<6 | 3<<4 | 0<<2 | 1<<0
case 1:
ctrl = 1<<6 | 0<<4 | 3<<2 | 2<<0
}
switch i {
case 0, 1:
expr = avx.Mm512ShufflePs{
node, node, il(ctrl),
}
case 2:
expr = avx.Mm512ShuffleF32x4{
node, node, il(ctrl),
}
}
return expr
}
fmadd := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FmaddPs{a, b, c},
)
}
mul := func(a, b cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512MulPs{a, b},
)
}
fnmadd := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FnmaddPs{a, b, c},
)
}
fnmsub := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FnmsubPs{a, b, c},
)
}
add := func(a, b cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512AddPs{a, b},
)
}
sub := func(a, b cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512SubPs{a, b},
)
}
bcast := func(i int) cgen.Gen {
return avx.Mm512Set1PsLit(
[3]float64{
1.0 / 32,
1.0 / 64,
math.Sqrt2 * 0.5,
}[i],
)
}
fmsub := func(a, b, c cgen.Gen) cgen.Gen {
return decl(
avx.M512, nil,
avx.Mm512FmsubPs{a, b, c},
)
}
emit := func(i int, node cgen.Gen) {
id := out[i]
switch id {
case nil:
stmt(cgen.Cast{
Type: cgen.Void,
Expr: node,
})
default:
decl(avx.M512, id, node)
}
}
layer16 := func() {
emit(0, nodes[76])
emit(1, nodes[78])
emit(2, nodes[80])
emit(3, nodes[82])
emit(4, nodes[77])
emit(5, nodes[79])
emit(6, nodes[81])
emit(7, nodes[83])
}
layer15 := func() {
bc := bcast(1)
nodes[76] = fmadd(nodes[72], bc, nodes[62])
nodes[77] = fnmadd(nodes[72], bc, nodes[62])
nodes[78] = fmadd(nodes[74], bc, nodes[64])
nodes[79] = fnmadd(nodes[74], bc, nodes[64])
nodes[80] = fnmadd(nodes[75], bc, nodes[63])
nodes[81] = fmadd(nodes[75], bc, nodes[63])
nodes[82] = fmadd(nodes[73], bc, nodes[65])
nodes[83] = fnmadd(nodes[73], bc, nodes[65])
layer16()
}
layer14 := func() {
nodes[72] = add(nodes[68], nodes[69])
nodes[73] = sub(nodes[68], nodes[69])
nodes[74] = add(nodes[70], nodes[71])
nodes[75] = sub(nodes[70], nodes[71])
layer15()
}
layer13 := func() {
bc := bcast(2)
nodes[68] = fnmadd(nodes[66], bc, nodes[58])
nodes[69] = fmadd(nodes[66], bc, nodes[58])
nodes[70] = fmadd(nodes[67], bc, nodes[59])
nodes[71] = fmsub(nodes[67], bc, nodes[59])
layer14()
}
layer12 := func() {
bc := bcast(1)
nodes[62] = fmadd(nodes[54], bc, nodes[60])
nodes[63] = fmsub(nodes[54], bc, nodes[60])
nodes[64] = fmadd(nodes[55], bc, nodes[61])
nodes[65] = fmsub(nodes[55], bc, nodes[61])
nodes[66] = add(nodes[56], nodes[57])
nodes[67] = sub(nodes[56], nodes[57])
layer13()
}
layer11 := func() {
bc := bcast(0)
nodes[54] = add(nodes[46], nodes[47])
nodes[55] = sub(nodes[46], nodes[47])
nodes[56] = sub(nodes[48], nodes[52])
nodes[57] = add(nodes[49], nodes[53])
nodes[58] = add(nodes[48], nodes[52])
nodes[59] = sub(nodes[49], nodes[53])
nodes[60] = mul(nodes[50], bc)
nodes[61] = mul(nodes[51], bc)
layer12()
}
layer10 := func() {
cf := coeff(5)
nodes[46] = fmadd(nodes[38], cf, shuf(2, nodes[38]))
nodes[47] = fmadd(nodes[39], cf, shuf(2, nodes[39]))
nodes[48] = fmadd(nodes[40], cf, shuf(2, nodes[40]))
nodes[49] = fmadd(nodes[41], cf, shuf(2, nodes[41]))
nodes[50] = fmadd(nodes[42], cf, shuf(2, nodes[42]))
nodes[51] = fnmsub(nodes[43], cf, shuf(2, nodes[43]))
nodes[52] = fmadd(nodes[44], cf, shuf(2, nodes[44]))
nodes[53] = fmadd(nodes[45], cf, shuf(2, nodes[45]))
layer11()
}
layer9 := func() {
nodes[38] = blend(2, nodes[30], nodes[31])
nodes[39] = blend(3, nodes[31], nodes[30])
nodes[40] = blend(2, nodes[32], nodes[33])
nodes[41] = blend(3, nodes[33], nodes[32])
nodes[42] = blend(2, nodes[34], nodes[35])
nodes[43] = blend(3, nodes[35], nodes[34])
nodes[44] = blend(2, nodes[36], nodes[37])
nodes[45] = blend(3, nodes[37], nodes[36])
layer10()
}
layer8 := func() {
cf := coeff(4)
nodes[30] = fmadd(nodes[22], cf, shuf(1, nodes[22]))
nodes[31] = fmadd(nodes[23], cf, shuf(1, nodes[23]))
nodes[32] = fmadd(nodes[24], cf, shuf(1, nodes[24]))
nodes[33] = fmadd(nodes[25], cf, shuf(1, nodes[25]))
nodes[34] = fmadd(nodes[26], cf, shuf(1, nodes[26]))
nodes[35] = fmadd(nodes[27], cf, shuf(1, nodes[27]))
nodes[36] = fmadd(nodes[28], cf, shuf(1, nodes[28]))
nodes[37] = fmadd(nodes[29], cf, shuf(1, nodes[29]))
layer9()
}
layer7 := func() {
cf := coeff(3)
nodes[22] = fnmadd(nodes[7], cf, nodes[14])
nodes[23] = fmadd(nodes[6], cf, nodes[15])
nodes[24] = fnmadd(nodes[9], cf, nodes[16])
nodes[25] = fmadd(nodes[8], cf, nodes[17])
nodes[26] = fnmadd(nodes[11], cf, nodes[18])
nodes[27] = fmadd(nodes[10], cf, nodes[19])
nodes[28] = fnmadd(nodes[13], cf, nodes[20])
nodes[29] = fmadd(nodes[12], cf, nodes[21])
layer8()
}
layer6 := func() {
cf := coeff(2)
nodes[14] = mul(nodes[6], cf)
nodes[15] = mul(nodes[7], cf)
nodes[16] = mul(nodes[8], cf)
nodes[17] = mul(nodes[9], cf)
nodes[18] = mul(nodes[10], cf)
nodes[19] = mul(nodes[11], cf)
nodes[20] = mul(nodes[12], cf)
nodes[21] = mul(nodes[13], cf)
layer7()
}
layer5 := func() {
cf := coeff(1)
nodes[6] = fmadd(nodes[4], cf, shuf(0, nodes[4]))
nodes[7] = fmadd(nodes[5], cf, shuf(0, nodes[5]))
nodes[8] = fmadd(in[2], cf, shuf(0, in[2]))
nodes[9] = fmadd(in[3], cf, shuf(0, in[3]))
nodes[10] = fmadd(in[4], cf, shuf(0, in[4]))
nodes[11] = fmadd(in[5], cf, shuf(0, in[5]))
nodes[12] = fmadd(in[6], cf, shuf(0, in[6]))
nodes[13] = fmadd(in[7], cf, shuf(0, in[7]))
layer6()
}
layer4 := func() {
nodes[4] = blend(0, nodes[3], nodes[0])
nodes[5] = blend(1, nodes[2], nodes[1])
layer5()
}
layer3 := func() {
nodes[0] = perm(0, in[0])
nodes[1] = perm(1, in[0])
nodes[2] = perm(0, in[1])
nodes[3] = perm(1, in[1])
layer4()
}
layer2 := func(i int) cgen.Stmts {
if B.In[i] == nil {
return nil
}
stmts = nil
in = B.In[i:]
out = B.Out[i:]
layer3()
return stmts
}
layer1 := func() cgen.Gen {
toMix := [2]cgen.Stmts{
layer2(0),
layer2(8),
}
if toMix[1] == nil {
return toMix[0]
}
var (
n = len(toMix[0])
mixed = make(cgen.Stmts, 2*n)
)
for i := range mixed {
mixed[i] = toMix[i&1][i>>1]
}
return mixed
}
return layer1()
}

Top || internal/compile/author/rsqrt/rsqrt.go

package rsqrt

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

type Ctx struct {
platform raw.Platform
nms nmsrc.Src
funcName string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src) *Ctx {
return &Ctx{
platform: pl.Config.Platform,
nms: nms,
funcName: nms.Name(pl.Config.Prefix + "Rsqrt"),
}
}

func (c *Ctx) Prep() cgen.Gen {
switch c.platform {
case raw.AVX512Float32:
return c.m512()
default:
panic("bug")
}
}

func (c *Ctx) name(s string) cgen.Gen {
return vb(c.nms.Name(s))
}

func (c *Ctx) m512() cgen.Gen {
var (
x = c.name("x")
y = c.name("y")
z = c.name("z")
a = c.name("a")
b = c.name("b")
)
return cgen.StaticFuncDef{
ReturnType: avx.M512,
Name: c.funcName,
Params: cgen.Param{Type: avx.M512, What: x},
Body: cgen.Stmts{
cgen.Var{
Type: avx.M512, What: y,
Init: avx.Mm512Rsqrt14Ps{x},
},
cgen.Var{
Type: avx.M512, What: z,
Init: avx.Mm512MulPs{x, y},
},
cgen.Var{
Type: avx.M512, What: a,
Init: avx.Mm512MulPs{
y, avx.Mm512Set1PsLit(0.5),
},
},
cgen.Var{
Type: avx.M512, What: b,
Init: avx.Mm512FnmaddPs{
y, z, avx.Mm512Set1PsLit(3),
},
},
cgen.Return{
Expr: avx.Mm512MulPs{a, b},
},
},
}
}

type Call struct {
*Ctx
Arg cgen.Gen
}

func (c *Call) Append(to []byte) []byte {
return cgen.Call{
Func: vb(c.funcName),
Args: c.Arg,
}.Append(to)
}

Top || internal/compile/author/softmax/softmax.go

package softmax

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/exp"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

type Ctx struct {
prefix string
platform raw.Platform
nms nmsrc.Src
tc *threader.Ctx
ec *exp.Ctx
dedup map[string]string
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ec *exp.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Softmax",
platform: pl.Config.Platform,
nms: nms,
tc: tc,
ec: ec,
dedup: make(map[string]string),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func cast(pitch int) cgen.Gen {
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: il(pitch),
}
}

func addr(ptr, pitch, idx cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: ptr,
Expr2: cgen.Mul{Expr1: pitch, Expr2: idx},
}
}

type Call struct {
*Ctx
Team cgen.Gen
Tensors []cgen.Gen
Shapes []Shape
funcName string
}

func (c *Call) Prep() cgen.Gen {
sig := fmt.Sprintf("%v", c.Shapes)
if prior, ok := c.dedup[sig]; ok {
c.funcName = prior
return nil
}
c.funcName = c.name(c.prefix)
c.dedup[sig] = c.funcName
return cgen.Gens{
&funcDef{
Ctx: c.Ctx,
funcName: c.funcName,
shapes: c.Shapes,
},
cgen.Newline,
}
}

func (c *Call) Append(to []byte) []byte {
var (
tensors = vb(c.name("tensors"))
ptrs = cgen.CommaLines(c.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(c.funcName),
Args: cgen.CommaSpaced{
c.Team, tensors,
},
},
}.Append(to)
}

type Shape struct {
Channels int
Height int
Width int
ElemBytes int
Pitch1Bytes int
Pitch2Bytes int
}

type funcDef struct {
*Ctx
funcName string
shapes []Shape
}

func (f *funcDef) Append(to []byte) []byte {
var lanes int
switch f.platform {
case raw.AVX512Float32:
lanes = 16
default:
panic("bug")
}
const (
packed = iota
linear
planar
)
var (
form = packed
shape = &f.shapes[0]
channels = shape.Channels
height = shape.Height
width = shape.Width
spatial = height * width
elem = shape.ElemBytes
nt = len(f.shapes)
pitches1 = make([]cgen.Gen, nt)
pitches2 = make([]cgen.Gen, nt)
)
if spatial*2 > lanes || channels < 2 {
form = linear
}
for i := range f.shapes {
var (
at = &f.shapes[i]
pitch1 = at.Pitch1Bytes
pitch2 = at.Pitch2Bytes
)
if pitch1 != width*elem {
form = planar
} else if form == packed &&
pitch2 != spatial*elem {
form = linear
}
pitches1[i] = cast(pitch1)
pitches2[i] = cast(pitch2)
}
var (
gens = make(cgen.Gens, 3)
team = vb(f.name("team"))
tensors = vb(f.name("tensors"))
body cgen.Gen
)
if form == packed {
stmts := make(cgen.Stmts, 2)
stmts[0] = cgen.Cast{
Type: cgen.Void, Expr: team,
}
switch f.platform {
case raw.AVX512Float32:
stmts[1] = &m512Packed{
Ctx: f.Ctx,
channels: channels,
spatial: spatial,
tensors: tensors,
nt: nt,
}
}
body = stmts
} else {
var (
callee = f.name(f.funcName + "Callee")
hull []cgen.Gen
)
if form == linear {
switch f.platform {
case raw.AVX512Float32:
gens[0] = &m512Linear{
Ctx: f.Ctx,
funcName: callee,
channels: channels,
spatial: spatial,
pitches: pitches2,
}
}
hull = []cgen.Gen{
il((spatial + lanes - 1) / lanes),
}
} else {
switch f.platform {
case raw.AVX512Float32:
gens[0] = &m512Planar{
Ctx: f.Ctx,
funcName: callee,
channels: channels,
width: width,
pitches1: pitches1,
pitches2: pitches2,
}
}
hull = []cgen.Gen{
il((width + lanes - 1) / lanes),
il(height),
}
}
gens[1] = cgen.Newline
body = &threader.Do{
Ctx: f.tc,
Callee: vb(callee),
Any: tensors,
Hull: hull,
Team: team,
}
}
gens[2] = cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: f.funcName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: f.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: body,
}
return gens.Append(to)
}

type m512Planar struct {
*Ctx
funcName string
channels int
width int
pitches1 []cgen.Gen
pitches2 []cgen.Gen
}

func (m *m512Planar) Append(to []byte) []byte {
const (
lanes = 16
laneBytes = 4
)
var (
n = len(m.pitches1)
stmts = make(cgen.Stmts, 4+n)
tensors = vb(m.name("tensors"))
w = vb(m.name("w"))
h = vb(m.name("h"))
mask = vb(m.name("mask"))
ptrs = make([]cgen.Gen, n)
max = vb(m.name("max"))
sum = vb(m.name("sum"))
)
callee := &threader.Callee{
Ctx: m.tc,
Name: m.funcName,
Task: vb(m.name("task")),
Pt: vb(m.name("pt")),
}
stmts[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT, What: w,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.Zero},
}
stmts[2] = cgen.Var{
Type: cgen.PtrdiffT, What: h,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.One},
}
stmts[3] = cgen.Var{
Type: avx.Mmask16, What: mask,
Init: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: w,
Expr2: il(m.width / lanes),
},
Then: il(1<<lanes - 1),
Else: il(1<<uint(m.width%lanes) - 1),
},
}
for i := range ptrs {
ptrs[i] = vb(m.name("ptr"))
var (
a1 = cgen.Elem{Arr: tensors, Idx: il(i)}
a2 = addr(a1, m.pitches1[i], h)
a3 = addr(a2, cast(lanes*laneBytes), w)
)
stmts[4+i] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptrs[i], Init: a3,
}
}
return callee.Func(cgen.Gens{
stmts,
&m512Max{
Ctx: m.Ctx,
ptr: ptrs[0],
pitch: m.pitches2[0],
loopCnt: m.channels,
loopMask: mask,
max: max,
},
&m512Exp{
Ctx: m.Ctx,
ldPtr: ptrs[0],
ldPitch: m.pitches2[0],
stPtr: ptrs[1],
stPitch: m.pitches2[1],
loopCnt: m.channels,
loopMask: mask,
sub: max,
sum: sum,
},
&m512Denom{
Ctx: m.Ctx,
ptrs: ptrs[1:],
pitches: m.pitches2[1:],
loopCnt: il(m.channels),
loopMask: mask,
divisor: sum,
},
}).Append(to)
}

type m512Linear struct {
*Ctx
funcName string
channels int
spatial int
pitches []cgen.Gen
}

func (m *m512Linear) Append(to []byte) []byte {
const (
lanes = 16
laneBytes = 4
)
var (
n = len(m.pitches)
stmts = make(cgen.Stmts, 3+n)
tensors = vb(m.name("tensors"))
i = vb(m.name("i"))
mask = vb(m.name("mask"))
ptrs = make([]cgen.Gen, n)
max = vb(m.name("max"))
sum = vb(m.name("sum"))
)
callee := &threader.Callee{
Ctx: m.tc,
Name: m.funcName,
Task: vb(m.name("task")),
Pt: vb(m.name("pt")),
}
stmts[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: tensors,
Init: callee.Any(),
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Elem{Arr: callee.Pt, Idx: cgen.Zero},
}
stmts[2] = cgen.Var{
Type: avx.Mmask16, What: mask,
Init: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: i,
Expr2: il(m.spatial / lanes),
},
Then: il(1<<lanes - 1),
Else: il(1<<uint(m.spatial%lanes) - 1),
},
}
for j := range ptrs {
ptrs[j] = vb(m.name("ptr"))
var (
a1 = cgen.Elem{Arr: tensors, Idx: il(j)}
a2 = addr(a1, cast(lanes*laneBytes), i)
)
stmts[3+j] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptrs[j], Init: a2,
}
}
return callee.Func(cgen.Gens{
stmts,
&m512Max{
Ctx: m.Ctx,
ptr: ptrs[0],
pitch: m.pitches[0],
loopCnt: m.channels,
loopMask: mask,
max: max,
},
&m512Exp{
Ctx: m.Ctx,
ldPtr: ptrs[0],
ldPitch: m.pitches[0],
stPtr: ptrs[1],
stPitch: m.pitches[1],
loopCnt: m.channels,
loopMask: mask,
sub: max,
sum: sum,
},
&m512Denom{
Ctx: m.Ctx,
ptrs: ptrs[1:],
pitches: m.pitches[1:],
loopCnt: il(m.channels),
loopMask: mask,
divisor: sum,
},
}).Append(to)
}

type m512Packed struct {
*Ctx
channels int
spatial int
tensors cgen.Gen
nt int
}

func (m *m512Packed) Append(to []byte) []byte {
var (
ptrs = make([]cgen.Gen, m.nt)
loadPtrs = make(cgen.Stmts, m.nt)
)
for i := range ptrs {
ptrs[i] = vb(m.name("ptr"))
loadPtrs[i] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptrs[i],
Init: cgen.Elem{
Arr: m.tensors,
Idx: il(i),
},
}
}
const (
lanes = 16
laneBytes = 4
)
loopChans := lanes / m.spatial
if loopChans > m.channels {
loopChans = m.channels
}
var (
loopLanes = loopChans * m.spatial
pitch = cast(loopLanes * laneBytes)
loopCnt = m.channels / loopChans
edgeLanes = m.channels % loopChans * m.spatial
loopMask = il(1<<uint(loopLanes) - 1)
edgeMask cgen.Gen
)
if edgeLanes > 0 {
edgeMask = il(1<<uint(edgeLanes) - 1)
}
var (
max = vb(m.name("max"))
sum = vb(m.name("sum"))
pitches = make([]cgen.Gen, m.nt-1)
)
for i := range pitches {
pitches[i] = pitch
}
return cgen.Gens{
loadPtrs,
&m512Max{
Ctx: m.Ctx,
ptr: ptrs[0],
pitch: pitch,
loopCnt: loopCnt,
loopMask: loopMask,
edgeMask: edgeMask,
max: max,
},
&m512Fold{
Ctx: m.Ctx,
dat: max,
cnt: loopChans,
each: m.spatial,
op: foldMax,
},
&m512Exp{
Ctx: m.Ctx,
ldPtr: ptrs[0],
ldPitch: pitch,
stPtr: ptrs[1],
stPitch: pitch,
loopCnt: loopCnt,
loopMask: loopMask,
edgeMask: edgeMask,
sub: max,
sum: sum,
},
&m512Fold{
Ctx: m.Ctx,
dat: sum,
cnt: loopChans,
each: m.spatial,
op: foldAdd,
},
&m512Denom{
Ctx: m.Ctx,
ptrs: ptrs[1:],
pitches: pitches,
loopCnt: il(loopCnt),
loopMask: loopMask,
edgeMask: edgeMask,
divisor: sum,
},
}.Append(to)
}

type m512Max struct {
*Ctx
ptr cgen.Gen
pitch cgen.Gen
loopCnt int
loopMask cgen.Gen
edgeMask cgen.Gen
max cgen.Gen
}

func (m *m512Max) Append(to []byte) []byte {
const unroll = 16
nparts := m.loopCnt
if nparts > unroll {
nparts = unroll
}
parts := make([]cgen.Gen, nparts)
parts[0] = m.max
for i := 1; i < nparts; i++ {
parts[i] = vb(m.name("max"))
}
first := make(cgen.Stmts, nparts)
for i := range first {
first[i] = cgen.Var{
Type: avx.M512, What: parts[i],
Init: avx.Mm512MaskzLoaduPs{
m.loopMask,
addr(m.ptr, m.pitch, il(i)),
},
}
}
stmts := make(cgen.Stmts, 6)
stmts[0] = first
if remain := m.loopCnt - nparts; remain > 0 {
var (
iters = remain / unroll
after = remain % unroll
)
fill := func(s cgen.Stmts, i cgen.Gen, n int) {
for j := 0; j < n; j++ {
var (
dat = vb(m.name("dat"))
part = parts[j]
)
from1 := addr(m.ptr, m.pitch, il(j))
from2 := addr(from1, m.pitch, cgen.Mul{
Expr1: il(unroll),
Expr2: i,
})
s[unroll*0+j] = cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
m.loopMask, from2,
},
}
s[unroll*1+j] = cgen.Assign{
Expr1: part,
Expr2: avx.Mm512MaxPs{
part, dat,
},
}
}
}
if iters > 0 {
var (
body = make(cgen.Stmts, unroll*2)
i = vb(m.name("i"))
)
fill(body, i, unroll)
stmts[1] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.One,
},
Cond: cgen.CmpLE{
Expr1: i, Expr2: il(iters),
},
Post: cgen.IncPre{Expr: i},
Body: body,
}
}
var (
tail = make(cgen.Stmts, unroll*2)
i = il(iters + 1)
)
fill(tail, i, after)
stmts[2] = tail
}
if m.edgeMask != nil {
var (
dat = vb(m.name("dat"))
idx = il(m.loopCnt)
part = parts[nparts-1]
)
stmts[3] = cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
m.edgeMask,
addr(m.ptr, m.pitch, idx),
},
}
stmts[4] = cgen.Assign{
Expr1: part,
Expr2: avx.Mm512MaskMaxPs{
part, m.edgeMask,
part, dat,
},
}
}
fold := make(cgen.Stmts, nparts-1)
for i := 0; nparts > 1; {
cnt := nparts >> 1
nparts -= cnt
for j := 0; j < cnt; j++ {
fold[i] = cgen.Assign{
Expr1: parts[j],
Expr2: avx.Mm512MaxPs{
parts[j],
parts[nparts+j],
},
}
i++
}
}
stmts[5] = fold
return stmts.Append(to)
}

type m512Exp struct {
*Ctx
ldPtr cgen.Gen
ldPitch cgen.Gen
stPtr cgen.Gen
stPitch cgen.Gen
loopCnt int
loopMask cgen.Gen
edgeMask cgen.Gen
sub cgen.Gen
sum cgen.Gen
}

func (m *m512Exp) Append(to []byte) []byte {
const unroll = 16
var (
iters = m.loopCnt / unroll
more = m.loopCnt % unroll
neg = vb(m.name("neg"))
)
ae := func(ptr, pitch, i, j cgen.Gen) cgen.Gen {
expr := addr(ptr, pitch, j)
if iters > 0 {
expr = addr(expr, pitch, cgen.Mul{
Expr1: il(unroll),
Expr2: i,
})
}
return expr
}
ld := func(dat, mask, i, j cgen.Gen) cgen.Gen {
return cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
mask,
ae(m.ldPtr, m.ldPitch, i, j),
},
}
}
add := func(dat, mask cgen.Gen) cgen.Gen {
if mask == m.loopMask {
return avx.Mm512AddPs{m.sum, dat}
}
return avx.Mm512MaskAddPs{m.sum, mask, m.sum, dat}
}
op := func(dat, mask cgen.Gen) cgen.Gen {
return cgen.Stmts{
cgen.Assign{
Expr1: dat,
Expr2: &exp.Call{
Ctx: m.ec,
Arg: avx.Mm512AddPs{neg, dat},
},
},
cgen.Assign{
Expr1: m.sum,
Expr2: add(dat, mask),
},
}
}
st := func(dat, mask, i, j cgen.Gen) cgen.Gen {
return avx.Mm512MaskStoreuPs{
ae(m.stPtr, m.stPitch, i, j),
mask, dat,
}
}
fill := func(s cgen.Stmts, mask, i cgen.Gen, j, n int) {
for stop := j + n; j < stop; j++ {
var (
dat = vb(m.name("dat"))
jj = il(j)
)
s[unroll*1-1-j] = ld(dat, mask, i, jj)
s[unroll*2-1-j] = op(dat, mask)
s[unroll*3-1-j] = st(dat, mask, i, jj)
}
}
stmts := make(cgen.Stmts, 4)
stmts[0] = cgen.Var{
Type: avx.M512, What: m.sum,
Init: avx.Mm512SetzeroPs,
}
stmts[1] = cgen.Var{
Type: avx.M512, What: neg,
Init: avx.Mm512SubPs{m.sum, m.sub},
}
if iters > 0 {
var (
inner = make(cgen.Stmts, unroll*3)
i = vb(m.name("i"))
)
fill(inner, m.loopMask, i, 0, unroll)
stmts[3] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: il(iters - 1),
},
Cond: cgen.CmpGE{
Expr1: i, Expr2: cgen.Zero,
},
Post: cgen.DecPre{Expr: i},
Body: inner,
}
}
var (
outer = make(cgen.Stmts, unroll*3)
i = il(iters)
)
fill(outer, m.loopMask, i, 0, more)
if m.edgeMask != nil {
fill(outer, m.edgeMask, i, more, 1)
}
stmts[2] = outer
return stmts.Append(to)
}

type foldOp int

const (
foldAdd foldOp = iota
foldMax
)

type m512Fold struct {
*Ctx
dat cgen.Gen
cnt int
each int
op foldOp
}

func (m *m512Fold) Append(to []byte) []byte {
var stmts cgen.Stmts
assign := func(a cgen.Gen) {
stmts = append(stmts, cgen.Assign{
Expr1: m.dat, Expr2: a,
})
}
perm := func(a []cgen.Gen) cgen.Gen {
p := vb(m.name("p"))
stmts = append(stmts, cgen.Var{
Type: avx.M512i, What: p,
Init: avx.Mm512SetEpi32(a),
})
return avx.Mm512PermutexvarPs{p, m.dat}
}
call := func(a ...cgen.Gen) cgen.Gen {
switch m.op {
case foldAdd:
return avx.Mm512MaskAddPs(a)
case foldMax:
return avx.Mm512MaskMaxPs(a)
default:
panic("bug")
}
}
const lanes = 16
for have := m.cnt; have > 1; {
stop := have * m.each
fold := have >> 1
have -= fold
elem := have * m.each
from := make([]cgen.Gen, lanes)
i := lanes - 1
for ; elem < stop; elem++ {
from[i] = il(elem)
i--
}
for ; i >= 0; i-- {
from[i] = cgen.Zero
}
mask := 1<<uint(fold*m.each) - 1
assign(call(
m.dat, il(mask),
m.dat, perm(from),
))
}
elem := 0
from := make([]cgen.Gen, lanes)
for i := lanes - 1; i >= 0; i-- {
from[i] = il(elem)
if elem++; elem == m.each {
elem = 0
}
}
assign(perm(from))
return stmts.Append(to)
}

type m512Denom struct {
*Ctx
ptrs []cgen.Gen
pitches []cgen.Gen
loopCnt cgen.Gen
loopMask cgen.Gen
edgeMask cgen.Gen
divisor cgen.Gen
}

func (m *m512Denom) Append(to []byte) []byte {
var (
outer = make(cgen.Stmts, 2, 3)
rcp = vb(m.name("rcp"))
i = vb(m.name("i"))
)
outer[0] = cgen.Var{
Type: avx.M512, What: rcp,
Init: avx.Mm512DivPs{
avx.Mm512Set1PsLit(1),
m.divisor,
},
}
iter := func(mask, idx cgen.Gen) cgen.Gen {
var (
inner = make(cgen.Stmts, 2+len(m.ptrs))
dat = vb(m.name("dat"))
)
inner[0] = cgen.Var{
Type: avx.M512, What: dat,
Init: avx.Mm512MaskzLoaduPs{
mask,
addr(m.ptrs[0], m.pitches[0], idx),
},
}
inner[1] = cgen.Assign{
Expr1: dat,
Expr2: avx.Mm512MulPs{rcp, dat},
}
for j, ptr := range m.ptrs {
inner[2+j] = avx.Mm512MaskStoreuPs{
addr(ptr, m.pitches[j], idx),
mask, dat,
}
}
return inner
}
outer[1] = cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT, What: i,
Init: cgen.Zero,
},
Cond: cgen.CmpL{Expr1: i, Expr2: m.loopCnt},
Post: cgen.IncPre{Expr: i},
Body: iter(m.loopMask, i),
}
if m.edgeMask != nil {
outer = append(outer, iter(
m.edgeMask, m.loopCnt,
))
}
return outer.Append(to)
}

Top || internal/compile/author/strider/strider.go

package strider

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/quadfft"
"NN-512/internal/compile/author/sumr"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

func btoi(b bool) int {
if b {
return 1
}
return 0
}

func min(x, y int) int {
if x <= y {
return x
}
return y
}

func max(x, y int) int {
if x >= y {
return x
}
return y
}

func ceilQuo(n, d int) int {
return (n + d - 1) / d
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func loMask(n int) cgen.Gen {
return il(1<<uint(n) - 1)
}

func addMul(a, b, c cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: a,
Expr2: cgen.Mul{
Expr1: b,
Expr2: c,
},
}
}

func mix(a []cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Ctx struct {
prefix string
platform raw.Platform
cacheBytes1 int
cacheBytes2 int
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]interface{}
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Strider",
platform: pl.Config.Platform,
cacheBytes1: pl.Config.L1DataCachePerThread,
cacheBytes2: pl.Config.L2CachePerThreadExL1,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]interface{}),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

type Spec struct {
From SpecFrom
Filts []SpecFilts
To SpecTo
FilterH int
FilterW int
PaddingH int
PaddingW int
DilationH int
DilationW int
Groups int
}

type SpecFrom struct {
Chans int
Height int
Width int
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type SpecFilts struct {
Cnt int
BnPre int
BnPost int
}

type SpecTo struct {
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type form struct {
padH int
padW int
datH int
datW int
yieldH int
yieldW int
}

type loopB struct {
fromH int
fromW int
fromStep int
blkFirst int
blkPast int
form
}

type loopW struct {
fromH int
fromW int
fromStep int
segFirst int
segPast int
lbs []*loopB
}

type loopH struct {
fromH int
fromStep int
segFirst int
segStep int
segPast int
lws []*loopW
}

type segments struct {
cnt int
lhs []*loopH
}

func newSegments(ctx *Ctx, spec *Spec, segBlks int) *segments {
var (
segs segments
lb1 loopB
lb2 loopB
lw1 loopW
lw2 loopW
lh1 loopH
idx map[int]int
tie int
at int
)
layer7 := func() {
lh := func(lws []*loopW) *loopH {
var (
fromH = lws[0].fromH
segFirst = lws[0].segFirst
segPast = lws[len(lws)-1].segPast
)
for _, lw := range lws {
lw.fromH -= fromH
lw.segFirst -= segFirst
lw.segPast -= segFirst
}
return &loopH{
fromH: fromH,
fromStep: 0,
segFirst: segFirst,
segStep: 0,
segPast: segPast,
lws: lws,
}
}
var (
i = tie
n = len(lh1.lws)
)
if i == -1 {
i = n
}
if i > 0 {
pre := lh(lh1.lws[:i])
segs.lhs = append(
segs.lhs, pre,
)
}
if i < n {
cyc := lh(lh1.lws[i:])
cyc.fromStep = lh1.fromStep
cyc.segStep = lh1.segStep
cyc.segPast = lh1.segPast
segs.lhs = append(
segs.lhs, cyc,
)
}
}
layer6 := func(flush bool) {
match := func(i int) bool {
var (
lw = lh1.lws[i]
n1 = lw.segPast - lw.segFirst
n2 = lw2.segPast - lw2.segFirst
)
if n1 != n2 {
return false
}
if len(lw.lbs) != len(lw2.lbs) {
return false
}
for i, lb := range lw.lbs {
if *lb != *lw2.lbs[i] {
return false
}
}
return true
}
var (
cut = false
pre = false
cyc = false
)
switch {
case lh1.lws == nil:
cut = true
case tie == -1:
i, ok := idx[lw2.fromW]
if ok && match(i) {
lw := lh1.lws[i]
lh1.fromStep = lw2.fromH - lw.fromH
lh1.segStep = lw2.segFirst - lw.segFirst
tie = i
at = i
cyc = true
} else {
pre = true
}
case match(at):
cyc = true
default:
cut = true
}
switch {
case cut:
if lh1.lws != nil {
layer7()
lh1.lws = nil
}
idx = make(map[int]int)
tie = -1
fallthrough
case pre:
lw := lw2
lw.lbs = make([]*loopB, len(lw2.lbs))
for i, lb := range lw2.lbs {
lb := *lb
lw.lbs[i] = &lb
}
idx[lw.fromW] = len(lh1.lws)
lh1.lws = append(
lh1.lws, &lw,
)
case cyc:
lh1.segPast = lw2.segPast
if at++; at == len(lh1.lws) {
at = tie
}
}
if flush {
layer7()
}
}
layer5 := func(flush bool) {
split := true
switch {
case lw2.fromH != lw1.fromH:
case len(lw2.lbs) != len(lw1.lbs):
default:
split = false
for i, lb := range lw2.lbs {
if *lb != *lw1.lbs[i] {
split = true
break
}
}
}
switch {
case split:
if lw2.segFirst < lw2.segPast {
layer6(false)
}
swap := lw2.lbs
lw2 = lw1
lw1.lbs = swap
default:
if lw2.fromStep == 0 {
lw2.fromStep = lw1.fromW - lw2.fromW
}
lw2.segPast = lw1.segPast
}
if flush {
layer6(true)
}
}
layer4 := func(flush bool) {
n := len(lw1.lbs)
if lb2.blkFirst == 0 {
if n > 0 {
layer5(false)
}
lw1.fromH = lb2.fromH
lw1.fromW = lb2.fromW
lw1.segFirst = segs.cnt
lw1.segPast = segs.cnt + 1
segs.cnt++
if lw1.lbs == nil {
lw1.lbs = make([]*loopB, segBlks)
for i := range lw1.lbs {
lw1.lbs[i] = new(loopB)
}
}
n = 0
}
lw1.lbs = lw1.lbs[:n+1]
lb := lw1.lbs[n]
*lb = lb2
lb.fromH -= lw1.fromH
lb.fromW -= lw1.fromW
if flush {
layer5(true)
}
}
layer3 := func(flush bool) {
if flush {
layer4(true)
return
}
switch {
case lb1.blkFirst == 0:
case lb2.fromH != lb1.fromH:
case lb2.form != lb1.form:
default:
if lb2.fromStep == 0 {
lb2.fromStep = lb1.fromW - lb2.fromW
}
lb2.blkPast = lb1.blkPast
return
}
if lb2.blkFirst < lb2.blkPast {
layer4(false)
}
lb2 = lb1
}
layer2 := func() {
var (
h1 = spec.PaddingH
h2 = h1 + spec.From.Height
h3 = h2 + spec.PaddingH
w1 = spec.PaddingW
w2 = w1 + spec.From.Width
w3 = w2 + spec.PaddingW
filtH = 1 + (spec.FilterH-1)*spec.DilationH
filtW = 1 + (spec.FilterW-1)*spec.DilationW
yieldH = 1 + (16-filtH)/2
yieldW = 1 + (16-filtW)/2
blk = 0
)
if filtH > 16 || filtW > 16 {
panic("bug")
}
for h := 0; h+filtH <= h3; h += yieldH * 2 {
for w := 0; w+filtW <= w3; w += yieldW * 2 {
lb1.fromH = h
lb1.fromW = w
lb1.blkFirst = blk
lb1.blkPast = blk + 1
if blk++; blk == segBlks {
blk = 0
}
lb1.padH = min(max(h1-h, 0), 16)
lb1.padW = min(max(w1-w, 0), 16)
lb1.datH = min(max(h2-h, 0), 16) - lb1.padH
lb1.datW = min(max(w2-w, 0), 16) - lb1.padW
if lb1.datH == 0 || lb1.datW == 0 {
lb1.padH = 16
lb1.padW = 16
lb1.datH = 0
lb1.datW = 0
}
lb1.yieldH = min(1+(h3-h-filtH)/2, yieldH)
lb1.yieldW = min(1+(w3-w-filtW)/2, yieldW)
layer3(false)
}
}
layer3(true)
}
layer1 := func() *segments {
sig := fmt.Sprint(
"newSegments",
" ",
spec.From.Height,
spec.From.Width,
spec.FilterH,
spec.FilterW,
spec.PaddingH,
spec.PaddingW,
spec.DilationH,
spec.DilationW,
segBlks,
)
if prior, ok := ctx.dedup[sig]; ok {
return prior.(*segments)
}
ctx.dedup[sig] = &segs
layer2()
return &segs
}
return layer1()
}

type layout struct {
segs *segments
blkZones int
zoneFrags int
fromChans int
toChans int
slices1 int
slices2 int
epochs1 int
epochs2 int
alignment int
biasBytes int
bfFragBytes int
bfMeldBytes int
bfGroupBytes int
bfEpochBytes int
bfTotalBytes int
wtBytes int
wfFragBytes int
wfMeldFrags int
wfMeldBytes int
wfSliceFrags1 int
wfSliceFrags2 int
wfSliceMelds1 int
wfSliceMelds2 int
wfSliceBytes1 int
wfSliceBytes2 int
wfCores1 int
wfCores2 int
wfCoreBytes11 int
wfCoreBytes12 int
wfCoreBytes21 int
wfCoreBytes22 int
wfPileBytes1 int
wfPileBytes2 int
wfGroupBytes1 int
wfGroupBytes2 int
wfZoneBytes1 int
wfZoneBytes2 int
wfEpochBytes1 int
wfEpochBytes2 int
wfTotalBytes int
datBytes int
dfFragBytes int
dfMeldFrags int
dfMeldBytes int
dfSliceFrags1 int
dfSliceFrags2 int
dfSliceMelds1 int
dfSliceMelds2 int
dfSliceBytes1 int
dfSliceBytes2 int
dfCores1 int
dfCores2 int
dfCoreBytes11 int
dfCoreBytes12 int
dfCoreBytes21 int
dfCoreBytes22 int
dfPileBytes1 int
dfPileBytes2 int
dfGroupBytes1 int
dfGroupBytes2 int
dfZoneBytes1 int
dfZoneBytes2 int
dfEpochBytes1 int
dfEpochBytes2 int
dfTotalBytes int
sfFragBytes int
sfMeldBytes11 int
sfMeldBytes12 int
sfMeldBytes21 int
sfMeldBytes22 int
sfRowBytes11 int
sfRowBytes12 int
sfRowBytes21 int
sfRowBytes22 int
sfSiteBytes11 int
sfSiteBytes12 int
sfSiteBytes21 int
sfSiteBytes22 int
sfCoreBytes1 int
sfCoreBytes2 int
sfPileBytes int
sfGroupBytes int
sfTotalBytes int
}

func newLayout(ctx *Ctx, spec *Spec) *layout {
var (
y layout
)
layer9 := func() {
y.dfCoreBytes11 = y.slices1 * y.dfSliceBytes1
y.dfCoreBytes12 = y.slices1 * y.dfSliceBytes2
y.dfCoreBytes21 = y.slices2 * y.dfSliceBytes1
y.dfCoreBytes22 = y.slices2 * y.dfSliceBytes2
y.dfPileBytes1 = y.dfCores1*y.dfCoreBytes11 + y.dfCoreBytes12
y.dfPileBytes2 = y.dfCores1*y.dfCoreBytes21 + y.dfCoreBytes22
y.dfGroupBytes1 = y.zoneFrags * y.dfPileBytes1
y.dfGroupBytes2 = y.zoneFrags * y.dfPileBytes2
y.dfZoneBytes1 = spec.Groups * y.dfGroupBytes1
y.dfZoneBytes2 = spec.Groups * y.dfGroupBytes2
y.dfEpochBytes1 = y.blkZones * y.dfZoneBytes1
y.dfEpochBytes2 = y.blkZones * y.dfZoneBytes2
y.dfTotalBytes = y.epochs1*y.dfEpochBytes1 + y.dfEpochBytes2
}
layer8 := func() {
y.wfCoreBytes11 = y.slices1 * y.wfSliceBytes1
y.wfCoreBytes12 = y.slices1 * y.wfSliceBytes2
y.wfCoreBytes21 = y.slices2 * y.wfSliceBytes1
y.wfCoreBytes22 = y.slices2 * y.wfSliceBytes2
y.wfPileBytes1 = y.wfCores1*y.wfCoreBytes11 + y.wfCoreBytes12
y.wfPileBytes2 = y.wfCores1*y.wfCoreBytes21 + y.wfCoreBytes22
y.wfGroupBytes1 = y.zoneFrags * y.wfPileBytes1
y.wfGroupBytes2 = y.zoneFrags * y.wfPileBytes2
y.wfZoneBytes1 = spec.Groups * y.wfGroupBytes1
y.wfZoneBytes2 = spec.Groups * y.wfGroupBytes2
y.wfEpochBytes1 = y.blkZones * y.wfZoneBytes1
y.wfEpochBytes2 = y.blkZones * y.wfZoneBytes2
y.wfTotalBytes = y.epochs1*y.wfEpochBytes1 + y.wfEpochBytes2
layer9()
}
layer7 := func() {
y.bfMeldBytes = y.wfMeldFrags * y.bfFragBytes
y.bfGroupBytes = ceilQuo(y.toChans, y.wfMeldFrags) * y.bfMeldBytes
y.bfEpochBytes = spec.Groups * y.bfGroupBytes
y.bfTotalBytes = y.epochs2 * y.bfEpochBytes
y.bfTotalBytes += y.alignment - 1
y.bfTotalBytes &= -y.alignment
layer8()
}
layer6 := func() {
wfSliceBytes := y.wfSliceBytes1
if y.wfCores1 == 0 {
wfSliceBytes = y.wfSliceBytes2
}
dfSliceBytes := y.dfSliceBytes1
if y.dfCores1 == 0 {
dfSliceBytes = y.dfSliceBytes2
}
switch ctx.platform {
case raw.AVX512Float32:
var (
sliceBytes = 2*wfSliceBytes + dfSliceBytes
cacheBytes = ctx.cacheBytes1 + ctx.cacheBytes2
)
const (
empirical1 = 4
empirical2 = 256
empirical3 = 4
)
y.slices1 = cacheBytes / empirical1 / sliceBytes
y.slices1 = max(y.slices1, empirical2)
y.slices2 = y.fromChans % y.slices1
y.epochs1 = y.fromChans / y.slices1
y.epochs2 = y.epochs1 + btoi(y.slices2 > 0)
if y.epochs1 > 0 && y.epochs1 < y.epochs2 {
if y.slices2*empirical3 < y.slices1 {
y.slices2 += y.slices1
y.epochs1--
y.epochs2--
}
}
default:
panic("bug")
}
layer7()
}
layer5 := func() {
var (
wfDiv = y.wfMeldFrags
wfQuo = y.wfSliceFrags2 / wfDiv
wfRem = y.wfSliceFrags2 % wfDiv
dfDiv = y.dfMeldFrags
dfQuo = y.dfSliceFrags2 / dfDiv
dfRem = y.dfSliceFrags2 % dfDiv
)
y.sfMeldBytes11 = wfDiv * dfDiv * y.sfFragBytes
y.sfMeldBytes12 = wfDiv * dfRem * y.sfFragBytes
y.sfMeldBytes21 = wfRem * dfDiv * y.sfFragBytes
y.sfMeldBytes22 = wfRem * dfRem * y.sfFragBytes
y.sfRowBytes11 = y.dfSliceMelds1 * y.sfMeldBytes11
y.sfRowBytes12 = dfQuo*y.sfMeldBytes11 + y.sfMeldBytes12
y.sfRowBytes21 = y.dfSliceMelds1 * y.sfMeldBytes21
y.sfRowBytes22 = dfQuo*y.sfMeldBytes21 + y.sfMeldBytes22
y.sfSiteBytes11 = y.wfSliceMelds1 * y.sfRowBytes11
y.sfSiteBytes12 = y.wfSliceMelds1 * y.sfRowBytes12
y.sfSiteBytes21 = wfQuo*y.sfRowBytes11 + y.sfRowBytes21
y.sfSiteBytes22 = wfQuo*y.sfRowBytes12 + y.sfRowBytes22
y.sfCoreBytes1 = y.wfCores1*y.sfSiteBytes11 + y.sfSiteBytes21
y.sfCoreBytes2 = y.wfCores1*y.sfSiteBytes12 + y.sfSiteBytes22
y.sfPileBytes = y.dfCores1*y.sfCoreBytes1 + y.sfCoreBytes2
y.sfGroupBytes = y.zoneFrags * y.sfPileBytes
y.sfTotalBytes = spec.Groups * y.sfGroupBytes
layer6()
}
layer4 := func() {
y.dfMeldBytes = y.dfMeldFrags * y.dfFragBytes
y.dfSliceFrags1 = y.dfSliceMelds1 * y.dfMeldFrags
y.segs = newSegments(ctx, spec, y.dfSliceFrags1)
var (
lh = y.segs.lhs[len(y.segs.lhs)-1]
lw = lh.lws[len(lh.lws)-1]
lb = lw.lbs[len(lw.lbs)-1]
)
y.dfSliceFrags2 = lb.blkPast
if y.dfSliceFrags2 == y.dfSliceFrags1 {
y.dfSliceFrags2 = 0
}
y.dfSliceMelds2 = ceilQuo(y.dfSliceFrags2, y.dfMeldFrags)
y.dfSliceBytes1 = y.dfSliceMelds1 * y.dfMeldBytes
y.dfSliceBytes2 = y.dfSliceMelds2 * y.dfMeldBytes
y.dfCores1 = y.segs.cnt - btoi(y.dfSliceFrags2 > 0)
y.dfCores2 = y.segs.cnt
layer5()
}
layer3 := func() {
y.wfMeldBytes = y.wfMeldFrags * y.wfFragBytes
y.wfSliceFrags1 = y.wfSliceMelds1 * y.wfMeldFrags
y.wfSliceFrags2 = y.toChans % y.wfSliceFrags1
y.wfSliceMelds2 = ceilQuo(y.wfSliceFrags2, y.wfMeldFrags)
y.wfSliceBytes1 = y.wfSliceMelds1 * y.wfMeldBytes
y.wfSliceBytes2 = y.wfSliceMelds2 * y.wfMeldBytes
y.wfCores1 = y.toChans / y.wfSliceFrags1
y.wfCores2 = y.wfCores1 + btoi(y.wfSliceFrags2 > 0)
layer4()
}
layer2 := func() {
if len(spec.Filts) > 1 && spec.Groups > 1 {
panic("bug")
}
filts := 0
for i := range spec.Filts {
filts += spec.Filts[i].Cnt
}
y.fromChans = spec.From.Chans / spec.Groups
y.toChans = filts / spec.Groups
layer3()
}
layer1 := func() *layout {
switch ctx.platform {
case raw.AVX512Float32:
y.blkZones = 4
y.zoneFrags = 4
y.alignment = 64
y.biasBytes = 4
y.bfFragBytes = 4
y.wtBytes = 4
y.wfFragBytes = 32
y.wfMeldFrags = 2
y.wfSliceMelds1 = 2
y.datBytes = 4
y.dfFragBytes = 64
y.dfMeldFrags = 2
y.dfSliceMelds1 = 3
y.sfFragBytes = 64
default:
panic("bug")
}
layer2()
return &y
}
return layer1()
}

type ArrangeFilts struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeFilts) Prep() cgen.Gen {
a.layout = newLayout(a.Ctx, a.Spec)
const affix = "ArrangeFilts"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeFilts{ArrangeFilts: a},
cgen.Newline,
}
}

func (a *ArrangeFilts) Bytes() int {
return a.bfTotalBytes + a.wfTotalBytes
}

func (a *ArrangeFilts) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeFilts struct {
*ArrangeFilts
bundleFilts int
bundleTile int
bundleTiles int
bundleScrap int
bundleHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
bundleCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
slices int
coreBytes int
pileBytes int
groupBytes int
zoneBytes int
epochFirst int
epochCnt int
bfPtr cgen.Gen
wfPtr cgen.Gen
filtsIdx int
wtPtr cgen.Gen
biasPtr cgen.Gen
bnPtrs []cgen.Gen
groupIdx cgen.Gen
bundleIdx cgen.Gen
bundleLast cgen.Gen
baseFilt int
baseBundle int
filts1 int
filts2 int
repeat bool
}

func (a *arrangeFilts) Append(to []byte) []byte {
var (
threadBlks int
groupBundles int
)
switch a.platform {
case raw.AVX512Float32:
a.bundleFilts = a.wfMeldFrags
threadBlks = 128
default:
panic("bug")
}
switch len(a.Filts) {
case 1:
groupBundles = ceilQuo(a.toChans, a.bundleFilts)
default:
for i := range a.Filts {
filts := a.Filts[i].Cnt
groupBundles += ceilQuo(filts, a.bundleFilts)
}
}
var (
filtBlks = ceilQuo(a.fromChans, a.epochs2)
bundleBlks = a.bundleFilts * filtBlks
groupBlks = a.toChans * filtBlks
)
switch {
case threadBlks <= bundleBlks:
a.bundleTile = 1
a.bundleTiles = groupBundles
a.bundleScrap = 0
a.bundleHull = groupBundles
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
case threadBlks <= groupBlks:
var (
tile = ceilQuo(threadBlks, bundleBlks)
tiles = max(groupBundles/tile, 1)
)
a.bundleTile = groupBundles / tiles
a.bundleTiles = tiles
a.bundleScrap = groupBundles - tiles*a.bundleTile
a.bundleHull = tiles
if a.bundleScrap > 0 {
a.bundleTiles--
a.bundleScrap += a.bundleTile
}
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
default:
a.bundleTile = groupBundles
a.bundleTiles = 1
a.bundleScrap = 0
a.bundleHull = 1
var (
tile = ceilQuo(threadBlks, groupBlks)
tiles = max(a.Groups/tile, 1)
)
a.groupTile = a.Groups / tiles
a.groupTiles = tiles
a.groupScrap = a.Groups - tiles*a.groupTile
a.groupHull = tiles
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.bundleHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeFilts) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 7)
usedPt = false
)
a.tensors = vb(a.name("tensors"))
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: a.tensors,
Init: callee.Any(),
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(a.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
a.bundleCoord = coord("b", a.bundleHull, 0)
a.groupCoord = coord("g", a.groupHull, 1)
a.epochCoord = coord("e", a.epochs2, 2)
if !usedPt {
body[4] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
impl := func() cgen.Gen {
var assn cgen.Gen
if a.epochs2 > 1 && a.epochCnt == 1 {
assn = cgen.Assign{
Expr1: a.epochCoord,
Expr2: il(a.epochFirst),
}
}
return cgen.Stmts{
assn,
a.kernel1(),
}
}
if a.epochs1 > 0 {
a.slices = a.slices1
a.coreBytes = a.wfCoreBytes11
a.pileBytes = a.wfPileBytes1
a.groupBytes = a.wfGroupBytes1
a.zoneBytes = a.wfZoneBytes1
a.epochFirst = 0
a.epochCnt = a.epochs1
put := impl()
if a.epochs1 < a.epochs2 {
put = cgen.If{
Cond: cgen.CmpL{
Expr1: a.epochCoord,
Expr2: il(a.epochs1),
},
Then: cgen.Stmts{
put,
cgen.Return{},
},
}
}
body[5] = put
}
if a.epochs1 < a.epochs2 {
a.slices = a.slices2
a.coreBytes = a.wfCoreBytes21
a.pileBytes = a.wfPileBytes2
a.groupBytes = a.wfGroupBytes2
a.zoneBytes = a.wfZoneBytes2
a.epochFirst = a.epochs1
a.epochCnt = 1
body[6] = impl()
}
return callee.Func(body)
}

func (a *arrangeFilts) kernel1() cgen.Gen {
var (
n = len(a.Filts)
savedFiltsIdx = 0
savedTensorIdx = 0
)
tensor := func(filtsIdx, off int) cgen.Gen {
if savedFiltsIdx != filtsIdx {
savedFiltsIdx = filtsIdx
at := 0
for x := 0; x < filtsIdx; x++ {
at += 2
at += a.Filts[x].BnPre
at += a.Filts[x].BnPost
}
savedTensorIdx = at
}
return cgen.Elem{
Arr: a.tensors,
Idx: il(savedTensorIdx + off),
}
}
ptrDecls := func(filtsIdx int) cgen.Gen {
wtDecl := func() cgen.Gen {
a.wtPtr = vb(a.name("wtPtr"))
filtHW := a.FilterH * a.FilterW
return cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.wtPtr,
Init: addMul(
tensor(filtsIdx, 0),
il(a.slices1*filtHW*a.wtBytes),
a.epochCoord,
),
}
}
biasDecl := func() cgen.Gen {
if a.epochFirst == 0 {
a.biasPtr = vb(a.name("biasPtr"))
return cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.biasPtr,
Init: tensor(filtsIdx, 1),
}
}
a.biasPtr = nil
return nil
}
bnDecls := func() cgen.Gen {
var (
pre = a.Filts[filtsIdx].BnPre
post = a.Filts[filtsIdx].BnPost
ret = make(cgen.Stmts, pre+post)
)
a.bnPtrs = make([]cgen.Gen, pre+post)
for x := range a.bnPtrs {
var (
bnPtr = vb(a.name("bnPtr"))
expr = tensor(filtsIdx, 2+x)
)
if x < pre {
expr = &bn.Offset{
Ctx: a.bc,
Mas: expr,
Channel: cgen.Mul{
Expr1: il(a.slices1),
Expr2: a.epochCoord,
},
}
}
ret[x] = cgen.Var{
Type: cgen.RestrictPtrChar,
What: bnPtr, Init: expr,
}
a.bnPtrs[x] = bnPtr
}
return ret
}
a.filtsIdx = filtsIdx
return cgen.Stmts{
wtDecl(),
biasDecl(),
bnDecls(),
}
}
layer5 := func() cgen.Gen {
if n == 1 {
a.baseFilt = 0
a.baseBundle = 0
return a.kernel2()
}
var (
atFilt = make([]int, n+1)
atBundle = make([]int, n+1)
)
for x := 0; x < n; x++ {
var (
filts = a.Filts[x].Cnt
bundles = ceilQuo(filts, a.bundleFilts)
)
atFilt[x+1] = atFilt[x] + filts
atBundle[x+1] = atBundle[x] + bundles
}
leaf := func(x int) cgen.Stmts {
a.baseFilt = atFilt[x]
a.baseBundle = atBundle[x]
var assn cgen.Gen
if x+1 < n {
assn = cgen.Assign{
Expr1: a.bundleIdx,
Expr2: il(atBundle[x+1]),
}
}
return cgen.Stmts{
ptrDecls(x),
a.kernel2(),
assn,
}
}
var tree func(int, int) cgen.Stmts
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = atBundle[first]
stop = atBundle[last+1]
split = start + (stop-start)/2
x = first + 1
)
for atBundle[x+1] <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: a.bundleIdx,
Expr2: il(atBundle[x]),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, n-1)
}
layer4 := func() cgen.Gen {
a.bundleIdx = vb(a.name("j"))
switch a.bundleHull {
case 1:
a.bundleLast = nil
default:
a.bundleLast = vb(a.name("jj"))
}
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.bundleIdx,
Init: cgen.Mul{
Expr1: il(a.bundleTile),
Expr2: a.bundleCoord,
},
}
if a.bundleLast != nil {
var expr cgen.Gen
switch a.bundleTiles {
case a.bundleHull:
expr = il(a.bundleTile - 1)
case 0:
expr = il(a.bundleScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.bundleCoord,
Expr2: il(a.bundleTiles),
},
Then: il(a.bundleTile - 1),
Else: il(a.bundleScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.bundleLast,
Init: cgen.Add{
Expr1: a.bundleIdx,
Expr2: expr,
},
}
}
stmts[2] = layer5()
return stmts
}
layer3 := func() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: cgen.Mul{
Expr1: il(a.groupTile),
Expr2: a.groupCoord,
},
}
switch a.groupTiles {
case a.groupHull:
iters = a.groupTile
case 0:
iters = a.groupScrap
}
switch iters {
case 1:
stmts[2] = layer4()
default:
var (
last = vb(a.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile - 1),
Else: il(a.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: a.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: layer4(),
}
}
return stmts
}
layer2 := func() cgen.Gen {
var decls cgen.Gen
if n == 1 {
decls = ptrDecls(0)
}
return cgen.Gens{
decls,
layer3(),
}
}
layer1 := func() cgen.Gen {
a.bfPtr = vb(a.name("bfPtr"))
a.wfPtr = vb(a.name("wfPtr"))
return cgen.Stmts{
cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.bfPtr,
Init: addMul(
tensor(n, 0),
il(a.bfEpochBytes),
a.epochCoord,
),
},
cgen.Var{
Type: cgen.RestrictPtrChar,
What: a.wfPtr,
Init: addMul(
cgen.Add{
Expr1: tensor(n, 0),
Expr2: il(a.bfTotalBytes),
},
il(a.wfEpochBytes1),
a.epochCoord,
),
},
layer2(),
}
}
return layer1()
}

func (a *arrangeFilts) kernel2() cgen.Gen {
var (
filts1 int
filts2 int
repeat bool
)
layer3 := func() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}
layer2 := func() cgen.Gen {
var (
retIf cgen.Gen
past = a.baseBundle
)
if a.bundleLast != nil {
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.bundleIdx,
Expr2: a.bundleLast,
},
Then: cgen.Return{},
}
}
do := func(bundles int) cgen.Gen {
past += bundles
if bundles == 1 {
return cgen.If{
Cond: cgen.CmpE{
Expr1: a.bundleIdx,
Expr2: il(past - 1),
},
Then: cgen.Stmts{
layer3(),
retIf,
cgen.Assign{
Expr1: a.bundleIdx,
Expr2: il(past),
},
},
}
}
return cgen.If{
Cond: cgen.CmpL{
Expr1: a.bundleIdx,
Expr2: il(past),
},
Then: cgen.Stmts{
cgen.For{
Cond: cgen.CmpNE{
Expr1: a.bundleIdx,
Expr2: il(past),
},
Post: cgen.IncPre{
Expr: a.bundleIdx,
},
Body: cgen.Stmts{
layer3(),
retIf,
},
},
},
}
}
var (
stmts = make(cgen.Stmts, 4)
quo1 = filts1 / a.bundleFilts
rem1 = filts1 - a.bundleFilts*quo1
tail = filts2 - a.bundleFilts*quo1
)
if quo1 > 0 {
a.filts1 = a.bundleFilts
a.filts2 = a.bundleFilts
a.repeat = false
stmts[0] = do(quo1)
}
if rem1 > 0 {
a.filts1 = rem1
a.filts2 = min(tail, a.bundleFilts)
tail -= a.filts2
a.repeat = repeat && tail == 0
stmts[1] = do(1)
}
if tail > 0 {
var (
head = tail - btoi(repeat)
quo2 = head / a.bundleFilts
rem2 = tail - a.bundleFilts*quo2
)
if quo2 > 0 {
a.filts1 = 0
a.filts2 = a.bundleFilts
a.repeat = false
stmts[2] = do(quo2)
}
if rem2 > 0 {
a.filts1 = 0
a.filts2 = rem2
a.repeat = repeat
stmts[3] = do(1)
}
}
return stmts
}
layer1 := func() cgen.Gen {
switch len(a.Filts) {
case 1:
filts2 = a.toChans
default:
filts2 = a.Filts[a.filtsIdx].Cnt
}
var (
past = a.baseFilt + filts2
split = a.toChans - a.wfSliceFrags2
clamp1 = max(past-split, 0)
clamp2 = min(clamp1, filts2)
)
filts1 = filts2 - clamp2
repeat = past == a.toChans &&
past%a.wfMeldFrags > 0
return layer2()
}
return layer1()
}

func (a *arrangeFilts) m512() cgen.Gen {
var (
bfs []cgen.Gen
preCnt int
postMuls []cgen.Gen
sliceIdx cgen.Gen
preMul1 cgen.Gen
preAdd1 cgen.Gen
filtIdx int
wts []cgen.Gen
fwd *quadfft.Fwd
coreIdx cgen.Gen
meldIdx cgen.Gen
fragIdx cgen.Gen
eo cgen.Gen
pileIdx int
zoneIdx int
wfs cgen.Gen
)
layer17 := func() cgen.Gen {
emit := func(side int) cgen.Gen {
var (
stmts = make(cgen.Stmts, 2)
to = a.wfPtr
slicePitch = a.wfSliceBytes1
fragPitch = a.wfFragBytes / 2
back = side * fragPitch
mask = 0x0f0f << uint(side*4)
from = wfs
)
if filtIdx >= a.filts1 {
slicePitch = a.wfSliceBytes2
}
if filtIdx == a.filts2-1 && a.repeat {
back = 0
mask = 0xffff
from = vb(a.name("rep"))
ctrl := (side+2)<<4 | side
ctrl |= ctrl << 2
stmts[0] = cgen.Var{
Type: avx.M512i, What: from,
Init: avx.Mm512ShuffleI32x4{
wfs, wfs, il(ctrl),
},
}
}
to = cgen.Add{
Expr1: to,
Expr2: il(
(zoneIdx+side)*a.zoneBytes +
pileIdx*a.pileBytes -
back,
),
}
to = addMul(to, il(a.groupBytes), a.groupIdx)
to = addMul(to, il(a.coreBytes), coreIdx)
to = addMul(to, il(slicePitch), sliceIdx)
to = addMul(to, il(a.wfMeldBytes), meldIdx)
to = addMul(to, il(fragPitch), fragIdx)
stmts[1] = avx.Mm512MaskStoreuEpi32{
to, il(mask), from,
}
return stmts
}
return cgen.Gens{
emit(0),
emit(1),
}
}
layer16 := func() cgen.Gen {
wfs = vb(a.name("wfs"))
var (
x = pileIdx*2 + zoneIdx/2*8
wf1 = fwd.Out[x]
wf2 = fwd.Out[x+1]
)
perm := func(wf cgen.Gen) cgen.Gen {
if pileIdx == 0 {
return nil
}
return cgen.Assign{
Expr1: wf,
Expr2: avx.Mm512PermutexvarPs{
eo, wf,
},
}
}
cvt := func(wf cgen.Gen) cgen.Gen {
return avx.Mm512CvtpsPh{
wf, avx.FroundToNearestIntNoExc,
}
}
return cgen.Stmts{
perm(wf1),
perm(wf2),
cgen.Var{
Type: avx.M512i, What: wfs,
Init: avx.Mm512Castsi256Si512{
cvt(wf1),
},
},
cgen.Assign{
Expr1: wfs,
Expr2: avx.Mm512Inserti64x4{
wfs, cvt(wf2), il(1),
},
},
layer17(),
}
}
layer15 := func() cgen.Gen {
var (
n1 = a.zoneFrags
n2 = a.blkZones / 2
gens = make(cgen.Gens, n1*n2)
)
for p := 0; p < n1; p++ {
pileIdx = (p + 1) % n1
for z := 0; z < n2; z++ {
zoneIdx = z * 2
gens[p*n2+z] = layer16()
}
}
return gens
}
layer14 := func() cgen.Gen {
eo = vb(a.name("eo"))
set := make(avx.Mm512SetEpi32, 16)
for x := 0; x < 16; x++ {
set[15-x] = il(x%8*2 + x/8)
}
return cgen.Stmts{
cgen.Var{
Type: avx.M512i, What: eo,
Init: set,
},
layer15(),
}
}
layer13 := func() cgen.Gen {
coreIdx = vb(a.name("c"))
meldIdx = vb(a.name("m"))
fragIdx = vb(a.name("f"))
var (
add = a.baseFilt + filtIdx
sub = a.baseBundle * a.bundleFilts
expr = cgen.Cast{
Type: cgen.SizeT,
Expr: cgen.Paren{
Inner: addMul(
il(add-sub),
il(a.bundleFilts),
a.bundleIdx,
),
},
}
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: coreIdx,
Init: cgen.Quo{
Expr1: expr,
Expr2: il(a.wfSliceFrags1),
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: meldIdx,
Init: cgen.Quo{
Expr1: cgen.Rem{
Expr1: expr,
Expr2: il(a.wfSliceFrags1),
},
Expr2: il(a.wfMeldFrags),
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: fragIdx,
Init: cgen.Rem{
Expr1: expr,
Expr2: il(a.wfMeldFrags),
},
},
layer14(),
}
}
layer12 := func() cgen.Gen {
fwd = &quadfft.Fwd{
Platform: a.platform,
Nms: a.nms,
}
for x, wt := range wts {
fwd.In[x*a.DilationH] = wt
}
for x := range &fwd.Out {
wf := vb(a.name("wf"))
fwd.Out[x] = wf
}
return cgen.Gens{
fwd,
layer13(),
}
}
layer11 := func() cgen.Gen {
if a.DilationW == 1 {
return layer12()
}
var (
last = 1 + len(wts)
stmts = make(cgen.Stmts, last+1)
dw = vb(a.name("dw"))
set = make(avx.Mm512SetEpi32, 16)
gap = il(15)
)
for x := 0; x < 16; x++ {
put := gap
if x%a.DilationW == 0 {
put = il(x / a.DilationW)
}
set[15-x] = put
}
stmts[0] = cgen.Var{
Type: avx.M512i, What: dw,
Init: set,
}
for x, wt := range wts {
stmts[1+x] = cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512PermutexvarPs{
dw, wt,
},
}
}
stmts[last] = layer12()
return stmts
}
layer10 := func() cgen.Gen {
if preCnt == 0 {
return layer11()
}
var (
last = len(wts) * 2
stmts = make(cgen.Stmts, last+1)
bf = bfs[filtIdx]
)
for x, wt := range wts {
stmts[x*2] = cgen.Assign{
Expr1: bf,
Expr2: avx.Mm512FmaddPs{
preAdd1, wt, bf,
},
}
stmts[x*2+1] = cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
preMul1, wt,
},
}
}
stmts[last] = layer11()
return stmts
}
layer9 := func() cgen.Gen {
if postMuls == nil {
return layer10()
}
var (
last = len(wts)
stmts = make(cgen.Stmts, last+1)
)
for x, wt := range wts {
stmts[x] = cgen.Assign{
Expr1: wt,
Expr2: avx.Mm512MulPs{
postMuls[filtIdx],
wt,
},
}
}
stmts[last] = layer10()
return stmts
}
layer8 := func() cgen.Gen {
wts = make([]cgen.Gen, a.FilterH)
var (
last = len(wts)
stmts = make(cgen.Stmts, last+1)
)
for h := range wts {
var (
wt = vb(a.name("wt"))
mask = loMask(a.FilterW)
ae = a.wtPtr
hPitch = a.FilterW * a.wtBytes
slicePitch = a.FilterH * hPitch
filtPitch = a.fromChans * slicePitch
bundlePitch = a.bundleFilts * filtPitch
groupPitch = a.toChans * filtPitch
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
-a.baseBundle*bundlePitch +
filtIdx*filtPitch +
h*hPitch,
),
}
ae = addMul(ae, il(groupPitch), a.groupIdx)
ae = addMul(ae, il(bundlePitch), a.bundleIdx)
ae = addMul(ae, il(slicePitch), sliceIdx)
wts[h] = wt
stmts[h] = cgen.Var{
Type: avx.M512, What: wt,
Init: avx.Mm512MaskzLoaduPs{
mask, ae,
},
}
}
stmts[last] = layer9()
return stmts
}
layer7 := func() cgen.Gen {
gens := make(cgen.Gens, a.filts2)
for x := range gens {
filtIdx = x
gens[x] = layer8()
}
return gens
}
layer6 := func() cgen.Gen {
if preCnt == 0 {
preMul1 = nil
preAdd1 = nil
return layer7()
}
var (
last = preCnt * 3
stmts = make(cgen.Stmts, last+1)
)
preCh := cgen.Paren{
Inner: addMul(
sliceIdx,
il(a.fromChans),
a.groupIdx,
),
}
for x, prePtr := range a.bnPtrs[:preCnt] {
var (
preMul2 = vb(a.name("preMul"))
preAdd2 = vb(a.name("preAdd"))
)
stmts[x*3] = &bn.Load{
Ctx: a.bc,
Mas: prePtr,
Channel: preCh,
Mul: preMul2,
Add: preAdd2,
}
if x == 0 {
preMul1 = preMul2
preAdd1 = preAdd2
continue
}
stmts[x*3+1] = cgen.Assign{
Expr1: preMul1,
Expr2: avx.Mm512MulPs{
preMul1, preMul2,
},
}
stmts[x*3+2] = cgen.Assign{
Expr1: preAdd1,
Expr2: avx.Mm512FmaddPs{
preAdd1, preMul2,
preAdd2,
},
}
}
stmts[last] = layer7()
return stmts
}
layer5 := func() cgen.Gen {
sliceIdx = vb(a.name("k"))
return cgen.Stmts{
cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: sliceIdx,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: sliceIdx,
Expr2: il(a.slices),
},
Post: cgen.IncPre{
Expr: sliceIdx,
},
Body: layer6(),
},
}
}
layer4 := func() cgen.Gen {
var (
postPtrs = a.bnPtrs[preCnt:]
postCnt = len(postPtrs)
)
switch postCnt {
case 0:
postMuls = nil
return layer5()
default:
postMuls = make([]cgen.Gen, a.filts2)
}
toMix := make([]cgen.Stmts, a.filts2)
for f := range toMix {
stmts := make(cgen.Stmts, postCnt*2)
postCh := cgen.Paren{
Inner: addMul(
addMul(
il(f-a.baseBundle*a.bundleFilts),
il(a.toChans),
a.groupIdx,
),
il(a.bundleFilts),
a.bundleIdx,
),
}
for x, postPtr := range postPtrs {
postMul := vb(a.name("postMul"))
stmts[x*2] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul,
}
if x == 0 {
postMuls[f] = postMul
continue
}
stmts[x*2+1] = cgen.Assign{
Expr1: postMuls[f],
Expr2: avx.Mm512MulPs{
postMuls[f],
postMul,
},
}
}
toMix[f] = stmts
}
return cgen.Gens{
mix(toMix),
layer5(),
}
}
layer3 := func() cgen.Gen {
var (
bias cgen.Gen
)
scale := func() cgen.Gen {
return cgen.Assign{
Expr1: bfs[0],
Expr2: avx.Mm512MulPs{
bfs[0],
avx.Mm512Set1PsLit(64),
},
}
}
sublayer5 := func() cgen.Gen {
var (
postPtrs = a.bnPtrs[preCnt:]
postCnt = len(postPtrs)
)
if postCnt == 0 {
return nil
}
stmts := make(cgen.Stmts, postCnt*2)
postCh := cgen.Paren{
Inner: addMul(
addMul(
il(-a.baseBundle*a.bundleFilts),
il(a.toChans),
a.groupIdx,
),
il(a.bundleFilts),
a.bundleIdx,
),
}
for x, postPtr := range postPtrs {
var (
postMul = vb(a.name("postMul"))
postAdd = vb(a.name("postAdd"))
)
stmts[x*2] = &bn.Load{
Ctx: a.bc,
Mas: postPtr,
Channel: postCh,
Mul: postMul,
Add: postAdd,
Cnt: a.filts2,
}
stmts[x*2+1] = cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512FmaddPs{
bias, postMul,
postAdd,
},
}
}
return stmts
}
sublayer4 := func() cgen.Gen {
var (
ae = a.biasPtr
groupPitch = a.toChans * a.biasBytes
bundlePitch = a.bundleFilts * a.biasBytes
mask = loMask(a.filts2)
)
ae = cgen.Sub{
Expr1: ae,
Expr2: il(a.baseBundle * bundlePitch),
}
ae = addMul(ae, il(groupPitch), a.groupIdx)
ae = addMul(ae, il(bundlePitch), a.bundleIdx)
return cgen.Stmts{
cgen.Assign{
Expr1: bias,
Expr2: avx.Mm512MaskzLoaduPs{
mask, ae,
},
},
sublayer5(),
}
}
sublayer3 := func() cgen.Gen {
bias = vb(a.name("bias"))
var stmt cgen.Gen
switch preCnt {
case 0:
bfs[0] = bias
stmt = scale()
default:
stmt = cgen.Assign{
Expr1: bfs[0],
Expr2: avx.Mm512AddPs{
bfs[0], bias,
},
}
}
return cgen.Stmts{
cgen.Var{
Type: avx.M512, What: bias,
Init: avx.Mm512SetzeroPs,
},
cgen.If{
Cond: cgen.IsZero{
Expr: a.epochCoord,
},
Then: cgen.Stmts{
sublayer4(),
stmt,
},
},
}
}
sublayer2 := func() cgen.Gen {
if a.epochFirst == 0 {
return sublayer3()
}
return nil
}
sublayer1 := func() cgen.Gen {
if preCnt == 0 {
return sublayer2()
}
return cgen.Stmts{
&sumr.Pack{
Platform: a.platform,
Nms: a.nms,
Vars: bfs,
},
sublayer2(),
scale(),
}
}
return cgen.Gens{
layer4(),
sublayer1(),
}
}
layer2 := func() cgen.Gen {
var (
stmts = layer3()
ae = a.bfPtr
bundlePitch = a.bundleFilts * a.bfFragBytes
mask = loMask(a.filts2)
bf = bfs[0]
)
ae = cgen.Sub{
Expr1: ae,
Expr2: il(
a.baseBundle*bundlePitch -
a.baseFilt*a.bfFragBytes,
),
}
ae = addMul(ae, il(a.bfGroupBytes), a.groupIdx)
ae = addMul(ae, il(bundlePitch), a.bundleIdx)
if bf == nil {
bf = avx.Mm512SetzeroPs
}
return cgen.Stmts{
stmts,
avx.Mm512MaskStoreuPs{
ae, mask, bf,
},
}
}
layer1 := func() cgen.Gen {
bfs = make([]cgen.Gen, a.filts2)
preCnt = a.Filts[a.filtsIdx].BnPre
if preCnt == 0 {
return layer2()
}
var (
last = len(bfs)
stmts = make(cgen.Stmts, last+1)
)
for x := range bfs {
bf := vb(a.name("bf"))
bfs[x] = bf
stmts[x] = cgen.Var{
Type: avx.M512, What: bf,
Init: avx.Mm512SetzeroPs,
}
}
stmts[last] = layer2()
return stmts
}
return layer1()
}

type ArrangeDats struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeDats) Prep() cgen.Gen {
a.layout = newLayout(a.Ctx, a.Spec)
const affix = "ArrangeDats"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeDats{ArrangeDats: a},
cgen.Newline,
}
}

func (a *ArrangeDats) Bytes() int {
return a.dfTotalBytes
}

func (a *ArrangeDats) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeDats struct {
*ArrangeDats
sliceTile1 int
sliceTile2 int
sliceTiles int
sliceScrap1 int
sliceScrap2 int
sliceHull int
coreTile int
coreTiles int
coreScrap int
coreHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
sliceCoord cgen.Gen
coreCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
sliceTile int
sliceScrap int
coreBytes int
pileBytes int
groupBytes int
zoneBytes int
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
dfPtr cgen.Gen
groupIdx cgen.Gen
coreIdx cgen.Gen
coreLast cgen.Gen
coreH cgen.Gen
coreW cgen.Gen
lbs []*loopB
sliceIdx cgen.Gen
bnMuls []cgen.Gen
bnAdds []cgen.Gen
lb *loopB
blkIdx cgen.Gen
repeat bool
meldIdx cgen.Gen
fragIdx cgen.Gen
}

func (a *arrangeDats) Append(to []byte) []byte {
var threadBlks int
switch a.platform {
case raw.AVX512Float32:
threadBlks = 128
default:
panic("bug")
}
var (
chanBlks1 = a.dfCores1 * a.dfSliceFrags1
chanBlks2 = chanBlks1 + a.dfSliceFrags2
groupBlks1 = a.fromChans * chanBlks2
groupBlks2 = ceilQuo(groupBlks1, a.epochs2)
coreBlks = ceilQuo(groupBlks2, a.dfCores2)
)
a.sliceTile1 = a.slices1
a.sliceTile2 = a.slices2
a.sliceTiles = 1
a.sliceScrap1 = 0
a.sliceScrap2 = 0
a.sliceHull = 1
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
switch {
case threadBlks <= coreBlks:
var (
minSlices = a.slices1
sliceBlks = ceilQuo(chanBlks2, a.dfCores2)
)
switch {
case a.epochs1 == a.epochs2:
case a.epochs1 == 0 || a.slices1 > a.slices2:
minSlices = a.slices2
}
var (
tile = ceilQuo(threadBlks, sliceBlks)
tiles = max(minSlices/tile, 1)
)
a.sliceTile1 = a.slices1 / tiles
a.sliceTile2 = a.slices2 / tiles
a.sliceTiles = tiles
a.sliceScrap1 = a.slices1 - tiles*a.sliceTile1
a.sliceScrap2 = a.slices2 - tiles*a.sliceTile2
a.sliceHull = tiles
if a.sliceScrap1 > 0 || a.sliceScrap2 > 0 {
a.sliceTiles--
a.sliceScrap1 += a.sliceTile1
a.sliceScrap2 += a.sliceTile2
}
a.coreTile = 1
a.coreTiles = a.dfCores2
a.coreScrap = 0
a.coreHull = a.dfCores2
case threadBlks <= groupBlks2:
var (
tile = ceilQuo(threadBlks, coreBlks)
tiles = max(a.dfCores2/tile, 1)
)
a.coreTile = a.dfCores2 / tiles
a.coreTiles = tiles
a.coreScrap = a.dfCores2 - tiles*a.coreTile
a.coreHull = tiles
if a.coreScrap > 0 {
a.coreTiles--
a.coreScrap += a.coreTile
}
default:
a.coreTile = a.dfCores2
a.coreTiles = 1
a.coreScrap = 0
a.coreHull = 1
var (
tile = ceilQuo(threadBlks, groupBlks2)
tiles = max(a.Groups/tile, 1)
)
a.groupTile = a.Groups / tiles
a.groupTiles = tiles
a.groupScrap = a.Groups - tiles*a.groupTile
a.groupHull = tiles
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.sliceHull),
il(a.coreHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeDats) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task")),
Pt: vb(a.name("pt")),
}
var (
body = make(cgen.Stmts, 8)
usedPt = false
)
a.tensors = vb(a.name("tensors"))
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: a.tensors,
Init: callee.Any(),
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(a.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
a.sliceCoord = coord("s", a.sliceHull, 0)
a.coreCoord = coord("c", a.coreHull, 1)
a.groupCoord = coord("g", a.groupHull, 2)
a.epochCoord = coord("e", a.epochs2, 3)
if !usedPt {
body[5] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
kernel := func(first, cnt int) cgen.Gen {
var assn cgen.Gen
if a.epochs2 > 1 && cnt == 1 {
assn = cgen.Assign{
Expr1: a.epochCoord,
Expr2: il(first),
}
}
return cgen.Stmts{
assn,
a.kernel1(),
}
}
if a.epochs1 > 0 {
a.sliceTile = a.sliceTile1
a.sliceScrap = a.sliceScrap1
a.coreBytes = a.dfCoreBytes11
a.pileBytes = a.dfPileBytes1
a.groupBytes = a.dfGroupBytes1
a.zoneBytes = a.dfZoneBytes1
put := kernel(0, a.epochs1)
if a.epochs1 < a.epochs2 {
put = cgen.If{
Cond: cgen.CmpL{
Expr1: a.epochCoord,
Expr2: il(a.epochs1),
},
Then: cgen.Stmts{
put,
cgen.Return{},
},
}
}
body[6] = put
}
if a.epochs1 < a.epochs2 {
a.sliceTile = a.sliceTile2
a.sliceScrap = a.sliceScrap2
a.coreBytes = a.dfCoreBytes21
a.pileBytes = a.dfPileBytes2
a.groupBytes = a.dfGroupBytes2
a.zoneBytes = a.dfZoneBytes2
body[7] = kernel(a.epochs1, 1)
}
return callee.Func(body)
}

func (a *arrangeDats) kernel1() cgen.Gen {
a.datPtrs = a.datPtrs[:0]
a.bnPtrs = a.bnPtrs[:0]
var (
stmts cgen.Stmts
tensorIdx = 0
)
decl := func(ptr, expr cgen.Gen) {
stmts = append(
stmts, cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr, Init: expr,
},
)
}
tensor := func() cgen.Gen {
i := tensorIdx
tensorIdx++
return cgen.Elem{
Arr: a.tensors,
Idx: il(i),
}
}
datPtr := func() {
var (
ptr = vb(a.name("datPtr"))
i = len(a.datPtrs)
pitch1 = a.From.Pitch1Bytes[i]
pitch2 = a.From.Pitch2Bytes[i]
)
a.datPtrs = append(a.datPtrs, ptr)
decl(
ptr, addMul(
cgen.Sub{
Expr1: tensor(),
Expr2: il(
a.PaddingH*pitch1 +
a.PaddingW*a.datBytes,
),
},
il(a.slices1*pitch2),
a.epochCoord,
),
)
}
datPtrs := func(n int) {
for ; n > 0; n-- {
datPtr()
}
}
bnPtr := func() {
ptr := vb(a.name("bnPtr"))
a.bnPtrs = append(a.bnPtrs, ptr)
decl(
ptr, &bn.Offset{
Ctx: a.bc,
Mas: tensor(),
Channel: cgen.Mul{
Expr1: il(a.slices1),
Expr2: a.epochCoord,
},
},
)
}
datPtr()
for op := range a.From.Ops {
op := &a.From.Ops[op]
switch op.Kind {
case mod.Add:
datPtrs(op.Int)
case mod.Bn:
bnPtr()
case mod.ReLU:
default:
panic("bug")
}
}
a.dfPtr = vb(a.name("dfPtr"))
decl(
a.dfPtr, addMul(
tensor(),
il(a.dfEpochBytes1),
a.epochCoord,
),
)
return append(
stmts,
a.kernel2(),
)
}

func (a *arrangeDats) kernel2() cgen.Gen {
a.groupIdx = vb(a.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.groupIdx,
Init: cgen.Mul{
Expr1: il(a.groupTile),
Expr2: a.groupCoord,
},
}
switch a.groupTiles {
case a.groupHull:
iters = a.groupTile
case 0:
iters = a.groupScrap
}
switch iters {
case 1:
stmts[2] = a.kernel3()
default:
var (
last = vb(a.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.groupCoord,
Expr2: il(a.groupTiles),
},
Then: il(a.groupTile - 1),
Else: il(a.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: a.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: a.groupIdx,
},
Body: a.kernel3(),
}
}
return stmts
}

func (a *arrangeDats) kernel3() cgen.Gen {
a.coreIdx = vb(a.name("j"))
switch a.coreHull {
case 1:
a.coreLast = nil
default:
a.coreLast = vb(a.name("last"))
}
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreIdx,
Init: cgen.Mul{
Expr1: il(a.coreTile),
Expr2: a.coreCoord,
},
}
if a.coreLast != nil {
var expr cgen.Gen
switch a.coreTiles {
case a.coreHull:
expr = il(a.coreTile - 1)
case 0:
expr = il(a.coreScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.coreCoord,
Expr2: il(a.coreTiles),
},
Then: il(a.coreTile - 1),
Else: il(a.coreScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreLast,
Init: cgen.Add{
Expr1: a.coreIdx,
Expr2: expr,
},
}
}
stmts[2] = a.kernel4()
return stmts
}

func (a *arrangeDats) kernel4() cgen.Gen {
var (
lh *loopH
rel cgen.Gen
base cgen.Gen
relBreak int
lw *loopW
)
layer7 := func() cgen.Gen {
var retIf cgen.Gen
if a.coreLast != nil {
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.coreIdx,
Expr2: a.coreLast,
},
Then: cgen.Return{},
}
}
return cgen.Stmts{
a.kernel5(),
retIf,
cgen.IncPre{
Expr: a.coreIdx,
},
}
}
layer6 := func() cgen.Gen {
if lw.fromStep == 0 {
return layer7()
}
last := vb(a.name("jj"))
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: cgen.Sub{
Expr1: il(lw.segPast - 1),
Expr2: rel,
},
Expr2: a.coreIdx,
},
},
cgen.For{
Cond: cgen.CmpLE{
Expr1: a.coreIdx,
Expr2: last,
},
Post: cgen.AddAssign{
Expr1: a.coreW,
Expr2: il(lw.fromStep),
},
Body: layer7(),
},
}
}
layer5 := func() cgen.Gen {
a.coreH = vb(a.name("h"))
a.coreW = vb(a.name("w"))
a.lbs = lw.lbs
var (
exprW cgen.Gen
breakIf cgen.Gen
)
switch lw.fromStep {
case 0:
exprW = il(lw.fromW)
default:
exprW = addMul(
il(lw.fromW-lw.fromStep*lw.segFirst),
il(lw.fromStep),
rel,
)
}
if lw.segPast == relBreak {
breakIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: a.coreIdx,
Expr2: il(lh.segPast),
},
Then: cgen.Break,
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreH,
Init: cgen.Add{
Expr1: base,
Expr2: il(lw.fromH),
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: a.coreW,
Init: exprW,
},
layer6(),
breakIf,
}
}
layer4 := func() cgen.Gen {
var (
lws = lh.lws
tree func(int, int) cgen.Stmts
)
leaf := func(x int) cgen.Stmts {
lw = lws[x]
var assn cgen.Gen
if x+1 < len(lws) {
assn = cgen.Assign{
Expr1: rel,
Expr2: il(lw.segPast),
}
}
return cgen.Stmts{
layer5(),
assn,
}
}
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = lws[first].segFirst
stop = lws[last].segPast
split = start + (stop-start)/2
x = first + 1
)
for lws[x].segPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: rel,
Expr2: il(lws[x].segFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(lws)-1)
}
layer3 := func() cgen.Gen {
if lh.segStep == 0 {
relBreak = -1
return layer4()
}
x := lh.segPast - lh.segFirst
relBreak = (x-1)%lh.segStep + 1
return cgen.For{
Post: cgen.CommaSpaced{
cgen.Assign{
Expr1: rel,
Expr2: il(0),
},
cgen.AddAssign{
Expr1: base,
Expr2: il(lh.fromStep),
},
},
Body: layer4(),
}
}
layer2 := func() cgen.Gen {
rel = vb(a.name("rel"))
base = vb(a.name("base"))
var (
relExpr cgen.Gen = cgen.Sub{
Expr1: a.coreIdx,
Expr2: il(lh.segFirst),
}
baseExpr = il(lh.fromH)
)
if lh.segStep != 0 {
var (
numer cgen.Gen = cgen.Cast{
Type: cgen.SizeT,
Expr: cgen.Paren{
Inner: relExpr,
},
}
denom = il(lh.segStep)
)
relExpr = cgen.Rem{
Expr1: numer,
Expr2: denom,
}
baseExpr = addMul(
baseExpr,
cgen.Quo{
Expr1: numer,
Expr2: denom,
},
il(lh.fromStep),
)
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: rel,
Init: relExpr,
},
cgen.Var{
Type: cgen.PtrdiffT,
What: base,
Init: baseExpr,
},
layer3(),
}
}
layer1 := func() cgen.Gen {
var (
lhs = a.segs.lhs
tree func(int, int) cgen.Stmts
)
leaf := func(x int) cgen.Stmts {
lh = lhs[x]
var assn cgen.Gen
if x+1 < len(lhs) {
assn = cgen.Assign{
Expr1: a.coreIdx,
Expr2: il(lh.segPast),
}
}
return cgen.Stmts{
layer2(),
assn,
}
}
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = lhs[first].segFirst
stop = lhs[last].segPast
split = start + (stop-start)/2
x = first + 1
)
for lhs[x].segPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: a.coreIdx,
Expr2: il(lhs[x].segFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(lhs)-1)
}
return layer1()
}

func (a *arrangeDats) kernel5() cgen.Gen {
a.sliceIdx = vb(a.name("k"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: a.sliceIdx,
Init: cgen.Mul{
Expr1: il(a.sliceTile),
Expr2: a.sliceCoord,
},
}
switch {
case a.sliceTiles == a.sliceHull:
iters = a.sliceTile
case a.sliceTiles == 0:
fallthrough
case a.sliceTile == a.sliceScrap:
iters = a.sliceScrap
}
switch iters {
case 1:
stmts[2] = a.kernel6()
default:
var (
last = vb(a.name("kk"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: a.sliceCoord,
Expr2: il(a.sliceTiles),
},
Then: il(a.sliceTile - 1),
Else: il(a.sliceScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: a.sliceIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: a.sliceIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: a.sliceIdx,
},
Body: a.kernel6(),
}
}
return stmts
}

func (a *arrangeDats) kernel6() cgen.Gen {
layer5 := func() cgen.Gen {
switch a.platform {
case raw.AVX512Float32:
return a.m512()
default:
panic("bug")
}
}
layer4 := func() cgen.Gen {
a.meldIdx = vb(a.name("m"))
a.fragIdx = vb(a.name("f"))
var (
numer cgen.Gen = cgen.Cast{
Type: cgen.SizeT,
Expr: a.blkIdx,
}
denom = il(a.dfMeldFrags)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: a.meldIdx,
Init: cgen.Quo{
Expr1: numer,
Expr2: denom,
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: a.fragIdx,
Init: cgen.Rem{
Expr1: numer,
Expr2: denom,
},
},
layer5(),
}
}
layer3 := func(repeat int) cgen.Gen {
var (
stmts cgen.Stmts
first = a.lb.blkFirst
past1 = a.lb.blkPast - repeat
past2 = a.lb.blkPast
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
do := func(start, stop int) {
if start == stop {
return
}
a.blkIdx = vb(a.name("b"))
a.repeat = start == past1
decl := cgen.Var{
Type: cgen.PtrdiffT,
What: a.blkIdx,
Init: il(start),
}
switch stop - start {
case 1:
stmt(decl)
stmt(layer4())
default:
stmt(cgen.For{
Init: decl,
Cond: cgen.CmpL{
Expr1: a.blkIdx,
Expr2: il(stop),
},
Post: cgen.IncPre{
Expr: a.blkIdx,
},
Body: layer4(),
})
}
}
do(first, past1)
do(past1, past2)
return stmts
}
layer2 := func() cgen.Gen {
var (
n = len(a.lbs)
gens = make(cgen.Gens, n)
)
for x, lb := range a.lbs {
a.lb = lb
repeat := 0
if x == n-1 &&
lb.blkPast%a.dfMeldFrags > 0 {
repeat = 1
}
gens[x] = layer3(repeat)
}
return gens
}
layer1 := func() cgen.Gen {
a.bnMuls = a.bnMuls[:0]
a.bnAdds = a.bnAdds[:0]
var (
last = len(a.bnPtrs)
gens = make(cgen.Gens, last+1)
)
ch := cgen.Paren{
Inner: addMul(
a.sliceIdx,
il(a.fromChans),
a.groupIdx,
),
}
for x, bnPtr := range a.bnPtrs {
var (
bnMul = vb(a.name("bnMul"))
bnAdd = vb(a.name("bnAdd"))
)
a.bnMuls = append(a.bnMuls, bnMul)
a.bnAdds = append(a.bnAdds, bnAdd)
gens[x] = &bn.Load{
Ctx: a.bc,
Mas: bnPtr,
Channel: ch,
Mul: bnMul,
Add: bnAdd,
}
}
gens[last] = layer2()
return gens
}
return layer1()
}

func (a *arrangeDats) m512() cgen.Gen {
var (
fwd *quadfft.Fwd
eo cgen.Gen
pileIdx int
zoneIdx int
dfs []cgen.Gen
)
layer6 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
emit := func(side, part int) {
var (
to = a.dfPtr
slicePitch = a.dfSliceBytes1
partPitch = a.dfMeldBytes / 2
fragPitch = a.dfFragBytes / 2
back = side * fragPitch
mask = 0x00ff << uint(side*8)
from = dfs[part]
)
switch a.lbs[len(a.lbs)-1].blkPast {
case a.dfSliceFrags2:
slicePitch = a.dfSliceBytes2
}
if a.repeat {
back = 0
mask = 0xffff
var (
rep = vb(a.name("rep"))
ctrl = side * 2
)
ctrl |= (ctrl + 1) << 2
ctrl |= ctrl << 4
stmt(cgen.Var{
Type: avx.M512, What: rep,
Init: avx.Mm512ShuffleF32x4{
from, from, il(ctrl),
},
})
from = rep
}
to = cgen.Add{
Expr1: to,
Expr2: il(
(zoneIdx+side)*a.zoneBytes +
pileIdx*a.pileBytes +
part*partPitch -
back,
),
}
to = addMul(to, il(a.groupBytes), a.groupIdx)
to = addMul(to, il(a.coreBytes), a.coreIdx)
to = addMul(to, il(slicePitch), a.sliceIdx)
to = addMul(to, il(a.dfMeldBytes), a.meldIdx)
to = addMul(to, il(fragPitch), a.fragIdx)
stmt(avx.Mm512MaskStoreuPs{
to, il(mask), from,
})
}
for side := 0; side < 2; side++ {
for part := 0; part < 2; part++ {
emit(side, part)
}
}
return stmts
}
layer5 := func() cgen.Gen {
at := pileIdx*2 + zoneIdx/2*8
dfs = fwd.Out[at : at+2]
if pileIdx == 0 {
return layer6()
}
stmts := make(cgen.Stmts, 3)
for x, df := range dfs {
stmts[x] = cgen.Assign{
Expr1: df,
Expr2: avx.Mm512PermutexvarPs{
eo, df,
},
}
}
stmts[2] = layer6()
return stmts
}
layer4 := func() cgen.Gen {
var (
n1 = a.zoneFrags
n2 = a.blkZones / 2
gens = make(cgen.Gens, n1*n2)
)
for p := 0; p < n1; p++ {
pileIdx = (p + 1) % n1
for z := 0; z < n2; z++ {
zoneIdx = z * 2
gens[p*n2+z] = layer5()
}
}
return gens
}
layer3 := func() cgen.Gen {
eo = vb(a.name("eo"))
set := make(avx.Mm512SetEpi32, 16)
for x := 0; x < 16; x++ {
set[15-x] = il(x%8*2 + x/8)
}
return cgen.Stmts{
cgen.Var{
Type: avx.M512i, What: eo,
Init: set,
},
layer4(),
}
}
layer2 := func() cgen.Gen {
var (
stmts cgen.Stmts
mask1 = 1<<uint(a.lb.datW) - 1
mask2 = il(mask1 << uint(a.lb.padW))
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
load := func(h, x int) cgen.Gen {
var (
ae = a.datPtrs[x]
pitch1 = a.From.Pitch1Bytes[x]
pitch2 = a.From.Pitch2Bytes[x]
groupPitch = a.fromChans * pitch2
blkPitch = a.lb.fromStep * a.datBytes
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
(a.lb.fromH+h)*pitch1 +
a.lb.fromW*a.datBytes -
a.lb.blkFirst*blkPitch,
),
}
ae = addMul(ae, il(groupPitch), a.groupIdx)
ae = addMul(ae, il(pitch2), a.sliceIdx)
ae = addMul(ae, il(pitch1), a.coreH)
ae = addMul(ae, il(a.datBytes), a.coreW)
ae = addMul(ae, il(blkPitch), a.blkIdx)
return avx.Mm512MaskzLoaduPs{
mask2, ae,
}
}
for h, dat := range &fwd.In {
if dat == nil {
continue
}
var (
datPtrIdx = 0
bnPtrIdx = 0
)
stmt(cgen.Var{
Type: avx.M512, What: dat,
Init: load(h, datPtrIdx),
})
for op := range a.From.Ops {
op := &a.From.Ops[op]
switch op.Kind {
case mod.Add:
for n := op.Int; n > 0; n-- {
datPtrIdx++
stmt(cgen.Assign{
Expr1: dat,
Expr2: avx.Mm512AddPs{
dat,
load(h, datPtrIdx),
},
})
}
case mod.Bn:
stmt(&bn.Apply{
Ctx: a.bc,
Mul: a.bnMuls[bnPtrIdx],
Add: a.bnAdds[bnPtrIdx],
To: dat,
Mask: mask2,
})
bnPtrIdx++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: a.ac,
NegSlope: op.Float,
Var: dat,
})
default:
panic("bug")
}
}
}
stmt(fwd)
stmt(layer3())
return stmts
}
layer1 := func() cgen.Gen {
fwd = &quadfft.Fwd{
Platform: a.platform,
Nms: a.nms,
}
var (
first = a.lb.padH
past = first + a.lb.datH
)
for x := first; x < past; x++ {
dat := vb(a.name("dat"))
fwd.In[x] = dat
}
for x := range &fwd.Out {
df := vb(a.name("df"))
fwd.Out[x] = df
}
return layer2()
}
return layer1()
}

type ProduceSums struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (p *ProduceSums) Prep() cgen.Gen {
p.layout = newLayout(p.Ctx, p.Spec)
const affix = "ProduceSums"
sig := fmt.Sprint(affix, " ", p.Spec)
if prior, ok := p.dedup[sig]; ok {
p.callerName = prior.(string)
return nil
}
p.callerName = p.name(p.prefix + affix)
p.dedup[sig] = p.callerName
return cgen.Gens{
&produceSums{ProduceSums: p},
cgen.Newline,
}
}

func (p *ProduceSums) Bytes() int {
return p.sfTotalBytes
}

func (p *ProduceSums) Append(to []byte) []byte {
var (
tensors = vb(p.name("tensors"))
ptrs = cgen.CommaLines(p.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(p.callerName),
Args: cgen.CommaSpaced{
p.Team, tensors,
},
},
}.Append(to)
}

type produceSums struct {
*ProduceSums
epochFirst int
epochCnt int
slices int
wfCoreBytes int
wfPileBytes int
wfGroupBytes int
wfZoneBytes int
dfCoreBytes int
dfPileBytes int
dfGroupBytes int
dfZoneBytes int
wfTile int
wfTiles int
wfScrap int
wfHull int
dfTile int
dfTiles int
dfScrap int
dfHull int
pileTile int
pileTiles int
pileScrap int
pileHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
epochCoord cgen.Gen
zoneCoord cgen.Gen
groupCoord cgen.Gen
pileCoord cgen.Gen
dfCoord cgen.Gen
wfCoord cgen.Gen
epoch0zone0 bool
bfPtr cgen.Gen
wfPtr cgen.Gen
dfPtr cgen.Gen
sfPtr cgen.Gen
groupIdx cgen.Gen
pileIdx cgen.Gen
pile0 bool
dfIdx cgen.Gen
dfShort bool
wfIdx cgen.Gen
wfShort bool
}

func (p *produceSums) Append(to []byte) []byte {
var threadWork int
switch p.platform {
case raw.AVX512Float32:
threadWork = 256
default:
panic("bug")
}
callee := func(first, cnt int) cgen.Gen {
p.epochFirst = first
p.epochCnt = cnt
switch {
case first < p.epochs1:
p.slices = p.slices1
p.wfCoreBytes = p.wfCoreBytes11
p.wfPileBytes = p.wfPileBytes1
p.wfGroupBytes = p.wfGroupBytes1
p.wfZoneBytes = p.wfZoneBytes1
p.dfCoreBytes = p.dfCoreBytes11
p.dfPileBytes = p.dfPileBytes1
p.dfGroupBytes = p.dfGroupBytes1
p.dfZoneBytes = p.dfZoneBytes1
default:
p.slices = p.slices2
p.wfCoreBytes = p.wfCoreBytes21
p.wfPileBytes = p.wfPileBytes2
p.wfGroupBytes = p.wfGroupBytes2
p.wfZoneBytes = p.wfZoneBytes2
p.dfCoreBytes = p.dfCoreBytes21
p.dfPileBytes = p.dfPileBytes2
p.dfGroupBytes = p.dfGroupBytes2
p.dfZoneBytes = p.dfZoneBytes2
}
var (
wfWork = p.slices
dfWork = p.wfCores2 * wfWork
pileWork = p.dfCores2 * dfWork
groupWork = p.zoneFrags * pileWork
)
p.wfTile = 1
p.wfTiles = p.wfCores2
p.wfScrap = 0
p.wfHull = p.wfCores2
p.dfTile = 1
p.dfTiles = p.dfCores2
p.dfScrap = 0
p.dfHull = p.dfCores2
p.pileTile = 1
p.pileTiles = p.zoneFrags
p.pileScrap = 0
p.pileHull = p.zoneFrags
p.groupTile = 1
p.groupTiles = p.Groups
p.groupScrap = 0
p.groupHull = p.Groups
switch {
case threadWork <= wfWork:
case threadWork <= dfWork:
var (
tile = ceilQuo(threadWork, wfWork)
tiles = max(p.wfCores2/tile, 1)
)
p.wfTile = p.wfCores2 / tiles
p.wfTiles = tiles
p.wfScrap = p.wfCores2 - tiles*p.wfTile
p.wfHull = tiles
if p.wfScrap > 0 {
p.wfTiles--
p.wfScrap += p.wfTile
}
case threadWork <= pileWork:
p.wfTile = p.wfCores2
p.wfTiles = 1
p.wfScrap = 0
p.wfHull = 1
var (
tile = ceilQuo(threadWork, dfWork)
tiles = max(p.dfCores2/tile, 1)
)
p.dfTile = p.dfCores2 / tiles
p.dfTiles = tiles
p.dfScrap = p.dfCores2 - tiles*p.dfTile
p.dfHull = tiles
if p.dfScrap > 0 {
p.dfTiles--
p.dfScrap += p.dfTile
}
case threadWork <= groupWork:
p.wfTile = p.wfCores2
p.wfTiles = 1
p.wfScrap = 0
p.wfHull = 1
p.dfTile = p.dfCores2
p.dfTiles = 1
p.dfScrap = 0
p.dfHull = 1
var (
tile = ceilQuo(threadWork, pileWork)
tiles = max(p.zoneFrags/tile, 1)
)
p.pileTile = p.zoneFrags / tiles
p.pileTiles = tiles
p.pileScrap = p.zoneFrags - tiles*p.pileTile
p.pileHull = tiles
if p.pileScrap > 0 {
p.pileTiles--
p.pileScrap += p.pileTile
}
default:
p.wfTile = p.wfCores2
p.wfTiles = 1
p.wfScrap = 0
p.wfHull = 1
p.dfTile = p.dfCores2
p.dfTiles = 1
p.dfScrap = 0
p.dfHull = 1
p.pileTile = p.zoneFrags
p.pileTiles = 1
p.pileScrap = 0
p.pileHull = 1
var (
tile = ceilQuo(threadWork, groupWork)
tiles = max(p.Groups/tile, 1)
)
p.groupTile = p.Groups / tiles
p.groupTiles = tiles
p.groupScrap = p.Groups - tiles*p.groupTile
p.groupHull = tiles
if p.groupScrap > 0 {
p.groupTiles--
p.groupScrap += p.groupTile
}
}
p.calleeName = p.name(
p.callerName + "Callee",
)
return cgen.Gens{
p.calleeFunc(),
cgen.Newline,
}
}
var (
team = vb(p.name("team"))
tensors = vb(p.name("tensors"))
tuple = vb(p.name("tuple"))
)
inner := func() cgen.Gen {
zone := vb(p.name("z"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: zone,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: zone,
Expr2: il(p.blkZones),
},
Post: cgen.IncPre{
Expr: zone,
},
Body: cgen.Stmts{
cgen.Assign{
Expr1: cgen.Elem{
Arr: tuple, Idx: il(2),
},
Expr2: cgen.Cast{
Type: cgen.PtrVoid,
Expr: zone,
},
},
&threader.Do{
Ctx: p.tc,
Callee: vb(p.calleeName),
Any: tuple,
Hull: []cgen.Gen{
il(p.wfHull),
il(p.dfHull),
il(p.pileHull),
il(p.groupHull),
},
Team: team,
},
},
}
}
outer := func() cgen.Gen {
epoch := vb(p.name("e"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: epoch,
Init: il(p.epochFirst),
},
Cond: cgen.CmpL{
Expr1: epoch,
Expr2: il(
p.epochFirst + p.epochCnt,
),
},
Post: cgen.IncPre{
Expr: epoch,
},
Body: cgen.Stmts{
cgen.Assign{
Expr1: cgen.Elem{
Arr: tuple, Idx: il(1),
},
Expr2: cgen.Cast{
Type: cgen.PtrVoid,
Expr: epoch,
},
},
inner(),
},
}
}
var (
prep = make(cgen.Gens, 2)
body = make(cgen.Stmts, 4)
)
body[0] = cgen.Var{
Type: cgen.PtrVoid,
What: cgen.Elem{
Arr: tuple, Idx: il(3),
},
}
body[1] = cgen.Assign{
Expr1: cgen.Elem{
Arr: tuple, Idx: il(0),
},
Expr2: tensors,
}
if p.epochs1 > 0 {
prep[0] = callee(0, p.epochs1)
body[2] = outer()
}
if p.epochs1 < p.epochs2 {
prep[1] = callee(p.epochs1, 1)
body[3] = outer()
}
return cgen.Gens{
prep,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: p.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: body,
},
}.Append(to)
}

func (p *produceSums) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: p.tc,
Name: p.calleeName,
Task: vb(p.name("task")),
Pt: vb(p.name("pt")),
}
var (
body = make(cgen.Stmts, 10)
tuple = vb(p.name("tuple"))
usedPt = false
)
body[0] = cgen.Var{
Type: cgen.PtrPtrVoid, What: tuple,
Init: callee.Any(),
}
p.tensors = vb(p.name("tensors"))
body[1] = cgen.Var{
Type: cgen.PtrPtrChar, What: p.tensors,
Init: cgen.Elem{
Arr: tuple, Idx: il(0),
},
}
p.epochCoord = vb(p.name("e"))
body[2] = cgen.Var{
Type: cgen.PtrdiffT, What: p.epochCoord,
Init: func() cgen.Gen {
if p.epochCnt == 1 {
return il(p.epochFirst)
}
return cgen.Cast{
Type: cgen.PtrdiffT,
Expr: cgen.Elem{
Arr: tuple, Idx: il(1),
},
}
}(),
}
p.zoneCoord = vb(p.name("z"))
body[3] = cgen.Var{
Type: cgen.PtrdiffT, What: p.zoneCoord,
Init: cgen.Cast{
Type: cgen.PtrdiffT,
Expr: cgen.Elem{
Arr: tuple, Idx: il(2),
},
},
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(p.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[7-i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
p.groupCoord = coord("g", p.groupHull, 3)
p.pileCoord = coord("p", p.pileHull, 2)
p.dfCoord = coord("d", p.dfHull, 1)
p.wfCoord = coord("w", p.wfHull, 0)
if !usedPt {
body[8] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
body[9] = p.kernel1()
return callee.Func(body)
}

func (p *produceSums) kernel1() cgen.Gen {
layer2 := func(e0z0 bool) cgen.Gen {
p.epoch0zone0 = e0z0
return p.kernel2()
}
layer1 := func() cgen.Gen {
return cgen.Stmts{
func() cgen.Gen {
if p.epochFirst > 0 {
return nil
}
both := cgen.Paren{
Inner: cgen.Or{
Expr1: p.epochCoord,
Expr2: p.zoneCoord,
},
}
then := cgen.Stmts{
nil,
cgen.Assign{
Expr1: p.zoneCoord,
Expr2: il(0),
},
layer2(true),
cgen.Return{},
}
if p.epochCnt > 1 {
then[0] = cgen.Assign{
Expr1: p.epochCoord,
Expr2: il(0),
}
}
return cgen.If{
Cond: cgen.Unlikely{
Cond: cgen.IsZero{
Expr: both,
},
},
Then: then,
}
}(),
layer2(false),
}
}
return layer1()
}

func (p *produceSums) kernel2() cgen.Gen {
p.bfPtr = vb(p.name("bfPtr"))
p.wfPtr = vb(p.name("wfPtr"))
p.dfPtr = vb(p.name("dfPtr"))
p.sfPtr = vb(p.name("sfPtr"))
return cgen.Stmts{
cgen.Var{
Type: cgen.RestrictPtrChar,
What: p.bfPtr,
Init: addMul(
cgen.Elem{
Arr: p.tensors,
Idx: il(0),
},
il(p.bfEpochBytes),
p.epochCoord,
),
},
cgen.Var{
Type: cgen.RestrictPtrChar,
What: p.wfPtr,
Init: addMul(
addMul(
cgen.Add{
Expr1: cgen.Elem{
Arr: p.tensors,
Idx: il(0),
},
Expr2: il(p.bfTotalBytes),
},
il(p.wfEpochBytes1),
p.epochCoord,
),
il(p.wfZoneBytes),
p.zoneCoord,
),
},
cgen.Var{
Type: cgen.RestrictPtrChar,
What: p.dfPtr,
Init: addMul(
addMul(
cgen.Elem{
Arr: p.tensors,
Idx: il(1),
},
il(p.dfEpochBytes1),
p.epochCoord,
),
il(p.dfZoneBytes),
p.zoneCoord,
),
},
cgen.Var{
Type: cgen.RestrictPtrChar,
What: p.sfPtr,
Init: cgen.Elem{
Arr: p.tensors,
Idx: il(2),
},
},
p.kernel3(),
}
}

func (p *produceSums) kernel3() cgen.Gen {
p.groupIdx = vb(p.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: p.groupIdx,
Init: cgen.Mul{
Expr1: il(p.groupTile),
Expr2: p.groupCoord,
},
}
switch p.groupTiles {
case p.groupHull:
iters = p.groupTile
case 0:
iters = p.groupScrap
}
switch iters {
case 1:
stmts[2] = p.kernel4()
default:
var (
last = vb(p.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: p.groupCoord,
Expr2: il(p.groupTiles),
},
Then: il(p.groupTile - 1),
Else: il(p.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: p.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: p.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: p.groupIdx,
},
Body: p.kernel4(),
}
}
return stmts
}

func (p *produceSums) kernel4() cgen.Gen {
layer2 := func(p0 bool) cgen.Gen {
p.pile0 = p0
return p.kernel5()
}
layer1 := func() cgen.Gen {
p.pileIdx = vb(p.name("j"))
var (
stmts = make(cgen.Stmts, 4)
last = vb(p.name("jj"))
expr cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: p.pileIdx,
Init: cgen.Mul{
Expr1: il(p.pileTile),
Expr2: p.pileCoord,
},
}
switch p.pileTiles {
case p.pileHull:
expr = il(p.pileTile - 1)
case 0:
expr = il(p.pileScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: p.pileCoord,
Expr2: il(p.pileTiles),
},
Then: il(p.pileTile - 1),
Else: il(p.pileScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: p.pileIdx,
Expr2: expr,
},
}
stmts[2] = cgen.If{
Cond: cgen.Unlikely{
Cond: cgen.IsZero{
Expr: p.pileIdx,
},
},
Then: cgen.Stmts{
layer2(true),
cgen.Assign{
Expr1: p.pileIdx,
Expr2: il(1),
},
},
}
stmts[3] = cgen.For{
Cond: cgen.CmpLE{
Expr1: p.pileIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: p.pileIdx,
},
Body: layer2(false),
}
return stmts
}
return layer1()
}

func (p *produceSums) kernel5() cgen.Gen {
p.dfIdx = vb(p.name("k"))
var (
stmts = make(cgen.Stmts, 4)
retIf cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: p.dfIdx,
Init: cgen.Mul{
Expr1: il(p.dfTile),
Expr2: p.dfCoord,
},
}
if p.dfHull > 1 {
var (
last = vb(p.name("kk"))
expr cgen.Gen
)
switch p.dfTiles {
case p.dfHull:
expr = il(p.dfTile - 1)
case 0:
expr = il(p.dfScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: p.dfCoord,
Expr2: il(p.dfTiles),
},
Then: il(p.dfTile - 1),
Else: il(p.dfScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: p.dfIdx,
Expr2: expr,
},
}
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: p.dfIdx,
Expr2: last,
},
Then: cgen.Return{},
}
}
if p.dfCores1 > 0 {
p.dfShort = false
stmts[2] = cgen.For{
Cond: cgen.CmpNE{
Expr1: p.dfIdx,
Expr2: il(p.dfCores1),
},
Post: cgen.IncPre{
Expr: p.dfIdx,
},
Body: cgen.Stmts{
p.kernel6(),
retIf,
},
}
}
if p.dfCores1 < p.dfCores2 {
p.dfShort = true
stmts[3] = p.kernel6()
}
return stmts
}

func (p *produceSums) kernel6() cgen.Gen {
p.wfIdx = vb(p.name("l"))
var (
stmts = make(cgen.Stmts, 4)
retIf cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: p.wfIdx,
Init: cgen.Mul{
Expr1: il(p.wfTile),
Expr2: p.wfCoord,
},
}
if p.wfHull > 1 {
var (
last = vb(p.name("ll"))
expr cgen.Gen
)
switch p.wfTiles {
case p.wfHull:
expr = il(p.wfTile - 1)
case 0:
expr = il(p.wfScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: p.wfCoord,
Expr2: il(p.wfTiles),
},
Then: il(p.wfTile - 1),
Else: il(p.wfScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: p.wfIdx,
Expr2: expr,
},
}
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: p.wfIdx,
Expr2: last,
},
Then: cgen.Return{},
}
}
if p.wfCores1 > 0 {
p.wfShort = false
stmts[2] = cgen.For{
Cond: cgen.CmpNE{
Expr1: p.wfIdx,
Expr2: il(p.wfCores1),
},
Post: cgen.IncPre{
Expr: p.wfIdx,
},
Body: cgen.Stmts{
p.kernel7(),
retIf,
},
}
}
if p.wfCores1 < p.wfCores2 {
p.wfShort = true
stmts[3] = p.kernel7()
}
return stmts
}

func (p *produceSums) kernel7() cgen.Gen {
switch p.platform {
case raw.AVX512Float32:
return p.m512()
default:
panic("bug")
}
}

func (p *produceSums) m512() cgen.Gen {
var (
rows1 int
rows2 int
cols1 int
cols2 int
sfs [][][][2]cgen.Gen
sliceIdx cgen.Gen
wfs [][3]cgen.Gen
)
layer10 := func() cgen.Gen {
var (
stmts cgen.Stmts
col int
dfs [2]cgen.Gen
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
load := func(part int, nm string) {
var (
vec = vb(p.name(nm))
ae = p.dfPtr
slicePitch = p.dfSliceBytes1
)
if p.dfShort {
slicePitch = p.dfSliceBytes2
}
ae = cgen.Add{
Expr1: ae,
Expr2: il(
col*p.dfMeldBytes +
part*(p.dfMeldBytes/2),
),
}
ae = addMul(ae, il(p.dfGroupBytes), p.groupIdx)
ae = addMul(ae, il(p.dfPileBytes), p.pileIdx)
ae = addMul(ae, il(p.dfCoreBytes), p.dfIdx)
ae = addMul(ae, il(slicePitch), sliceIdx)
stmt(cgen.Var{
Type: avx.M512, What: vec,
Init: avx.Mm512LoaduPs{ae},
})
dfs[part] = vec
}
madd := func(rows, acc int) {
var (
dfRe = dfs[0]
dfIm = dfs[1]
)
for row := 0; row < rows; row++ {
var (
wfRe = wfs[row][0]
wfIm = wfs[row][1]
sfRe = sfs[row][col][acc][0]
sfIm = sfs[row][col][acc][1]
)
stmt(cgen.Assign{
Expr1: sfRe,
Expr2: avx.Mm512FmaddPs{
wfRe, dfRe, sfRe,
},
})
switch {
case p.pile0:
var (
mask = il(0xfcfc)
wfMx = wfs[row][2]
)
stmt(cgen.Assign{
Expr1: sfRe,
Expr2: avx.Mm512Mask3FmaddPs{
wfIm, dfIm, sfRe, mask,
},
})
stmt(cgen.Assign{
Expr1: sfIm,
Expr2: avx.Mm512FmaddPs{
wfMx, dfIm, sfIm,
},
})
stmt(cgen.Assign{
Expr1: sfIm,
Expr2: avx.Mm512Mask3FnmaddPs{
wfIm, dfRe, sfIm, mask,
},
})
default:
stmt(cgen.Assign{
Expr1: sfRe,
Expr2: avx.Mm512FmaddPs{
wfIm, dfIm, sfRe,
},
})
stmt(cgen.Assign{
Expr1: sfIm,
Expr2: avx.Mm512FmaddPs{
wfRe, dfIm, sfIm,
},
})
stmt(cgen.Assign{
Expr1: sfIm,
Expr2: avx.Mm512FnmaddPs{
wfIm, dfRe, sfIm,
},
})
}
}
}
exch := func(part int) {
var (
vec = dfs[part]
ctrl = 1<<6 | 0<<4 | 3<<2 | 2<<0
)
stmt(cgen.Assign{
Expr1: vec,
Expr2: avx.Mm512ShuffleF32x4{
vec, vec, il(ctrl),
},
})
}
for col = 0; col < cols2; col++ {
load(0, "dfRe")
load(1, "dfIm")
madd(rows2, 0)
if col == cols1 {
break
}
if rows1 == 0 {
continue
}
exch(0)
exch(1)
madd(rows1, 1)
}
return stmts
}
layer9 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
wfs = make([][3]cgen.Gen, rows2)
for row := range wfs {
var (
ae = p.wfPtr
slicePitch = p.wfSliceBytes1
)
if p.wfShort {
slicePitch = p.wfSliceBytes2
}
ae = cgen.Add{
Expr1: ae,
Expr2: il(row * p.wfMeldBytes),
}
ae = addMul(ae, il(p.wfGroupBytes), p.groupIdx)
ae = addMul(ae, il(p.wfPileBytes), p.pileIdx)
ae = addMul(ae, il(p.wfCoreBytes), p.wfIdx)
ae = addMul(ae, il(slicePitch), sliceIdx)
var (
wfLd = vb(p.name("wfLd"))
wfRe = vb(p.name("wfRe"))
wfIm = vb(p.name("wfIm"))
)
stmt(cgen.Var{
Type: avx.M512i, What: wfLd,
Init: avx.Mm512LoaduSi512{ae},
})
wfs[row][0] = wfRe
stmt(cgen.Var{
Type: avx.M512, What: wfRe,
Init: avx.Mm512CvtphPs{
avx.Mm512Castsi512Si256{
wfLd,
},
},
})
wfs[row][1] = wfIm
stmt(cgen.Var{
Type: avx.M512, What: wfIm,
Init: avx.Mm512CvtphPs{
avx.Mm512Extracti64x4Epi64{
wfLd, il(1),
},
},
})
if p.pile0 {
wfMx := vb(p.name("wfMx"))
wfs[row][2] = wfMx
stmt(cgen.Var{
Type: avx.M512, What: wfMx,
Init: avx.Mm512MaskMovPs{
wfIm, il(0xfcfc),
wfRe,
},
})
}
}
stmt(layer10())
return stmts
}
layer8 := func() cgen.Gen {
sliceIdx = vb(p.name("s"))
return cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: sliceIdx,
Init: il(0),
},
Cond: cgen.CmpL{
Expr1: sliceIdx,
Expr2: il(p.slices),
},
Post: cgen.IncPre{
Expr: sliceIdx,
},
Body: layer9(),
}
}
layer7 := func() cgen.Gen {
var stmts [2]cgen.Stmts
stmt := func(x int, st cgen.Gen) {
stmts[x] = append(stmts[x], st)
}
do := func(row, col, acc, part int) {
var (
vec = sfs[row][col][acc][part]
ae = p.sfPtr
sitePitch = p.sfSiteBytes11
rowPitch = p.sfRowBytes11
colPitch = p.sfMeldBytes11
accPitch = p.wfMeldFrags * p.sfFragBytes
partPitch = accPitch / 2
)
if p.dfShort {
sitePitch = p.sfSiteBytes12
rowPitch = p.sfRowBytes12
}
if row == rows1 {
colPitch = p.sfMeldBytes21
}
ae = cgen.Add{
Expr1: ae,
Expr2: il(
row*rowPitch +
col*colPitch +
acc*accPitch +
part*partPitch,
),
}
ae = addMul(ae, il(p.sfGroupBytes), p.groupIdx)
ae = addMul(ae, il(p.sfPileBytes), p.pileIdx)
ae = addMul(ae, il(p.sfCoreBytes1), p.dfIdx)
ae = addMul(ae, il(sitePitch), p.wfIdx)
if !p.epoch0zone0 {
stmt(0, cgen.Assign{
Expr1: vec,
Expr2: avx.Mm512AddPs{
vec,
avx.Mm512LoaduPs{ae},
},
})
}
stmt(1, avx.Mm512StoreuPs{
ae, vec,
})
}
stmt(0, layer8())
for row := range sfs {
for col := range sfs[row] {
for acc := range sfs[row][col] {
switch {
case row == rows1 && col == cols1:
var (
sfRe = sfs[row][col][acc][0]
sfIm = sfs[row][col][acc][1]
ctrl = 1<<6 | 0<<4 | 1<<2 | 0<<0
)
stmt(0, cgen.Assign{
Expr1: sfRe,
Expr2: avx.Mm512ShuffleF32x4{
sfRe, sfIm, il(ctrl),
},
})
do(row, col, acc, 0)
default:
for part := range &sfs[row][col][acc] {
do(row, col, acc, part)
}
}
}
}
}
return cgen.Gens{
stmts[0],
stmts[1],
}
}
layer6 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
for row := range sfs {
for col := range sfs[row] {
for acc := range sfs[row][col] {
if col == 0 && acc == 0 {
continue
}
for part, vec := range &sfs[row][col][acc] {
stmt(cgen.Var{
Type: avx.M512, What: vec,
Init: sfs[row][0][0][part],
})
}
}
}
}
stmt(layer7())
return stmts
}
layer5 := func() cgen.Gen {
bias := func() cgen.Stmts {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
for row := range sfs {
var (
sfRe = sfs[row][0][0][0]
frags = p.wfMeldFrags
)
if row == rows1 {
frags = 1
}
for frag := 0; frag < frags; frag++ {
var (
mask = 1 << uint(frag*8)
bf = p.bfPtr
corePitch = p.wfSliceMelds1 * p.bfMeldBytes
)
if row == rows1 && cols1 > 0 {
mask |= mask << 8
}
bf = cgen.Add{
Expr1: bf,
Expr2: il(
row*p.bfMeldBytes +
frag*p.bfFragBytes,
),
}
bf = addMul(bf, il(p.bfGroupBytes), p.groupIdx)
bf = addMul(bf, il(corePitch), p.wfIdx)
bf = cgen.At{
Expr: cgen.Cast{
Type: cgen.PtrFloat,
Expr: cgen.Paren{
Inner: bf,
},
},
}
stmt(cgen.Assign{
Expr1: sfRe,
Expr2: avx.Mm512MaskMovPs{
sfRe, il(mask),
avx.Mm512Set1Ps{bf},
},
})
}
}
return stmts
}
return cgen.Stmts{
func() cgen.Gen {
if p.pile0 {
if p.epoch0zone0 {
return bias()
}
if p.epochFirst+p.epochCnt > 1 {
for x := range p.Filts {
if p.Filts[x].BnPre == 0 {
continue
}
return cgen.If{
Cond: cgen.Unlikely{
Cond: cgen.IsZero{
Expr: p.zoneCoord,
},
},
Then: bias(),
}
}
}
}
return cgen.Cast{
Type: cgen.Void,
Expr: p.bfPtr,
}
}(),
layer6(),
}
}
layer4 := func() cgen.Gen {
var (
last = len(sfs) * 2
stmts = make(cgen.Stmts, last+1)
)
for row := range sfs {
for part, vec := range &sfs[row][0][0] {
stmts[row*2+part] = cgen.Var{
Type: avx.M512, What: vec,
Init: avx.Mm512SetzeroPs,
}
}
}
stmts[last] = layer5()
return stmts
}
layer3 := func() cgen.Gen {
sfs = make([][][][2]cgen.Gen, rows2)
for row := range sfs {
sfs[row] = make([][][2]cgen.Gen, cols2)
for col := range sfs[row] {
accs := p.dfMeldFrags
if row == rows1 || col == cols1 {
accs = 1
}
sfs[row][col] = make([][2]cgen.Gen, accs)
for acc := range sfs[row][col] {
var (
sfRe = vb(p.name("sfRe"))
sfIm = vb(p.name("sfIm"))
)
sfs[row][col][acc][0] = sfRe
sfs[row][col][acc][1] = sfIm
}
}
}
return layer4()
}
layer2 := func() cgen.Gen {
switch {
case p.dfShort:
cols2 = p.dfSliceMelds2
cols1 = cols2 - p.dfSliceFrags2%p.dfMeldFrags
default:
cols2 = p.dfSliceMelds1
cols1 = cols2
}
return layer3()
}
layer1 := func() cgen.Gen {
switch {
case p.wfShort:
rows2 = p.wfSliceMelds2
rows1 = rows2 - p.wfSliceFrags2%p.wfMeldFrags
default:
rows2 = p.wfSliceMelds1
rows1 = rows2
}
return layer2()
}
return layer1()
}

type ConsumeSums struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
callerName string
}

func (c *ConsumeSums) Prep() cgen.Gen {
const affix = "ConsumeSums"
sig := fmt.Sprint(affix, " ", c.Spec)
if prior, ok := c.dedup[sig]; ok {
c.callerName = prior.(string)
return nil
}
c.callerName = c.name(c.prefix + affix)
c.dedup[sig] = c.callerName
return cgen.Gens{
&consumeSums{ConsumeSums: c},
cgen.Newline,
}
}

func (c *ConsumeSums) Append(to []byte) []byte {
var (
tensors = vb(c.name("tensors"))
ptrs = cgen.CommaLines(c.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(c.callerName),
Args: cgen.CommaSpaced{
c.Team, tensors,
},
},
}.Append(to)
}

type consumeSums struct {
*ConsumeSums
*layout
wfTile int
wfTiles int
wfScrap int
wfHull int
dfTile int
dfTiles int
dfScrap int
dfHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
wfCoord cgen.Gen
dfCoord cgen.Gen
groupCoord cgen.Gen
sfPtr cgen.Gen
datSplit int
datPtrs []cgen.Gen
bnPtrs []cgen.Gen
groupIdx cgen.Gen
dfIdx cgen.Gen
dfLast cgen.Gen
toH cgen.Gen
toW cgen.Gen
lbs []*loopB
wfIdx cgen.Gen
wfShort bool
}

func (c *consumeSums) Append(to []byte) []byte {
c.layout = newLayout(c.Ctx, c.Spec)
var (
n1 = c.dfCores1 * c.dfSliceFrags1
n2 = c.toChans * (n1 + c.dfSliceFrags2)
wfBlks = ceilQuo(n2, c.dfCores2*c.wfCores2)
dfBlks = ceilQuo(n2, c.dfCores2)
groupBlks = n2
threadBlks int
)
switch c.platform {
case raw.AVX512Float32:
threadBlks = 512
default:
panic("bug")
}
c.wfTile = c.wfCores2
c.wfTiles = 1
c.wfScrap = 0
c.wfHull = 1
c.groupTile = 1
c.groupTiles = c.Groups
c.groupScrap = 0
c.groupHull = c.Groups
switch {
case threadBlks <= dfBlks:
var (
tile = ceilQuo(threadBlks, wfBlks)
tiles = max(c.wfCores2/tile, 1)
)
c.wfTile = c.wfCores2 / tiles
c.wfTiles = tiles
c.wfScrap = c.wfCores2 - tiles*c.wfTile
c.wfHull = tiles
if c.wfScrap > 0 {
c.wfTiles--
c.wfScrap += c.wfTile
}
c.dfTile = 1
c.dfTiles = c.dfCores2
c.dfScrap = 0
c.dfHull = c.dfCores2
case threadBlks <= groupBlks:
var (
tile = ceilQuo(threadBlks, dfBlks)
tiles = max(c.dfCores2/tile, 1)
)
c.dfTile = c.dfCores2 / tiles
c.dfTiles = tiles
c.dfScrap = c.dfCores2 - tiles*c.dfTile
c.dfHull = tiles
if c.dfScrap > 0 {
c.dfTiles--
c.dfScrap += c.dfTile
}
default:
c.dfTile = c.dfCores2
c.dfTiles = 1
c.dfScrap = 0
c.dfHull = 1
var (
tile = ceilQuo(threadBlks, groupBlks)
tiles = max(c.Groups/tile, 1)
)
c.groupTile = c.Groups / tiles
c.groupTiles = tiles
c.groupScrap = c.Groups - tiles*c.groupTile
c.groupHull = tiles
if c.groupScrap > 0 {
c.groupTiles--
c.groupScrap += c.groupTile
}
}
c.calleeName = c.name(c.callerName + "Callee")
var (
team = vb(c.name("team"))
tensors = vb(c.name("tensors"))
)
return cgen.Gens{
c.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: c.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: c.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: c.tc,
Callee: vb(c.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(c.wfHull),
il(c.dfHull),
il(c.groupHull),
},
Team: team,
},
},
}.Append(to)
}

func (c *consumeSums) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: c.tc,
Name: c.calleeName,
Task: vb(c.name("task")),
Pt: vb(c.name("pt")),
}
var (
body = make(cgen.Stmts, 6)
usedPt = false
)
c.tensors = vb(c.name("tensors"))
body[0] = cgen.Var{
Type: cgen.PtrPtrChar, What: c.tensors,
Init: callee.Any(),
}
coord := func(nm string, hull, i int) cgen.Gen {
var (
ret = vb(c.name(nm))
expr cgen.Gen
)
switch hull {
case 1:
expr = il(0)
default:
expr = cgen.Elem{
Arr: callee.Pt, Idx: il(i),
}
usedPt = true
}
body[1+i] = cgen.Var{
Type: cgen.PtrdiffT, What: ret,
Init: expr,
}
return ret
}
c.wfCoord = coord("w", c.wfHull, 0)
c.dfCoord = coord("d", c.dfHull, 1)
c.groupCoord = coord("g", c.groupHull, 2)
if !usedPt {
body[4] = cgen.Cast{
Type: cgen.Void,
Expr: callee.Pt,
}
}
body[5] = c.kernel1()
return callee.Func(body)
}

func (c *consumeSums) kernel1() cgen.Gen {
c.datPtrs = nil
c.bnPtrs = nil
var (
stmts cgen.Stmts
tensorIdx = 0
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
decl := func(ptr cgen.Gen) {
stmt(cgen.Var{
Type: cgen.RestrictPtrChar,
What: ptr,
Init: cgen.Elem{
Arr: c.tensors,
Idx: il(tensorIdx),
},
})
tensorIdx++
}
decls := func(n int) {
for ; n > 0; n-- {
datPtr := vb(c.name("datPtr"))
c.datPtrs = append(c.datPtrs, datPtr)
decl(datPtr)
}
}
c.sfPtr = vb(c.name("sfPtr"))
decl(c.sfPtr)
for op := range c.To.Ops {
op := &c.To.Ops[op]
switch op.Kind {
case mod.Add:
decls(op.Int)
case mod.Bn:
bnPtr := vb(c.name("bnPtr"))
c.bnPtrs = append(c.bnPtrs, bnPtr)
decl(bnPtr)
case mod.ReLU:
default:
panic("bug")
}
}
var (
need = len(c.To.Pitch1Bytes)
have = len(c.datPtrs)
)
c.datSplit = have
decls(need - have)
stmt(c.kernel2())
return stmts
}

func (c *consumeSums) kernel2() cgen.Gen {
c.groupIdx = vb(c.name("i"))
var (
stmts = make(cgen.Stmts, 3)
iters = 0
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.groupIdx,
Init: cgen.Mul{
Expr1: il(c.groupTile),
Expr2: c.groupCoord,
},
}
switch c.groupTiles {
case c.groupHull:
iters = c.groupTile
case 0:
iters = c.groupScrap
}
switch iters {
case 1:
stmts[2] = c.kernel3()
default:
var (
last = vb(c.name("ii"))
expr cgen.Gen
)
switch iters {
case 0:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.groupCoord,
Expr2: il(c.groupTiles),
},
Then: il(c.groupTile - 1),
Else: il(c.groupScrap - 1),
},
}
default:
expr = il(iters - 1)
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: c.groupIdx,
Expr2: expr,
},
}
stmts[2] = cgen.For{
Cond: cgen.CmpLE{
Expr1: c.groupIdx,
Expr2: last,
},
Post: cgen.IncPre{
Expr: c.groupIdx,
},
Body: c.kernel3(),
}
}
return stmts
}

func (c *consumeSums) kernel3() cgen.Gen {
c.dfIdx = vb(c.name("j"))
switch c.dfHull {
case 1:
c.dfLast = nil
default:
c.dfLast = vb(c.name("last"))
}
stmts := make(cgen.Stmts, 3)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.dfIdx,
Init: cgen.Mul{
Expr1: il(c.dfTile),
Expr2: c.dfCoord,
},
}
if c.dfLast != nil {
var expr cgen.Gen
switch c.dfTiles {
case c.dfHull:
expr = il(c.dfTile - 1)
case 0:
expr = il(c.dfScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.dfCoord,
Expr2: il(c.dfTiles),
},
Then: il(c.dfTile - 1),
Else: il(c.dfScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.dfLast,
Init: cgen.Add{
Expr1: c.dfIdx,
Expr2: expr,
},
}
}
stmts[2] = c.kernel4()
return stmts
}

func (c *consumeSums) kernel4() cgen.Gen {
var (
lh *loopH
lhToH int
lhToStep int
rel cgen.Gen
base cgen.Gen
relBreak int
lw *loopW
lwToH int
lwToW int
lwToStep int
)
layer7 := func() cgen.Gen {
var retIf cgen.Gen
if c.dfLast != nil {
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: c.dfIdx,
Expr2: c.dfLast,
},
Then: cgen.Return{},
}
}
return cgen.Stmts{
c.kernel5(),
retIf,
cgen.IncPre{
Expr: c.dfIdx,
},
}
}
layer6 := func() cgen.Gen {
if lwToStep == 0 {
return layer7()
}
last := vb(c.name("jj"))
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: cgen.Sub{
Expr1: il(lw.segPast - 1),
Expr2: rel,
},
Expr2: c.dfIdx,
},
},
cgen.For{
Cond: cgen.CmpLE{
Expr1: c.dfIdx,
Expr2: last,
},
Post: cgen.AddAssign{
Expr1: c.toW,
Expr2: il(lwToStep),
},
Body: layer7(),
},
}
}
layer5 := func() cgen.Gen {
c.toH = vb(c.name("toH"))
c.toW = vb(c.name("toW"))
c.lbs = lw.lbs
var (
exprW cgen.Gen
breakIf cgen.Gen
)
switch lwToStep {
case 0:
exprW = il(lwToW)
default:
exprW = addMul(
il(lwToW-lwToStep*lw.segFirst),
il(lwToStep),
rel,
)
}
if lw.segPast == relBreak {
breakIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: c.dfIdx,
Expr2: il(lh.segPast),
},
Then: cgen.Break,
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: c.toH,
Init: cgen.Add{
Expr1: base,
Expr2: il(lwToH),
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: c.toW,
Init: exprW,
},
layer6(),
breakIf,
}
}
layer4 := func() cgen.Gen {
var (
lws = lh.lws
tree func(int, int) cgen.Stmts
)
leaf := func(x int) cgen.Stmts {
lw = lws[x]
lwToH = lw.fromH / 2
lwToW = lw.fromW / 2
lwToStep = lw.fromStep / 2
var assn cgen.Gen
if x+1 < len(lws) {
assn = cgen.Assign{
Expr1: rel,
Expr2: il(lw.segPast),
}
}
return cgen.Stmts{
layer5(),
assn,
}
}
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = lws[first].segFirst
stop = lws[last].segPast
split = start + (stop-start)/2
x = first + 1
)
for lws[x].segPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: rel,
Expr2: il(lws[x].segFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(lws)-1)
}
layer3 := func() cgen.Gen {
if lh.segStep == 0 {
relBreak = -1
return layer4()
}
var (
last1 = lh.segPast - lh.segFirst - 1
last2 = last1 % lh.segStep
)
relBreak = last2 + 1
return cgen.For{
Post: cgen.CommaSpaced{
cgen.Assign{
Expr1: rel,
Expr2: il(0),
},
cgen.AddAssign{
Expr1: base,
Expr2: il(lhToStep),
},
},
Body: layer4(),
}
}
layer2 := func() cgen.Gen {
rel = vb(c.name("rel"))
base = vb(c.name("base"))
var (
relExpr cgen.Gen = cgen.Sub{
Expr1: c.dfIdx,
Expr2: il(lh.segFirst),
}
baseExpr = il(lhToH)
)
if lh.segStep != 0 {
var (
numer cgen.Gen = cgen.Cast{
Type: cgen.SizeT,
Expr: cgen.Paren{
Inner: relExpr,
},
}
denom = il(lh.segStep)
)
relExpr = cgen.Rem{
Expr1: numer,
Expr2: denom,
}
baseExpr = addMul(
baseExpr,
cgen.Quo{
Expr1: numer,
Expr2: denom,
},
il(lhToStep),
)
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: rel,
Init: relExpr,
},
cgen.Var{
Type: cgen.PtrdiffT,
What: base,
Init: baseExpr,
},
layer3(),
}
}
layer1 := func() cgen.Gen {
var (
lhs = c.segs.lhs
tree func(int, int) cgen.Stmts
)
leaf := func(x int) cgen.Stmts {
lh = lhs[x]
lhToH = lh.fromH / 2
lhToStep = lh.fromStep / 2
var assn cgen.Gen
if x+1 < len(lhs) {
assn = cgen.Assign{
Expr1: c.dfIdx,
Expr2: il(lh.segPast),
}
}
return cgen.Stmts{
layer2(),
assn,
}
}
tree = func(first, last int) cgen.Stmts {
if first == last {
return leaf(first)
}
var (
start = lhs[first].segFirst
stop = lhs[last].segPast
split = start + (stop-start)/2
x = first + 1
)
for lhs[x].segPast <= split {
x++
}
return cgen.Stmts{
cgen.If{
Cond: cgen.CmpL{
Expr1: c.dfIdx,
Expr2: il(lhs[x].segFirst),
},
Then: tree(first, x-1),
},
tree(x, last),
}
}
return tree(0, len(lhs)-1)
}
return layer1()
}

func (c *consumeSums) kernel5() cgen.Gen {
c.wfIdx = vb(c.name("k"))
var (
stmts = make(cgen.Stmts, 4)
retIf cgen.Gen
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: c.wfIdx,
Init: cgen.Mul{
Expr1: il(c.wfTile),
Expr2: c.wfCoord,
},
}
if c.wfHull > 1 {
var (
last = vb(c.name("kk"))
expr cgen.Gen
)
switch c.wfTiles {
case c.wfHull:
expr = il(c.wfTile - 1)
case 0:
expr = il(c.wfScrap - 1)
default:
expr = cgen.Paren{
Inner: cgen.Ternary{
Cond: cgen.CmpL{
Expr1: c.wfCoord,
Expr2: il(c.wfTiles),
},
Then: il(c.wfTile - 1),
Else: il(c.wfScrap - 1),
},
}
}
stmts[1] = cgen.Var{
Type: cgen.PtrdiffT,
What: last,
Init: cgen.Add{
Expr1: c.wfIdx,
Expr2: expr,
},
}
retIf = cgen.If1{
Cond: cgen.CmpGE{
Expr1: c.wfIdx,
Expr2: last,
},
Then: cgen.Return{},
}
}
if c.wfCores1 > 0 {
c.wfShort = false
stmts[2] = cgen.For{
Cond: cgen.CmpNE{
Expr1: c.wfIdx,
Expr2: il(c.wfCores1),
},
Post: cgen.IncPre{
Expr: c.wfIdx,
},
Body: cgen.Stmts{
c.kernel6(),
retIf,
},
}
}
if c.wfCores1 < c.wfCores2 {
c.wfShort = true
stmts[3] = c.kernel6()
}
return stmts
}

func (c *consumeSums) kernel6() cgen.Gen {
switch c.platform {
case raw.AVX512Float32:
return c.m512()
default:
panic("bug")
}
}

func (c *consumeSums) m512() cgen.Gen {
type Span struct {
vec cgen.Gen
lane int
lanes int
relC int
relH int
relW int
prior bool
}
var (
lbs []*loopB
rowIdx cgen.Gen
rowChans int
bnMuls [][]cgen.Gen
bnAdds [][]cgen.Gen
blkFirst int
blkCnt int
iters int
iterIdx cgen.Gen
accBlks [][]int
bwd *quadfft.Bwd
spans []*Span
)
addr := func(span *Span, ptr int) cgen.Gen {
var (
ae = c.datPtrs[ptr]
pitch1 = c.To.Pitch1Bytes[ptr]
pitch2 = c.To.Pitch2Bytes[ptr]
groupPitch = c.toChans * pitch2
corePitch = c.wfSliceFrags1 * pitch2
rowPitch = c.wfMeldFrags * pitch2
toStep = lbs[blkFirst].fromStep / 2
blkPitch = toStep * c.datBytes
iterPitch = blkCnt * blkPitch
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
span.relC*pitch2 +
span.relH*pitch1 +
span.relW*c.datBytes -
span.lane*c.datBytes,
),
}
ae = addMul(ae, il(groupPitch), c.groupIdx)
ae = addMul(ae, il(corePitch), c.wfIdx)
ae = addMul(ae, il(rowPitch), rowIdx)
ae = addMul(ae, il(pitch1), c.toH)
ae = addMul(ae, il(c.datBytes), c.toW)
ae = addMul(ae, il(iterPitch), iterIdx)
return ae
}
mask := func(span *Span) cgen.Gen {
var (
mask1 = 1<<uint(span.lanes) - 1
mask2 = mask1 << uint(span.lane)
)
return il(mask2)
}
layer13 := func() cgen.Gen {
var (
stmts cgen.Stmts
ptr = c.datSplit
stop = len(c.datPtrs)
)
for ; ptr < stop; ptr++ {
for _, span := range spans {
if span == nil {
continue
}
stmts = append(
stmts,
avx.Mm512MaskStoreuPs{
addr(span, ptr),
mask(span),
span.vec,
},
)
}
}
return stmts
}
layer12 := func() cgen.Gen {
var (
stmts cgen.Stmts
co = make([]*Span, 1, 2)
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
ops := func() {
var (
vec = co[0].vec
datPtr = 0
bnPtr = 0
)
for op := range c.To.Ops {
op := &c.To.Ops[op]
switch op.Kind {
case mod.Add:
for n := op.Int; n > 0; n-- {
for _, span := range co {
stmt(cgen.Assign{
Expr1: vec,
Expr2: avx.Mm512AddPs{
vec,
avx.Mm512MaskzLoaduPs{
mask(span),
addr(span, datPtr),
},
},
})
}
datPtr++
}
case mod.Bn:
switch {
case len(co) == 1:
fallthrough
case co[0].relC == co[1].relC:
ch := co[0].relC
stmt(&bn.Apply{
Ctx: c.bc,
Mul: bnMuls[ch][bnPtr],
Add: bnAdds[ch][bnPtr],
To: vec,
})
default:
for _, span := range co {
ch := span.relC
stmt(&bn.Apply{
Ctx: c.bc,
Mul: bnMuls[ch][bnPtr],
Add: bnAdds[ch][bnPtr],
To: vec,
Mask: mask(span),
})
}
}
bnPtr++
case mod.ReLU:
stmt(&act.ReLU{
Ctx: c.ac,
NegSlope: op.Float,
Var: vec,
})
default:
panic("bug")
}
}
}
for x1, span1 := range spans {
if span1 == nil ||
span1.prior {
continue
}
co[0] = span1
for x2 := x1 + 1; x2 < len(spans); x2++ {
span2 := spans[x2]
if span2 != nil &&
span2.vec == span1.vec {
co = co[:2]
co[1] = span2
span2.prior = true
break
}
}
ops()
co = co[:1]
}
stmt(layer13())
return stmts
}
layer11 := func() cgen.Gen {
if rowChans != c.wfMeldFrags {
return layer12()
}
var (
stmts cgen.Stmts
pms [2]cgen.Gen
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
ctrl := func(side, lanes int) cgen.Gen {
pm := pms[side]
if pm == nil {
pm = vb(c.name("pm"))
var (
set = make(avx.Mm512SetEpi32, 16)
base = side * (16 + 8)
off = 0
)
for x := 15; x >= 0; x-- {
set[x] = il(base + off)
if off++; off == lanes {
base ^= 16
off = 0
}
}
stmt(cgen.Var{
Type: avx.M512i, What: pm,
Init: set,
})
pms[side] = pm
}
return pm
}
for x1, span1 := range spans {
if span1 == nil {
continue
}
var (
x2 = x1 + 1
span2 *Span
pm cgen.Gen
)
for ; x2 < len(spans); x2++ {
span2 = spans[x2]
if span2 == nil ||
span2.relC != span1.relC ||
span2.relH != span1.relH {
continue
}
switch {
case span1.relW+span1.lanes == span2.relW:
pm = ctrl(0, span1.lanes)
case span2.relW+span2.lanes == span1.relW:
pm = ctrl(1, span2.lanes)
span1.relW = span2.relW
default:
continue
}
break
}
if x2 == len(spans) {
continue
}
pack := vb(c.name("pack"))
stmt(cgen.Var{
Type: avx.M512, What: pack,
Init: avx.Mm512Permutex2varPs{
span1.vec, pm,
span2.vec,
},
})
span1.vec = pack
span1.lane = 0
span1.lanes += span2.lanes
spans[x2] = nil
}
stmt(layer12())
return stmts
}
layer10 := func() cgen.Gen {
if rowChans != 1 {
return layer11()
}
var (
stmts cgen.Stmts
pms [2]cgen.Gen
attach []int
)
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
ctrl := func(side, lanes int) cgen.Gen {
pm := pms[side]
if pm == nil {
pm = vb(c.name("pm"))
var (
set = make(avx.Mm512SetEpi32, 16)
base = side * 8
off = 0
)
for x := 15; x >= 0; x-- {
set[x] = il(base + off)
if off++; off == lanes {
base += 8
base %= 32
off = 0
}
}
stmt(cgen.Var{
Type: avx.M512i, What: pm,
Init: set,
})
pms[side] = pm
}
return pm
}
for x1, span1 := range spans {
if span1 == nil {
continue
}
var (
lanes = span1.lanes
nextW = span1.relW + lanes
avail = 16 - lanes
)
for x2 := x1 + 1; x2 < len(spans); x2++ {
span2 := spans[x2]
if span2 != nil &&
span2.relH == span1.relH &&
span2.relW == nextW &&
span2.lanes <= avail {
attach = append(attach, x2)
nextW += span2.lanes
avail -= span2.lanes
}
}
n := len(attach)
if n == 0 {
continue
}
var (
pack = vb(c.name("pack"))
vec1 = span1.vec
vec2 = spans[attach[n-1]].vec
side = btoi(span1.lane != 0)
pm = ctrl(side, lanes)
expr cgen.Gen
)
switch vec1 {
case vec2:
expr = avx.Mm512PermutexvarPs{
pm, vec1,
}
default:
expr = avx.Mm512Permutex2varPs{
vec1, pm, vec2,
}
}
stmt(cgen.Var{
Type: avx.M512, What: pack,
Init: expr,
})
span1.vec = pack
span1.lane = 0
span1.lanes = 16 - avail
for _, x2 := range attach {
spans[x2] = nil
}
attach = attach[:0]
}
stmt(layer11())
return stmts
}
layer9 := func() cgen.Gen {
spans = spans[:0]
for h := 0; h < 8; h++ {
for acc, blks := range accBlks {
dat := bwd.Out[acc*8+h]
if dat == nil {
continue
}
for side, blk := range blks {
lb := lbs[blk]
if lb.yieldH <= h {
continue
}
var (
relBlk = blk - lb.blkFirst
toStep = lb.fromStep / 2
w = relBlk * toStep
)
spans = append(
spans, &Span{
vec: dat,
lane: side * 8,
lanes: lb.yieldW,
relC: side % rowChans,
relH: lb.fromH/2 + h,
relW: lb.fromW/2 + w,
prior: false,
},
)
}
}
}
return layer10()
}
layer8 := func() cgen.Gen {
var stmts cgen.Stmts
stmt := func(st cgen.Gen) {
stmts = append(stmts, st)
}
load := func(pile, acc, part int) cgen.Gen {
var (
ae = c.sfPtr
sitePitch = c.sfSiteBytes11
rowPitch = c.sfRowBytes11
meldPitch = c.sfMeldBytes11
)
if len(lbs) == c.dfSliceFrags2 {
sitePitch = c.sfSiteBytes12
rowPitch = c.sfRowBytes12
}
if rowChans == 1 {
meldPitch = c.sfMeldBytes21
}
var (
meldFirst = blkFirst / c.dfMeldFrags
iterMelds = blkCnt / c.dfMeldFrags
iterPitch = iterMelds * meldPitch
accPitch = c.wfMeldFrags * c.sfFragBytes
partPitch = accPitch / 2
)
ae = cgen.Add{
Expr1: ae,
Expr2: il(
pile*c.sfPileBytes +
meldFirst*meldPitch +
acc*accPitch +
part*partPitch,
),
}
ae = addMul(ae, il(c.sfGroupBytes), c.groupIdx)
ae = addMul(ae, il(c.sfCoreBytes1), c.dfIdx)
ae = addMul(ae, il(sitePitch), c.wfIdx)
ae = addMul(ae, il(rowPitch), rowIdx)
ae = addMul(ae, il(iterPitch), iterIdx)
return avx.Mm512LoaduPs{ae}
}
for pile := 0; pile < c.zoneFrags; pile++ {
for acc, blks := range accBlks {
for part := 0; part < 2; part++ {
var (
x1 = acc * c.zoneFrags * 2
x2 = x1 + pile*2 + part
sf1 = bwd.In[x2]
expr cgen.Gen
)
switch {
case part == 1 && len(blks) == 1:
var (
sf2 = bwd.In[x2-1]
ctrl = 1<<6 | 0<<4 | 3<<2 | 2<<0
)
expr = avx.Mm512ShuffleF32x4{
sf2, sf2, il(ctrl),
}
default:
expr = load(pile, acc, part)
}
stmt(cgen.Var{
Type: avx.M512, What: sf1,
Init: expr,
})
}
}
}
stmt(bwd)
stmt(layer9())
return stmts
}
layer7 := func() cgen.Gen {
bwd = &quadfft.Bwd{
Platform: c.platform,
Nms: c.nms,
}
var (
accs = len(accBlks)
each = c.zoneFrags * 2
)
for x := 0; x < accs*each; x++ {
var sf cgen.Gen
switch x % 2 {
case 0:
sf = vb(c.name("sfRe"))
default:
sf = vb(c.name("sfIm"))
}
bwd.In[x] = sf
}
for acc, blks := range accBlks {
yieldH := 0
for _, blk := range blks {
yieldH = max(
yieldH,
lbs[blk].yieldH,
)
}
for h := 0; h < yieldH; h++ {
var (
x = acc*each + h
dat = vb(c.name("dat"))
)
bwd.Out[x] = dat
}
}
return layer8()
}
layer6 := func() cgen.Gen {
switch {
case rowChans == 1:
var (
quo = blkCnt / 2
rem = blkCnt % 2
)
accBlks = make([][]int, quo+rem)
for acc := range accBlks {
var (
blk1 = blkFirst + acc*2
blk2 = blk1 + 1
blks = []int{blk1, blk2}
)
if acc == quo {
blks = blks[:1]
}
accBlks[acc] = blks
}
case blkCnt == 1:
accBlks = [][]int{
{blkFirst, blkFirst},
}
default:
accBlks = [][]int{
{blkFirst, blkFirst + 1},
{blkFirst + 1, blkFirst},
}
}
return layer7()
}
layer5 := func() cgen.Gen {
iterIdx = vb(c.name("t"))
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: iterIdx,
Init: il(0),
},
func() cgen.Gen {
if iters == 1 {
return layer6()
}
return cgen.For{
Cond: cgen.CmpL{
Expr1: iterIdx,
Expr2: il(iters),
},
Post: cgen.IncPre{
Expr: iterIdx,
},
Body: layer6(),
}
}(),
}
}
layer4 := func() cgen.Gen {
var gens cgen.Gens
gen := func() {
gens = append(gens, layer5())
}
each := c.dfMeldFrags
if rowChans == 1 {
each *= 2
}
for blk := 0; ; {
rem := len(lbs) - blk
if rem <= each {
if rem > 0 {
blkFirst = blk
blkCnt = rem
iters = 1
gen()
}
break
}
blkFirst = blk
blkCnt = each
iters = 1
for lb := lbs[blk]; ; {
blk += each
if blk+each > len(lbs) ||
lbs[blk+each-1] != lb {
break
}
iters++
}
gen()
}
return gens
}
layer3 := func() cgen.Gen {
n := len(c.bnPtrs)
if n == 0 {
return layer4()
}
bnMuls = make([][]cgen.Gen, rowChans)
bnAdds = make([][]cgen.Gen, rowChans)
var (
last = n * rowChans
gens = make(cgen.Gens, last+1)
)
for ch1 := 0; ch1 < rowChans; ch1++ {
var (
muls = make([]cgen.Gen, n)
adds = make([]cgen.Gen, n)
ch2 = il(ch1)
)
ch2 = addMul(ch2, il(c.toChans), c.groupIdx)
ch2 = addMul(ch2, il(c.wfSliceFrags1), c.wfIdx)
ch2 = addMul(ch2, il(c.wfMeldFrags), rowIdx)
ch2 = cgen.Paren{
Inner: ch2,
}
for x1, ptr := range c.bnPtrs {
var (
bnMul = vb(c.name("bnMul"))
bnAdd = vb(c.name("bnAdd"))
x2 = x1*rowChans + ch1
)
muls[x1] = bnMul
adds[x1] = bnAdd
gens[x2] = &bn.Load{
Ctx: c.bc,
Mas: ptr,
Channel: ch2,
Mul: bnMul,
Add: bnAdd,
}
}
bnMuls[ch1] = muls
bnAdds[ch1] = adds
}
gens[last] = layer4()
return gens
}
layer2 := func() cgen.Gen {
rowIdx = vb(c.name("r"))
var (
stmts = make(cgen.Stmts, 3)
rows1 int
rows2 int
)
stmts[0] = cgen.Var{
Type: cgen.PtrdiffT,
What: rowIdx,
Init: il(0),
}
switch {
case c.wfShort:
rows2 = c.wfSliceMelds2
rows1 = rows2 - c.wfSliceFrags2%c.wfMeldFrags
default:
rows2 = c.wfSliceMelds1
rows1 = rows2
}
if rows1 > 0 {
rowChans = c.wfMeldFrags
stmts[1] = cgen.For{
Cond: cgen.CmpNE{
Expr1: rowIdx,
Expr2: il(rows1),
},
Post: cgen.IncPre{
Expr: rowIdx,
},
Body: layer3(),
}
}
if rows1 < rows2 {
rowChans = 1
stmts[2] = layer3()
}
return stmts
}
layer1 := func() cgen.Gen {
var (
n1 = len(c.lbs)
n2 = c.lbs[n1-1].blkPast
)
lbs = make([]*loopB, n2)
for _, lb := range c.lbs {
blk := lb.blkFirst
for ; blk < lb.blkPast; blk++ {
lbs[blk] = lb
}
}
return layer2()
}
return layer1()
}

Top || internal/compile/author/sumr/sumr.go

package sumr

import (
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func mix(a, b cgen.Stmts) cgen.Stmts {
var (
tot = len(a) + len(b)
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
if i < len(a) {
ret[n] = a[i]
n++
}
if i < len(b) {
ret[n] = b[i]
n++
}
}
return ret
}

type Pack struct {
Platform raw.Platform
Nms nmsrc.Src
Vars []cgen.Gen
}

func (p *Pack) Append(to []byte) []byte {
var gen cgen.Gen
switch p.Platform {
case raw.AVX512Float32:
gen = &m512Pack{Pack: p}
default:
panic("bug")
}
return gen.Append(to)
}

func (p *Pack) name(s string) cgen.Gen {
return cgen.Vb(p.Nms.Name(s))
}

type m512Pack struct {
*Pack
pmEven cgen.Gen
pmOdd cgen.Gen
pm1Lo cgen.Gen
pm1Hi cgen.Gen
pm4Lo cgen.Gen
pm4Hi cgen.Gen
}

func (m *m512Pack) Append(to []byte) []byte {
n := len(m.Vars)
switch {
case n == 0:
return to
case n > 16:
panic("bug")
}
gs := make(cgen.Gens, 3)
switch {
case n == 1 || n > 8:
gs[1] = m.fold(m.Vars, 1)
default:
gs[1] = m.fold(m.Vars, 2)
var (
lower = m.Vars[0]
upper = m.name("upper")
)
decl := cgen.Var{
Type: avx.M512, What: upper,
}
assn := cgen.Assign{
Expr1: lower,
}
if n == 2 {
decl.Init = avx.Mm512ShufflePs{
lower, lower, il(3<<2 | 1),
}
assn.Expr2 = avx.Mm512ShufflePs{
lower, lower, il(2<<2 | 0),
}
} else {
m.pmOdd = m.name("pmOdd")
decl.Init = avx.Mm512PermutexvarPs{
m.pmOdd, lower,
}
m.pmEven = m.name("pmEven")
assn.Expr2 = avx.Mm512PermutexvarPs{
m.pmEven, lower,
}
}
gs[2] = cgen.Stmts{
decl,
assn,
cgen.Assign{
Expr1: lower,
Expr2: avx.Mm512AddPs{
lower, upper,
},
},
}
}
gs[0] = m.pms()
return gs.Append(to)
}

func (m *m512Pack) fold(vs []cgen.Gen, w int) cgen.Stmts {
var (
stmts = make(cgen.Stmts, 4)
lower = vs[0]
upper = m.name("upper")
)
decl := cgen.Var{
Type: avx.M512, What: upper,
}
if n := len(vs); n == 1 {
if w < 8 {
stmts[0] = m.fold(vs, w*2)
}
switch w {
case 1:
decl.Init = avx.Mm512ShufflePs{
lower, lower, il(1),
}
case 2:
decl.Init = avx.Mm512ShufflePs{
lower, lower, il(3<<2 | 2),
}
case 4:
decl.Init = avx.Mm512ShuffleF32x4{
lower, lower, il(1),
}
case 8:
decl.Init = avx.Mm512ShuffleF32x4{
lower, lower, il(3<<2 | 2),
}
}
} else {
if w < 8 {
var (
n2 = n >> 1
n1 = n - n2
vs1 = make([]cgen.Gen, n1)
vs2 = make([]cgen.Gen, n2)
)
for i, v := range vs {
if i&1 == 0 {
vs1[i>>1] = v
} else {
vs2[i>>1] = v
}
}
stmts[0] = mix(
m.fold(vs1, w*2),
m.fold(vs2, w*2),
)
}
v := vs[1]
assn := cgen.Assign{
Expr1: lower,
}
switch w {
case 1:
if m.pm1Lo == nil {
m.pm1Lo = m.name("pm1Lo")
m.pm1Hi = m.name("pm1Hi")
}
decl.Init = avx.Mm512Permutex2varPs{
lower, m.pm1Hi, v,
}
assn.Expr2 = avx.Mm512Permutex2varPs{
lower, m.pm1Lo, v,
}
case 2:
decl.Init = avx.Mm512ShufflePs{
lower, v, il(3<<6 | 2<<4 | 3<<2 | 2),
}
assn.Expr2 = avx.Mm512ShufflePs{
lower, v, il(1<<6 | 0<<4 | 1<<2 | 0),
}
case 4:
if m.pm4Lo == nil {
m.pm4Lo = m.name("pm4Lo")
m.pm4Hi = m.name("pm4Hi")
}
decl.Init = avx.Mm512Permutex2varPs{
lower, m.pm4Hi, v,
}
assn.Expr2 = avx.Mm512Permutex2varPs{
lower, m.pm4Lo, v,
}
case 8:
decl.Init = avx.Mm512ShuffleF32x4{
lower, v, il(3<<6 | 2<<4 | 3<<2 | 2),
}
assn.Expr2 = avx.Mm512ShuffleF32x4{
lower, v, il(1<<6 | 0<<4 | 1<<2 | 0),
}
}
stmts[2] = assn
}
stmts[1] = decl
stmts[3] = cgen.Assign{
Expr1: lower,
Expr2: avx.Mm512AddPs{
lower, upper,
},
}
return stmts
}

func (m *m512Pack) pms() cgen.Gen {
decl := func(pm cgen.Gen, fn func(int) int) cgen.Gen {
if pm == nil {
return nil
}
set := make(avx.Mm512SetEpi32, 16)
for i := 0; i < 16; i++ {
set[15-i] = il(fn(i))
}
return cgen.Var{
Type: avx.M512i, What: pm,
Init: set,
}
}
return cgen.Stmts{
decl(m.pmEven, func(i int) int {
return 0 + i*2
}),
decl(m.pmOdd, func(i int) int {
return 1 + i*2
}),
decl(m.pm1Lo, func(i int) int {
return i&^1 + i&1*16
}),
decl(m.pm1Hi, func(i int) int {
return i | 1 + i&1*16
}),
decl(m.pm4Lo, func(i int) int {
return i&^4 + i&4*4
}),
decl(m.pm4Hi, func(i int) int {
return i | 4 + i&4*4
}),
}
}

Top || internal/compile/author/threader/threader.go

package threader

import (
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/errmsg"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
)

const maxNd = 4

func vb(a string) cgen.Gen {
return cgen.Vb(a)
}

type Ctx struct {
prefix string
cacheLine int
nms nmsrc.Src
emc *errmsg.Ctx
taskType string
taskCallee string
taskAny string
taskNd string
taskHull string
teamType string
destroy string
create string
pthreadT string
do string
ptrTask cgen.Gen
PtrTeam cgen.Gen
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, emc *errmsg.Ctx) *Ctx {
prefix := pl.Config.Prefix + "Threader"
var cacheLine int
switch pl.Config.Platform {
case raw.AVX512Float32:
cacheLine = 1 << 6
default:
panic("bug")
}
ctx := &Ctx{
prefix: prefix,
cacheLine: cacheLine,
nms: nms,
emc: emc,
taskType: nms.Name(prefix + "Task"),
taskCallee: nms.Name("callee"),
taskAny: nms.Name("any"),
taskNd: nms.Name("nd"),
taskHull: nms.Name("hull"),
teamType: nms.Name(prefix + "Team"),
destroy: nms.Name(prefix + "Destroy"),
create: nms.Name(prefix + "Create"),
pthreadT: nms.Name(prefix + "PthreadT"),
do: nms.Name(prefix + "Do"),
}
ctx.ptrTask = cgen.Ptr{Type: vb(ctx.taskType)}
ctx.PtrTeam = cgen.Ptr{Type: vb(ctx.teamType)}
return ctx
}

func (c *Ctx) name(a string) string {
return c.nms.Name(a)
}

func (c *Ctx) nameP(a string) string {
return c.name(c.prefix + a)
}

type must cgen.Call

func (m must) Append(to []byte) []byte {
return cgen.For{Cond: cgen.Unlikely{
Cond: cgen.Call(m),
}}.Append(to)
}

type lock struct {
ptr cgen.Gen
mut string
}

func (l lock) Append(to []byte) []byte {
return must{
Func: cgen.PthreadMutexLock,
Args: cgen.AddrArrow{Expr: l.ptr, Name: l.mut},
}.Append(to)
}

type unlock struct {
ptr cgen.Gen
mut string
}

func (u unlock) Append(to []byte) []byte {
return must{
Func: cgen.PthreadMutexUnlock,
Args: cgen.AddrArrow{Expr: u.ptr, Name: u.mut},
}.Append(to)
}

type wait struct {
ptr cgen.Gen
cond string
mut string
}

func (w wait) Append(to []byte) []byte {
return must{
Func: cgen.PthreadCondWait,
Args: cgen.CommaSpaced{
cgen.AddrArrow{Expr: w.ptr, Name: w.cond},
cgen.AddrArrow{Expr: w.ptr, Name: w.mut},
},
}.Append(to)
}

type signal struct {
ptr cgen.Gen
cond string
}

func (s signal) Append(to []byte) []byte {
return must{
Func: cgen.PthreadCondSignal,
Args: cgen.AddrArrow{Expr: s.ptr, Name: s.cond},
}.Append(to)
}

type join struct {
ptr cgen.Gen
thr string
}

func (j join) Append(to []byte) []byte {
return must{
Func: cgen.PthreadJoin,
Args: cgen.CommaSpaced{
cgen.Arrow{Expr: j.ptr, Name: j.thr},
cgen.Zero,
},
}.Append(to)
}

type destroyMut struct {
ptr cgen.Gen
mut string
}

func (d destroyMut) Append(to []byte) []byte {
return must{
Func: cgen.PthreadMutexDestroy,
Args: cgen.AddrArrow{Expr: d.ptr, Name: d.mut},
}.Append(to)
}

type destroyCond struct {
ptr cgen.Gen
cond string
}

func (d destroyCond) Append(to []byte) []byte {
return must{
Func: cgen.PthreadCondDestroy,
Args: cgen.AddrArrow{Expr: d.ptr, Name: d.cond},
}.Append(to)
}

type round struct {
expr cgen.Gen
pow2 cgen.IntLit
}

func (r round) Append(to []byte) []byte {
return cgen.And{
Expr1: cgen.Paren{Inner: cgen.Add{
Expr1: r.expr,
Expr2: r.pow2 - 1,
}},
Expr2: -r.pow2,
}.Append(to)
}

type Prep struct {
*Ctx
to []byte
maxNd cgen.Gen
calleeType cgen.Gen
hubType string
hubMut string
hubCond string
hubPending string
hubOffset string
hubMask string
hubStatus string
nodeType string
nodeMut string
nodeNp string
nodePt string
nodeTask string
nodeCond string
nodeTeam string
nodeThr string
unwindType string
unwindJoin string
unwindNodeConds string
unwindNodeMuts string
unwindHubCond string
unwindHubMut string
unwindNodes string
unwindHub string
teamNt string
teamHub string
teamNodes string
teamUnwind string
ptrHub cgen.Gen
ptrNode cgen.Gen
inc string
put string
add string
main string
}

func (p *Prep) Append(to []byte) []byte {
p.to = to
p.stage1()
p.stage2()
p.stage3()
p.stage4()
p.stage5()
p.stage6()
p.stage7()
p.stage8()
p.stage9()
return p.to
}

func (p *Prep) newline() {
p.to = cgen.Newline.Append(p.to)
}

func (p *Prep) stage1() {
p.maxNd = cgen.IntLit(maxNd)
p.calleeType = vb(p.nameP("Callee"))
p.hubType = p.nameP("Hub")
p.hubMut = p.name("mut")
p.hubCond = p.name("cond")
p.hubPending = p.name("pending")
p.hubOffset = p.name("offset")
p.hubMask = p.name("mask")
p.hubStatus = p.name("status")
p.nodeType = p.nameP("Node")
p.nodeMut = p.name("mut")
p.nodeNp = p.name("np")
p.nodePt = p.name("pt")
p.nodeTask = p.name("task")
p.nodeCond = p.name("cond")
p.nodeTeam = p.name("team")
p.nodeThr = p.name("thr")
p.unwindType = p.nameP("Unwind")
p.unwindJoin = p.name("join")
p.unwindNodeConds = p.name("nodeConds")
p.unwindNodeMuts = p.name("nodeMuts")
p.unwindHubCond = p.name("hubCond")
p.unwindHubMut = p.name("hubMut")
p.unwindNodes = p.name("nodes")
p.unwindHub = p.name("hub")
p.teamNt = p.name("nt")
p.teamHub = p.name("hub")
p.teamNodes = p.name("nodes")
p.teamUnwind = p.name("unwind")
p.ptrHub = cgen.Ptr{Type: vb(p.hubType)}
p.ptrNode = cgen.Ptr{Type: vb(p.nodeType)}
p.inc = p.nameP("Inc")
p.put = p.nameP("Put")
p.add = p.nameP("Add")
p.main = p.nameP("Main")
}

func (p *Prep) stage2() {
p.to = cgen.Gens{
cgen.StructFwd(p.taskType),
cgen.TypedefPtrFunc{
ReturnType: cgen.Void,
What: p.calleeType,
Params: cgen.CommaSpaced{p.ptrTask, cgen.PtrInt64T},
},
cgen.StructFwd(p.hubType),
cgen.StructFwd(p.nodeType),
cgen.StructFwd(p.unwindType),
cgen.StructFwd(p.teamType),
cgen.Newline,
}.Append(p.to)
}

func (p *Prep) stage3() {
p.to = cgen.Gens{
cgen.StructDef{
Name: p.taskType,
Fields: cgen.Stmts{
cgen.Field{Type: p.calleeType, What: vb(p.taskCallee)},
cgen.Field{Type: cgen.PtrVoid, What: vb(p.taskAny)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.taskNd)},
cgen.Field{
Type: cgen.Int64T,
What: cgen.Elem{Arr: vb(p.taskHull), Idx: p.maxNd},
},
},
},
cgen.Newline,
cgen.StructDef{
Name: p.hubType,
Fields: cgen.Stmts{
cgen.Field{Type: cgen.PthreadMutexT, What: vb(p.hubMut)},
cgen.Field{Type: cgen.PthreadCondT, What: vb(p.hubCond)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.hubPending)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.hubOffset)},
cgen.Field{Type: cgen.Long, What: vb(p.hubMask)},
cgen.Field{
Type: cgen.Long,
What: cgen.Elem{Arr: vb(p.hubStatus)},
},
},
},
cgen.Newline,
cgen.StructDef{
Name: p.nodeType,
Fields: cgen.Stmts{
cgen.Field{Type: cgen.PthreadMutexT, What: vb(p.nodeMut)},
cgen.Field{Type: cgen.Int64T, What: vb(p.nodeNp)},
cgen.Field{
Type: cgen.Int64T,
What: cgen.Elem{Arr: vb(p.nodePt), Idx: p.maxNd},
},
cgen.Field{Type: p.ptrTask, What: vb(p.nodeTask)},
cgen.Field{Type: cgen.PthreadCondT, What: vb(p.nodeCond)},
cgen.Field{Type: p.PtrTeam, What: vb(p.nodeTeam)},
cgen.Field{Type: cgen.PthreadT, What: vb(p.nodeThr)},
},
Attrs: cgen.Aligned(p.cacheLine),
},
cgen.Newline,
cgen.StructDef{
Name: p.unwindType,
Fields: cgen.Stmts{
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.unwindJoin)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.unwindNodeConds)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.unwindNodeMuts)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.unwindHubCond)},
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.unwindHubMut)},
cgen.Field{Type: cgen.PtrVoid, What: vb(p.unwindNodes)},
cgen.Field{Type: cgen.PtrVoid, What: vb(p.unwindHub)},
},
},
cgen.Newline,
cgen.StructDef{
Name: p.teamType,
Fields: cgen.Stmts{
cgen.Field{Type: cgen.PtrdiffT, What: vb(p.teamNt)},
cgen.Field{Type: p.ptrHub, What: vb(p.teamHub)},
cgen.Field{Type: p.ptrNode, What: vb(p.teamNodes)},
cgen.Field{Type: vb(p.unwindType), What: vb(p.teamUnwind)},
},
},
cgen.Newline,
}.Append(p.to)
}

func (p *Prep) stage4Inc() {
var (
nd = vb(p.name("nd"))
hull = vb(p.name("hull"))
pt = vb(p.name("pt"))
i = vb(p.name("i"))
elem = vb(p.name("elem"))
ptI = cgen.Elem{Arr: pt, Idx: i}
)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.inc,
Params: cgen.CommaLines{
cgen.Param{Type: cgen.PtrdiffT, What: nd},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: hull},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: pt},
},
Body: cgen.Stmts{cgen.For{
Init: cgen.Var{Type: cgen.PtrdiffT, What: i, Init: cgen.Zero},
Cond: cgen.CmpL{Expr1: i, Expr2: nd},
Post: cgen.IncPre{Expr: i},
Body: cgen.Stmts{
cgen.Var{Type: cgen.Int64T, What: elem, Init: ptI},
cgen.If{
Cond: cgen.CmpE{
Expr1: cgen.IncPre{Expr: elem},
Expr2: cgen.Elem{Arr: hull, Idx: i},
},
Then: cgen.Stmts{cgen.Assign{Expr1: ptI, Expr2: cgen.Zero}},
Else: cgen.Stmts{
cgen.Assign{Expr1: ptI, Expr2: elem},
cgen.Break,
},
},
},
}},
}.Append(p.to)
}

func (p *Prep) stage4Put() {
var (
nd = vb(p.name("nd"))
hull = vb(p.name("hull"))
pt = vb(p.name("pt"))
val = vb(p.name("val"))
i = vb(p.name("i"))
iOk = cgen.CmpL{Expr1: i, Expr2: nd}
wrap = vb(p.name("wrap"))
carry = vb(p.name("carry"))
)
ptI := cgen.Elem{
Arr: pt,
Idx: cgen.IncPost{Expr: i},
}
p.to = cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.put,
Params: cgen.CommaLines{
cgen.Param{Type: cgen.PtrdiffT, What: nd},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: hull},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: pt},
cgen.Param{Type: cgen.Int64T, What: val},
},
Body: cgen.Stmts{
cgen.Var{Type: cgen.PtrdiffT, What: i, Init: cgen.Zero},
cgen.For{
Cond: cgen.Land{Expr1: iOk, Expr2: val},
Body: cgen.Stmts{
cgen.Var{
Type: cgen.Int64T,
What: wrap,
Init: cgen.Elem{Arr: hull, Idx: i},
},
cgen.Var{
Type: cgen.Int64T,
What: carry,
Init: cgen.Quo{Expr1: val, Expr2: wrap},
},
cgen.Assign{
Expr1: ptI,
Expr2: cgen.Sub{
Expr1: val,
Expr2: cgen.Mul{Expr1: carry, Expr2: wrap},
},
},
cgen.Assign{Expr1: val, Expr2: carry},
},
},
cgen.For{
Cond: iOk,
Post: cgen.Assign{Expr1: ptI, Expr2: cgen.Zero},
},
},
}.Append(p.to)
}

func (p *Prep) stage4Add() {
var (
nd = vb(p.name("nd"))
hull = vb(p.name("hull"))
pt = vb(p.name("pt"))
plus = vb(p.name("plus"))
carry = vb(p.name("carry"))
i = vb(p.name("i"))
wrap = vb(p.name("wrap"))
sum = vb(p.name("sum"))
ptI = cgen.Elem{Arr: pt, Idx: i}
)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.add,
Params: cgen.CommaLines{
cgen.Param{Type: cgen.PtrdiffT, What: nd},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: hull},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: pt},
cgen.Param{Type: cgen.RestrictPtrInt64T, What: plus},
cgen.Param{Type: cgen.Int64T, What: carry},
},
Body: cgen.Stmts{cgen.For{
Init: cgen.Var{Type: cgen.PtrdiffT, What: i, Init: cgen.Zero},
Cond: cgen.CmpL{Expr1: i, Expr2: nd},
Post: cgen.IncPre{Expr: i},
Body: cgen.Stmts{
cgen.Var{
Type: cgen.Int64T,
What: wrap,
Init: cgen.Elem{Arr: hull, Idx: i},
},
cgen.Var{
Type: cgen.Int64T,
What: sum,
Init: cgen.Add{
Expr1: cgen.Add{
Expr1: ptI,
Expr2: cgen.Elem{Arr: plus, Idx: i},
},
Expr2: carry,
},
},
cgen.If{
Cond: cgen.CmpL{Expr1: sum, Expr2: wrap},
Then: cgen.Stmts{
cgen.Assign{Expr1: ptI, Expr2: sum},
cgen.Assign{Expr1: carry, Expr2: cgen.Zero},
},
Else: cgen.Stmts{
cgen.Assign{
Expr1: ptI,
Expr2: cgen.Sub{Expr1: sum, Expr2: wrap},
},
cgen.Assign{Expr1: carry, Expr2: cgen.One},
},
},
},
}},
}.Append(p.to)
}

func (p *Prep) stage4() {
p.stage4Inc()
p.newline()
p.stage4Put()
p.newline()
p.stage4Add()
p.newline()
}

type stage5 struct {
*Prep
arg cgen.Gen
callee cgen.Gen
hand cgen.Gen
hub cgen.Gen
hullField cgen.Gen
lockHub cgen.Gen
lockNode1 cgen.Gen
lockNode2 cgen.Gen
mask cgen.Gen
maskField cgen.Gen
nd cgen.Gen
node1 cgen.Gen
node2 cgen.Gen
nodes cgen.Gen
np cgen.Gen
npField1 cgen.Gen
npField2 cgen.Gen
nt cgen.Gen
offset cgen.Gen
offsetField cgen.Gen
pending cgen.Gen
pt cgen.Gen
ptField1 cgen.Gen
ptField2 cgen.Gen
role cgen.Gen
statusField cgen.Gen
target cgen.Gen
task cgen.Gen
taskField cgen.Gen
team cgen.Gen
unlockHub cgen.Gen
unlockNode1 cgen.Gen
unlockNode2 cgen.Gen
wrapped cgen.Gen
}

func (s *stage5) local() cgen.Gen {
return cgen.Stmts{
cgen.Var{
Type: s.calleeType,
What: s.callee,
Init: cgen.Arrow{Expr: s.task, Name: s.taskCallee},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: s.nd,
Init: cgen.Arrow{Expr: s.task, Name: s.taskNd},
},
cgen.Var{
Type: cgen.Int64T,
What: cgen.Elem{Arr: s.pt, Idx: s.maxNd},
},
cgen.For{
Cond: s.np,
Post: cgen.Assign{Expr1: s.np, Expr2: s.npField1},
Body: cgen.Stmts{
cgen.Call{
Func: cgen.Memcpy,
Args: cgen.CommaSpaced{
s.pt, s.ptField1, cgen.Sizeof{What: s.pt},
},
},
cgen.Assign{
Expr1: s.npField1,
Expr2: cgen.Sub{Expr1: s.np, Expr2: cgen.One},
},
cgen.Call{
Func: vb(s.inc),
Args: cgen.CommaSpaced{s.nd, s.hullField, s.ptField1},
},
s.unlockNode1,
cgen.Call{
Func: s.callee,
Args: cgen.CommaSpaced{s.task, s.pt},
},
s.lockNode1,
},
},
}
}

func (s *stage5) steal() cgen.Gen {
return cgen.Stmts{
cgen.Var{
Type: s.ptrNode,
What: s.node2,
Init: cgen.Add{Expr1: s.nodes, Expr2: s.target},
},
s.lockNode2,
cgen.For{
Init: cgen.Assign{Expr1: s.np, Expr2: s.npField2},
Cond: s.np,
Post: cgen.Assign{Expr1: s.np, Expr2: s.npField2},
Body: cgen.Stmts{
cgen.Call{
Func: cgen.Memcpy,
Args: cgen.CommaSpaced{
s.pt, s.ptField2, cgen.Sizeof{What: s.pt},
},
},
cgen.Assign{
Expr1: s.npField2,
Expr2: cgen.Sub{Expr1: s.np, Expr2: cgen.One},
},
cgen.Call{
Func: vb(s.inc),
Args: cgen.CommaSpaced{s.nd, s.hullField, s.ptField2},
},
s.unlockNode2,
cgen.Call{
Func: s.callee,
Args: cgen.CommaSpaced{s.task, s.pt},
},
s.lockNode2,
},
},
s.unlockNode2,
}
}

func (s *stage5) nonlocal() cgen.Gen {
return cgen.Stmts{
cgen.AndAssign{
Expr1: cgen.Elem{
Arr: s.statusField,
Idx: cgen.Quo{Expr1: s.role, Expr2: cgen.BitsPerLong},
},
Expr2: cgen.Not{Expr: cgen.Paren{
Inner: cgen.ShiftHigh{
Expr1: cgen.Cast{Type: cgen.Long, Expr: cgen.One},
Expr2: cgen.Rem{Expr1: s.role, Expr2: cgen.BitsPerLong},
},
}},
},
cgen.Var{Type: cgen.PtrdiffT, What: s.offset, Init: s.offsetField},
cgen.Var{Type: cgen.Long, What: s.mask, Init: s.maskField},
cgen.Var{Type: cgen.PtrdiffT, What: s.wrapped, Init: cgen.Zero},
cgen.For{Body: cgen.Stmts{
cgen.Var{
Type: cgen.Long,
What: s.hand,
Init: cgen.And{
Expr1: cgen.Elem{Arr: s.statusField, Idx: s.offset},
Expr2: s.mask,
},
},
cgen.If{
Cond: cgen.IsZero{Expr: s.hand},
Then: cgen.Stmts{
cgen.IncPre{Expr: s.offset},
cgen.Assign{Expr1: s.mask, Expr2: cgen.NegOne},
cgen.Continue,
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: s.target,
Init: cgen.Add{
Expr1: cgen.Mul{Expr1: s.offset, Expr2: cgen.BitsPerLong},
Expr2: cgen.Call{Func: cgen.Ctzl, Args: s.hand},
},
},
cgen.If{
Cond: cgen.CmpE{Expr1: s.target, Expr2: s.nt},
Then: cgen.Stmts{
cgen.If1{Cond: s.wrapped, Then: cgen.Break},
cgen.Assign{Expr1: s.offset, Expr2: cgen.Zero},
cgen.Assign{Expr1: s.mask, Expr2: cgen.NegOne},
cgen.Assign{Expr1: s.wrapped, Expr2: cgen.One},
cgen.Continue,
},
},
cgen.AndAssign{
Expr1: s.hand,
Expr2: cgen.Neg{Expr: s.hand},
},
cgen.Assign{Expr1: s.offsetField, Expr2: s.offset},
cgen.Assign{
Expr1: s.maskField,
Expr2: cgen.Sub{Expr1: s.mask, Expr2: s.hand},
},
s.unlockHub,
s.steal(),
s.lockHub,
cgen.AndAssign{
Expr1: cgen.Elem{Arr: s.statusField, Idx: s.offset},
Expr2: cgen.Not{Expr: s.hand},
},
cgen.Assign{Expr1: s.offset, Expr2: s.offsetField},
cgen.Assign{Expr1: s.mask, Expr2: s.maskField},
cgen.Assign{Expr1: s.wrapped, Expr2: cgen.Zero},
}},
}
}

func (s *stage5) fn() {
s.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrVoid,
Name: s.main,
Params: cgen.Param{Type: cgen.PtrVoid, What: s.arg},
Body: cgen.Stmts{
cgen.Var{Type: s.ptrNode, What: s.node1, Init: s.arg},
cgen.Var{
Type: s.PtrTeam,
What: s.team,
Init: cgen.Arrow{Expr: s.node1, Name: s.nodeTeam},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: s.nt,
Init: cgen.Arrow{Expr: s.team, Name: s.teamNt},
},
cgen.Var{
Type: s.ptrHub,
What: s.hub,
Init: cgen.Arrow{Expr: s.team, Name: s.teamHub},
},
cgen.Var{
Type: s.ptrNode,
What: s.nodes,
Init: cgen.Arrow{Expr: s.team, Name: s.teamNodes},
},
cgen.Var{
Type: cgen.SizeT,
What: s.role,
Init: cgen.Sub{Expr1: s.node1, Expr2: s.nodes},
},
s.lockNode1,
cgen.For{Body: cgen.Stmts{
cgen.Var{Type: s.ptrTask, What: s.task, Init: s.taskField},
cgen.If{
Cond: cgen.IsZero{Expr: s.task},
Then: cgen.Stmts{
wait{s.node1, s.nodeCond, s.nodeMut},
cgen.Continue,
},
},
cgen.Var{Type: cgen.Int64T, What: s.np, Init: s.npField1},
cgen.If{
Cond: cgen.CmpL{Expr1: s.np, Expr2: cgen.Zero},
Then: cgen.Stmts{
s.unlockNode1,
cgen.Return{Expr: cgen.Zero},
},
},
cgen.Assign{Expr1: s.taskField, Expr2: cgen.Zero},
s.local(),
s.unlockNode1,
s.lockHub,
s.nonlocal(),
cgen.Var{
Type: cgen.PtrdiffT,
What: s.pending,
Init: cgen.DecPre{
Expr: cgen.Arrow{Expr: s.hub, Name: s.hubPending},
},
},
s.unlockHub,
cgen.If1{
Cond: cgen.IsZero{Expr: s.pending},
Then: signal{s.hub, s.hubCond},
},
s.lockNode1,
}},
},
}.Append(s.to)
}

func (p *Prep) stage5() {
s := &stage5{
Prep: p,
arg: vb(p.name("arg")),
callee: vb(p.name("callee")),
hand: vb(p.name("hand")),
hub: vb(p.name("hub")),
mask: vb(p.name("mask")),
nd: vb(p.name("nd")),
node1: vb(p.name("node")),
node2: vb(p.name("node")),
nodes: vb(p.name("nodes")),
np: vb(p.name("np")),
nt: vb(p.name("nt")),
offset: vb(p.name("offset")),
pending: vb(p.name("pending")),
pt: vb(p.name("pt")),
role: vb(p.name("role")),
target: vb(p.name("target")),
task: vb(p.name("task")),
team: vb(p.name("team")),
wrapped: vb(p.name("wrapped")),
}
s.hullField = cgen.Arrow{Expr: s.task, Name: s.taskHull}
s.lockHub = lock{s.hub, s.hubMut}
s.lockNode1 = lock{s.node1, s.nodeMut}
s.lockNode2 = lock{s.node2, s.nodeMut}
s.maskField = cgen.Arrow{Expr: s.hub, Name: s.hubMask}
s.npField1 = cgen.Arrow{Expr: s.node1, Name: s.nodeNp}
s.npField2 = cgen.Arrow{Expr: s.node2, Name: s.nodeNp}
s.offsetField = cgen.Arrow{Expr: s.hub, Name: s.hubOffset}
s.ptField1 = cgen.Arrow{Expr: s.node1, Name: s.nodePt}
s.ptField2 = cgen.Arrow{Expr: s.node2, Name: s.nodePt}
s.statusField = cgen.Arrow{Expr: s.hub, Name: s.hubStatus}
s.taskField = cgen.Arrow{Expr: s.node1, Name: s.nodeTask}
s.unlockHub = unlock{s.hub, s.hubMut}
s.unlockNode1 = unlock{s.node1, s.nodeMut}
s.unlockNode2 = unlock{s.node2, s.nodeMut}
s.fn()
s.newline()
}

func (p *Prep) stage6() {
var (
team = vb(p.name("team"))
nodes = vb(p.name("nodes"))
node = vb(p.name("node"))
stop = vb(p.name("stop"))
unwind = cgen.Arrow{Expr: team, Name: p.teamUnwind}
hub = vb(p.name("hub"))
)
field := func(a string) cgen.Gen {
return cgen.Dot{Expr: unwind, Name: a}
}
loop := func(a ...cgen.Gen) cgen.Gen {
return cgen.For{
Init: cgen.Var{Type: p.ptrNode, What: node, Init: nodes},
Cond: cgen.CmpNE{Expr1: node, Expr2: stop},
Post: cgen.IncPre{Expr: node},
Body: cgen.Stmts(a),
}
}
reps := func(a cgen.Gen) cgen.Gen {
return cgen.Assign{
Expr1: stop,
Expr2: cgen.Add{Expr1: nodes, Expr2: a},
}
}
free := func(a cgen.Gen) cgen.Gen {
return cgen.Call{Func: cgen.Free, Args: a}
}
p.to = cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: p.destroy,
Params: cgen.Param{Type: p.PtrTeam, What: team},
Body: cgen.Stmts{
cgen.If1{
Cond: cgen.IsZero{Expr: team},
Then: cgen.Return{},
},
cgen.Var{
Type: p.ptrNode,
What: nodes,
Init: cgen.Arrow{Expr: team, Name: p.teamNodes},
},
cgen.Var{
Type: p.ptrNode,
What: stop,
Init: cgen.Add{Expr1: nodes, Expr2: field(p.unwindJoin)},
},
loop(
lock{node, p.nodeMut},
cgen.Assign{
Expr1: cgen.Arrow{Expr: node, Name: p.nodeNp},
Expr2: cgen.NegOne,
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: node, Name: p.nodeTask},
Expr2: cgen.Cast{Type: p.ptrTask, Expr: cgen.One},
},
unlock{node, p.nodeMut},
signal{node, p.nodeCond},
),
loop(join{node, p.nodeThr}),
reps(field(p.unwindNodeConds)),
loop(destroyCond{node, p.nodeCond}),
reps(field(p.unwindNodeMuts)),
loop(destroyMut{node, p.nodeMut}),
cgen.Var{
Type: p.ptrHub,
What: hub,
Init: cgen.Arrow{Expr: team, Name: p.teamHub},
},
cgen.If{
Cond: field(p.unwindHubCond),
Then: cgen.Stmts{destroyCond{hub, p.hubCond}},
},
cgen.If{
Cond: field(p.unwindHubMut),
Then: cgen.Stmts{destroyMut{hub, p.hubMut}},
},
free(field(p.unwindNodes)),
free(field(p.unwindHub)),
free(team),
},
}.Append(p.to)
p.newline()
}

func (p *Prep) stage7Up4(root, self string) {
var (
team = vb(p.name("team"))
nt = vb(p.name("nt"))
nodes = vb(p.name("nodes"))
node = vb(p.name("node"))
cnt = cgen.Sub{Expr1: node, Expr2: nodes}
cntOne = cgen.Add{Expr1: cnt, Expr2: cgen.One}
unwind = cgen.Arrow{Expr: team, Name: p.teamUnwind}
muts = cgen.Dot{Expr: unwind, Name: p.unwindNodeMuts}
conds = cgen.Dot{Expr: unwind, Name: p.unwindNodeConds}
join = cgen.Dot{Expr: unwind, Name: p.unwindJoin}
)
set := func(nm, nc, nj cgen.Gen) cgen.Gen {
return cgen.Stmts{
cgen.Assign{Expr1: muts, Expr2: nm},
cgen.Assign{Expr1: conds, Expr2: nc},
cgen.Assign{Expr1: join, Expr2: nj},
}
}
p.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: self,
Params: cgen.CommaSpaced{
cgen.Param{Type: p.PtrTeam, What: team},
cgen.Param{Type: cgen.PtrdiffT, What: nt},
},
Body: cgen.Stmts{
cgen.Var{
Type: p.ptrNode,
What: nodes,
Init: cgen.Arrow{Expr: team, Name: p.teamNodes},
},
cgen.For{
Init: cgen.Var{Type: p.ptrNode, What: node, Init: nodes},
Cond: cgen.CmpNE{
Expr1: node,
Expr2: cgen.Add{Expr1: nodes, Expr2: nt},
},
Post: cgen.IncPre{Expr: node},
Body: cgen.Stmts{
&errmsg.ReturnedErrnoIf{
Ctx: p.emc,
Call: cgen.Call{
Func: cgen.PthreadMutexInit,
Args: cgen.CommaSpaced{
cgen.AddrArrow{Expr: node, Name: p.nodeMut},
cgen.Zero,
},
},
Unwind: set(cnt, cnt, cnt),
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: node, Name: p.nodeTask},
Expr2: cgen.Zero,
},
&errmsg.ReturnedErrnoIf{
Ctx: p.emc,
Call: cgen.Call{
Func: cgen.PthreadCondInit,
Args: cgen.CommaSpaced{
cgen.AddrArrow{Expr: node, Name: p.nodeCond},
cgen.Zero,
},
},
Unwind: set(cntOne, cnt, cnt),
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: node, Name: p.nodeTeam},
Expr2: team,
},
&errmsg.ReturnedErrnoIf{
Ctx: p.emc,
Call: cgen.Call{
Func: cgen.PthreadCreate,
Args: cgen.CommaSpaced{
cgen.AddrArrow{Expr: node, Name: p.nodeThr},
cgen.Zero, vb(p.main), node,
},
},
Unwind: set(cntOne, cntOne, cnt),
},
},
},
set(nt, nt, nt),
cgen.Return{Expr: cgen.Zero},
},
}.Append(p.to)
p.newline()
}

func (p *Prep) stage7Up3(root, self string) {
var (
up = p.name(root)
team = vb(p.name("team"))
nt = vb(p.name("nt"))
hub = vb(p.name("hub"))
unwind = cgen.Arrow{Expr: team, Name: p.teamUnwind}
)
p.stage7Up4(root, up)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: self,
Params: cgen.CommaSpaced{
cgen.Param{Type: p.PtrTeam, What: team},
cgen.Param{Type: cgen.PtrdiffT, What: nt},
},
Body: cgen.Stmts{
cgen.Var{
Type: p.ptrHub,
What: hub,
Init: cgen.Arrow{Expr: team, Name: p.teamHub},
},
&errmsg.ReturnedErrnoIf{
Ctx: p.emc,
Call: cgen.Call{
Func: cgen.PthreadMutexInit,
Args: cgen.CommaSpaced{
cgen.AddrArrow{Expr: hub, Name: p.hubMut},
cgen.Zero,
},
},
},
cgen.Assign{
Expr1: cgen.Dot{Expr: unwind, Name: p.unwindHubMut},
Expr2: cgen.One,
},
&errmsg.ReturnedErrnoIf{
Ctx: p.emc,
Call: cgen.Call{
Func: cgen.PthreadCondInit,
Args: cgen.CommaSpaced{
cgen.AddrArrow{Expr: hub, Name: p.hubCond},
cgen.Zero,
},
},
},
cgen.Assign{
Expr1: cgen.Dot{Expr: unwind, Name: p.unwindHubCond},
Expr2: cgen.One,
},
cgen.Return{Expr: cgen.Call{
Func: vb(up),
Args: cgen.CommaSpaced{team, nt},
}},
},
}.Append(p.to)
p.newline()
}

func (p *Prep) stage7Up2(root, self string) {
var (
up = p.name(root)
team = vb(p.name("team"))
nt = vb(p.name("nt"))
size = vb(p.name("size"))
each = cgen.Sizeof{What: vb(p.nodeType)}
addr = vb(p.name("addr"))
line = cgen.IntLit(p.cacheLine)
)
p.stage7Up3(root, up)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: self,
Params: cgen.CommaSpaced{
cgen.Param{Type: p.PtrTeam, What: team},
cgen.Param{Type: cgen.PtrdiffT, What: nt},
},
Body: cgen.Stmts{
cgen.Var{
Type: cgen.SizeT,
What: size,
Init: cgen.Mul{Expr1: nt, Expr2: each},
},
&errmsg.FormatIf{
Ctx: p.emc,
Cond: cgen.CmpNE{
Expr1: cgen.Quo{Expr1: size, Expr2: each},
Expr2: cgen.Cast{Type: cgen.SizeT, Expr: nt},
},
Format: "too many threads",
},
cgen.Var{
Type: cgen.PtrVoid,
What: addr,
Init: cgen.Call{
Func: cgen.Malloc,
Args: cgen.Add{Expr1: size, Expr2: line - 1},
},
},
&errmsg.ErrnoIf{
Ctx: p.emc,
Cond: cgen.IsZero{Expr: addr},
},
cgen.Assign{
Expr1: cgen.Dot{
Expr: cgen.Arrow{Expr: team, Name: p.teamUnwind},
Name: p.unwindNodes,
},
Expr2: addr,
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: team, Name: p.teamNodes},
Expr2: cgen.Cast{
Type: cgen.PtrVoid,
Expr: cgen.Paren{Inner: round{
cgen.Cast{Type: cgen.SizeT, Expr: addr},
line,
}},
},
},
cgen.Return{Expr: cgen.Call{
Func: vb(up),
Args: cgen.CommaSpaced{team, nt},
}},
},
}.Append(p.to)
p.newline()
}

func (p *Prep) stage7Up1(root, self string) {
var (
up = p.name(root)
team = vb(p.name("team"))
nt = vb(p.name("nt"))
size = vb(p.name("size"))
line = cgen.IntLit(p.cacheLine)
addr = vb(p.name("addr"))
)
p.stage7Up2(root, up)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: self,
Params: cgen.CommaSpaced{
cgen.Param{Type: p.PtrTeam, What: team},
cgen.Param{Type: cgen.PtrdiffT, What: nt},
},
Body: cgen.Stmts{
cgen.Assign{
Expr1: cgen.Arrow{Expr: team, Name: p.teamNt},
Expr2: nt,
},
cgen.Var{
Type: cgen.SizeT,
What: size,
Init: cgen.Sizeof{What: vb(p.hubType)},
},
cgen.AddAssign{
Expr1: size,
Expr2: cgen.Mul{
Expr1: cgen.Sizeof{What: cgen.Long},
Expr2: cgen.Paren{Inner: cgen.Add{
Expr1: cgen.Quo{
Expr1: cgen.Cast{Type: cgen.SizeT, Expr: nt},
Expr2: cgen.BitsPerLong,
},
Expr2: cgen.One,
}},
},
},
cgen.Assign{
Expr1: size,
Expr2: round{size, line},
},
cgen.Var{
Type: cgen.PtrVoid,
What: addr,
Init: cgen.Call{
Func: cgen.Malloc,
Args: cgen.Add{Expr1: size, Expr2: line - 1},
},
},
&errmsg.ErrnoIf{
Ctx: p.emc,
Cond: cgen.IsZero{Expr: addr},
},
cgen.Assign{
Expr1: cgen.Dot{
Expr: cgen.Arrow{Expr: team, Name: p.teamUnwind},
Name: p.unwindHub,
},
Expr2: addr,
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: team, Name: p.teamHub},
Expr2: cgen.Cast{
Type: cgen.PtrVoid,
Expr: cgen.Paren{Inner: round{
cgen.Cast{Type: cgen.SizeT, Expr: addr},
line,
}},
},
},
cgen.Return{Expr: cgen.Call{
Func: vb(up),
Args: cgen.CommaSpaced{team, nt},
}},
},
}.Append(p.to)
p.newline()
}

func (p *Prep) stage7() {
var (
root = p.create + "Up"
up = p.name(root)
team = vb(p.name("team"))
nt = vb(p.name("nt"))
addr = vb(p.name("addr"))
err = vb(p.name("err"))
)
p.stage7Up1(root, up)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: p.create,
Params: cgen.CommaSpaced{
cgen.Param{
Type: cgen.Ptr{Type: p.PtrTeam},
What: team,
},
cgen.Param{Type: cgen.PtrdiffT, What: nt},
},
Body: cgen.Stmts{
&errmsg.FormatIf{
Ctx: p.emc,
Cond: cgen.CmpL{Expr1: nt, Expr2: cgen.One},
Format: "too few threads",
},
cgen.Var{
Type: cgen.PtrVoid,
What: addr,
Init: cgen.Call{
Func: cgen.Calloc,
Args: cgen.CommaSpaced{
cgen.One,
cgen.Sizeof{What: vb(p.teamType)},
},
},
},
&errmsg.ErrnoIf{
Ctx: p.emc,
Cond: cgen.IsZero{Expr: addr},
},
cgen.Var{
Type: cgen.PtrChar,
What: err,
Init: cgen.Call{
Func: vb(up),
Args: cgen.CommaSpaced{addr, nt},
},
},
cgen.If{
Cond: cgen.Unlikely{
Cond: cgen.IsNonzero{Expr: err},
},
Then: cgen.Stmts{cgen.Call{
Func: vb(p.destroy), Args: addr,
}},
Else: cgen.Stmts{cgen.Assign{
Expr1: cgen.At{Expr: team},
Expr2: addr,
}},
},
cgen.Return{Expr: err},
},
}.Append(p.to)
p.newline()
}

func (p *Prep) stage8() {
var (
thr = vb(p.name("thr"))
team = vb(p.name("team"))
idx = vb(p.name("idx"))
)
p.to = cgen.StaticFuncDef{
ReturnType: cgen.PtrChar,
Name: p.pthreadT,
Params: cgen.CommaLines{
cgen.Param{Type: cgen.PtrPthreadT, What: thr},
cgen.Param{Type: p.PtrTeam, What: team},
cgen.Param{Type: cgen.PtrdiffT, What: idx},
},
Body: cgen.Stmts{
&errmsg.FormatIf{
Ctx: p.emc,
Cond: cgen.Lor{
Expr1: cgen.CmpL{Expr1: idx, Expr2: cgen.Zero},
Expr2: cgen.CmpGE{
Expr1: idx,
Expr2: cgen.Arrow{Expr: team, Name: p.teamNt},
},
},
Format: "bad thread idx",
},
cgen.Assign{
Expr1: cgen.At{Expr: thr},
Expr2: cgen.Dot{
Expr: cgen.Elem{
Arr: cgen.Arrow{Expr: team, Name: p.teamNodes},
Idx: idx,
},
Name: p.nodeThr,
},
},
cgen.Return{Expr: cgen.Zero},
},
}.Append(p.to)
p.newline()
}

type stage9 struct {
*Prep
team cgen.Gen
task cgen.Gen
nd cgen.Gen
tot cgen.Gen
hull cgen.Gen
i cgen.Gen
nt cgen.Gen
each cgen.Gen
more cgen.Gen
plus cgen.Gen
pt cgen.Gen
hub cgen.Gen
node cgen.Gen
carry cgen.Gen
pending cgen.Gen
}

func (s *stage9) divide() cgen.Gen {
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrdiffT,
What: s.nd,
Init: cgen.Arrow{Expr: s.task, Name: s.taskNd},
},
cgen.If1{
Cond: cgen.CmpL{Expr1: s.nd, Expr2: cgen.One},
Then: cgen.Return{},
},
cgen.Var{
Type: cgen.Int64T,
What: s.tot,
Init: cgen.Elem{Arr: s.hull, Idx: cgen.Zero},
},
cgen.For{
Init: cgen.Var{Type: cgen.PtrdiffT, What: s.i, Init: cgen.One},
Cond: cgen.CmpL{Expr1: s.i, Expr2: s.nd},
Post: cgen.MulAssign{
Expr1: s.tot,
Expr2: cgen.Elem{
Arr: s.hull,
Idx: cgen.IncPost{Expr: s.i},
},
},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: s.nt,
Init: cgen.Arrow{Expr: s.team, Name: s.teamNt},
},
cgen.Var{
Type: cgen.Int64T,
What: s.each,
Init: cgen.Quo{Expr1: s.tot, Expr2: s.nt},
},
cgen.Var{
Type: cgen.PtrdiffT,
What: s.more,
Init: cgen.Rem{Expr1: s.tot, Expr2: s.nt},
},
cgen.Var{
Type: cgen.Int64T,
What: cgen.Elem{Arr: s.plus, Idx: s.maxNd},
},
cgen.Call{
Func: vb(s.put),
Args: cgen.CommaSpaced{s.nd, s.hull, s.plus, s.each},
},
cgen.Var{
Type: cgen.Int64T,
What: cgen.Elem{Arr: s.pt, Idx: s.maxNd},
Init: cgen.Zeros,
},
}
}

func (s *stage9) launch() cgen.Gen {
return cgen.Stmts{
cgen.Var{
Type: s.ptrNode,
What: s.node,
Init: cgen.Arrow{Expr: s.team, Name: s.teamNodes},
},
cgen.For{
Init: cgen.Var{Type: cgen.PtrdiffT, What: s.i, Init: cgen.Zero},
Post: cgen.IncPre{Expr: s.node},
Body: cgen.Stmts{
lock{s.node, s.nodeMut},
cgen.Var{
Type: cgen.Int64T,
What: s.carry,
Init: cgen.CmpL{Expr1: s.i, Expr2: s.more},
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: s.node, Name: s.nodeNp},
Expr2: cgen.Add{Expr1: s.each, Expr2: s.carry},
},
cgen.Call{
Func: cgen.Memcpy,
Args: cgen.CommaSpaced{
cgen.Arrow{Expr: s.node, Name: s.nodePt},
s.pt, cgen.Sizeof{What: s.pt},
},
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: s.node, Name: s.nodeTask},
Expr2: s.task,
},
unlock{s.node, s.nodeMut},
signal{s.node, s.nodeCond},
cgen.If1{
Cond: cgen.CmpE{
Expr1: cgen.IncPre{Expr: s.i},
Expr2: s.nt,
},
Then: cgen.Break,
},
cgen.Call{
Func: vb(s.add),
Args: cgen.CommaSpaced{s.nd, s.hull, s.pt, s.plus, s.carry},
},
},
},
}
}

func (s *stage9) fn() {
s.to = cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: s.do,
Params: cgen.CommaSpaced{
cgen.Param{Type: s.PtrTeam, What: s.team},
cgen.Param{Type: s.ptrTask, What: s.task},
},
Body: cgen.Stmts{
s.divide(),
cgen.Var{
Type: s.ptrHub,
What: s.hub,
Init: cgen.Arrow{Expr: s.team, Name: s.teamHub},
},
lock{s.hub, s.hubMut},
s.launch(),
cgen.Assign{
Expr1: cgen.Arrow{Expr: s.hub, Name: s.hubOffset},
Expr2: cgen.Zero,
},
cgen.Assign{
Expr1: cgen.Arrow{Expr: s.hub, Name: s.hubMask},
Expr2: cgen.NegOne,
},
cgen.For{
Init: cgen.Var{
Type: cgen.PtrdiffT,
What: s.i,
Init: cgen.Quo{
Expr1: cgen.Cast{Type: cgen.SizeT, Expr: s.nt},
Expr2: cgen.BitsPerLong,
},
},
Cond: cgen.CmpGE{Expr1: s.i, Expr2: cgen.Zero},
Body: cgen.Stmts{cgen.Assign{
Expr1: cgen.Elem{
Arr: cgen.Arrow{Expr: s.hub, Name: s.hubStatus},
Idx: cgen.DecPost{Expr: s.i},
},
Expr2: cgen.NegOne,
}},
},
cgen.For{
Init: cgen.Assign{Expr1: s.pending, Expr2: s.nt},
Cond: s.pending,
Body: cgen.Stmts{wait{s.hub, s.hubCond, s.hubMut}},
},
unlock{s.hub, s.hubMut},
},
}.Append(s.to)
}

func (p *Prep) stage9() {
s := &stage9{
Prep: p,
team: vb(p.name("team")),
task: vb(p.name("task")),
nd: vb(p.name("nd")),
tot: vb(p.name("tot")),
i: vb(p.name("i")),
nt: vb(p.name("nt")),
each: vb(p.name("each")),
more: vb(p.name("more")),
plus: vb(p.name("plus")),
pt: vb(p.name("pt")),
hub: vb(p.name("hub")),
node: vb(p.name("node")),
carry: vb(p.name("carry")),
}
s.hull = cgen.Arrow{Expr: s.task, Name: s.taskHull}
s.pending = cgen.Arrow{Expr: s.hub, Name: s.hubPending}
s.fn()
}

type Destroy struct {
*Ctx
Team cgen.Gen
}

func (d *Destroy) Append(to []byte) []byte {
return cgen.Stmts{cgen.Call{
Func: vb(d.destroy),
Args: d.Team,
}}.Append(to)
}

type Create struct {
*Ctx
Team cgen.Gen
Nt cgen.Gen
Unwind cgen.Gen
}

func (c *Create) Append(to []byte) []byte {
var (
err = vb(c.name("err"))
ret = cgen.Return{Expr: err}
)
cond := cgen.Unlikely{
Cond: cgen.IsNonzero{Expr: err},
}
var follow cgen.Gen
if c.Unwind == nil {
follow = cgen.If1{Cond: cond, Then: ret}
} else {
follow = cgen.If{
Cond: cond,
Then: cgen.Stmts{c.Unwind, ret},
}
}
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: err,
Init: cgen.Call{
Func: vb(c.create),
Args: cgen.CommaSpaced{c.Team, c.Nt},
},
},
follow,
}.Append(to)
}

type PthreadT struct {
*Ctx
Thr cgen.Gen
Team cgen.Gen
Idx cgen.Gen
}

func (p *PthreadT) Append(to []byte) []byte {
return cgen.Stmts{cgen.Return{
Expr: cgen.Call{
Func: vb(p.pthreadT),
Args: cgen.CommaSpaced{p.Thr, p.Team, p.Idx},
},
}}.Append(to)
}

type Callee struct {
*Ctx
Name string
Task cgen.Gen
Pt cgen.Gen
}

func (c *Callee) Any() cgen.Gen {
return cgen.Arrow{Expr: c.Task, Name: c.taskAny}
}

func (c *Callee) Nd() cgen.Gen {
return cgen.Arrow{Expr: c.Task, Name: c.taskNd}
}

func (c *Callee) Hull() cgen.Gen {
return cgen.Arrow{Expr: c.Task, Name: c.taskHull}
}

func (c *Callee) Func(body cgen.Gen) cgen.Gen {
return cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: c.Name,
Params: cgen.CommaSpaced{
cgen.Param{Type: c.ptrTask, What: c.Task},
cgen.Param{Type: cgen.PtrInt64T, What: c.Pt},
},
Body: body,
}
}

type Do struct {
*Ctx
Callee cgen.Gen
Any cgen.Gen
Hull []cgen.Gen
Team cgen.Gen
}

func (d *Do) Append(to []byte) []byte {
task := vb(d.name("task"))
tf := func(a string) cgen.Gen {
return cgen.Dot{Expr: task, Name: a}
}
nd := cgen.IntLit(len(d.Hull))
if nd > maxNd {
panic("bug")
}
stmts := cgen.Stmts{
cgen.Var{Type: vb(d.taskType), What: task},
cgen.Assign{Expr1: tf(d.taskCallee), Expr2: d.Callee},
cgen.Assign{Expr1: tf(d.taskAny), Expr2: d.Any},
cgen.Assign{Expr1: tf(d.taskNd), Expr2: nd},
}
hull := tf(d.taskHull)
for i, expr := range d.Hull {
stmts = append(stmts, cgen.Assign{
Expr1: cgen.Elem{
Arr: hull,
Idx: cgen.IntLit(i),
},
Expr2: expr,
})
}
stmts = append(stmts, cgen.Call{
Func: vb(d.do),
Args: cgen.CommaSpaced{
d.Team, cgen.Addr{Expr: task},
},
})
return stmts.Append(to)
}

Top || internal/compile/author/three/three.go

package three

import (
"NN-512/internal/compile/author/act"
"NN-512/internal/compile/author/avx"
"NN-512/internal/compile/author/bn"
"NN-512/internal/compile/author/cgen"
"NN-512/internal/compile/author/mod"
"NN-512/internal/compile/author/sumr"
"NN-512/internal/compile/author/threader"
"NN-512/internal/compile/author/wct"
"NN-512/internal/compile/plan"
"NN-512/internal/nmsrc"
"NN-512/internal/raw"
"fmt"
)

func btoi(b bool) int {
if b {
return 1
}
return 0
}

func min(x, y int) int {
if x <= y {
return x
}
return y
}

func max(x, y int) int {
if x >= y {
return x
}
return y
}

func ceilQuo(n, d int) int {
return (n + d - 1) / d
}

func vb(s string) cgen.Gen {
return cgen.Vb(s)
}

func il(i int) cgen.Gen {
return cgen.IntLit(i)
}

func loMask(n int) cgen.Gen {
return il(1<<uint(n) - 1)
}

func void(a cgen.Gen) cgen.Gen {
return cgen.Cast{
Type: cgen.Void,
Expr: a,
}
}

func addMul(a, b, c cgen.Gen) cgen.Gen {
return cgen.Add{
Expr1: a,
Expr2: cgen.Mul{
Expr1: b,
Expr2: c,
},
}
}

func mix(a []cgen.Stmts) cgen.Stmts {
if len(a) == 1 {
return a[0]
}
tot := 0
for i := range a {
tot += len(a[i])
}
var (
ret = make(cgen.Stmts, tot)
n = 0
)
for i := 0; n < tot; i++ {
for _, aa := range a {
if i < len(aa) {
ret[n] = aa[i]
n++
}
}
}
return ret
}

type Ctx struct {
prefix string
platform raw.Platform
cacheBytes1 int
cacheBytes2 int
nms nmsrc.Src
tc *threader.Ctx
ac *act.Ctx
bc *bn.Ctx
dedup map[string]interface{}
}

func NewCtx(pl *plan.Plan, nms nmsrc.Src, tc *threader.Ctx, ac *act.Ctx, bc *bn.Ctx) *Ctx {
return &Ctx{
prefix: pl.Config.Prefix + "Three",
platform: pl.Config.Platform,
cacheBytes1: pl.Config.L1DataCachePerThread,
cacheBytes2: pl.Config.L2CachePerThreadExL1,
nms: nms,
tc: tc,
ac: ac,
bc: bc,
dedup: make(map[string]interface{}),
}
}

func (c *Ctx) name(s string) string {
return c.nms.Name(s)
}

type Spec struct {
From SpecFrom
Filts []SpecFilts
To SpecTo
StrideH int
StrideW int
PaddingH int
PaddingW int
Groups int
}

type SpecFrom struct {
Chans int
Height int
Width int
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type SpecFilts struct {
Cnt int
BnPre int
BnPost int
}

type SpecTo struct {
Pitch1Bytes []int
Pitch2Bytes []int
Ops []mod.Op
}

type block struct {
fromH int
fromW int
padH int
padW int
datH int
datW int
yieldH int
yieldW int
}

type loopW struct {
fromH int
fromW int
fromStep int
segFirst int
segPast int
blks []*block
}

type loopH struct {
fromH int
fromStep int
segFirst int
segStep int
segPast int
lws []*loopW
}

type segments struct {
cnt int
lhs []*loopH
}

func newSegments(ctx *Ctx, spec *Spec, segBlks int) *segments {
var (
segs segments
blks []*block
lw1 loopW
idx map[int]int
lws []*loopW
tie int
fromStep int
segStep int
segPast int
at int
)
equal := func(seg1, seg2 []*block) bool {
if len(seg1) != len(seg2) {
return false
}
for i := range seg1 {
if *seg1[i] != *seg2[i] {
return false
}
}
return true
}
commit3 := func() {
n1 := len(lws)
if n1 == 0 {
return
}
n2 := tie
if n2 == -1 {
n2 = n1
}
if n2 > 0 {
lh := &loopH{
fromH: lws[0].fromH,
fromStep: 0,
segFirst: lws[0].segFirst,
segStep: 0,
segPast: lws[n2-1].segPast,
lws: make([]*loopW, n2),
}
for i, lw := range lws[:n2] {
lw.fromH -= lh.fromH
lw.segFirst -= lh.segFirst
lw.segPast -= lh.segFirst
lh.lws[i] = lw
}
segs.lhs = append(
segs.lhs, lh,
)
}
if n2 == n1 {
return
}
lh := &loopH{
fromH: lws[n2].fromH,
fromStep: fromStep,
segFirst: lws[n2].segFirst,
segStep: segStep,
segPast: segPast,
lws: make([]*loopW, n1-n2),
}
for i, lw := range lws[n2:] {
lw.fromH -= lh.fromH
lw.segFirst -= lh.segFirst
lw.segPast -= lh.segFirst
lh.lws[i] = lw
}
segs.lhs = append(
segs.lhs, lh,
)
}
commit2 := func(flush bool) {
n := len(lw1.blks)
if n == 0 {
if flush {
commit3()
}
return
}
match := func(lw2 *loopW) bool {
var (
iters1 = lw1.segPast - lw1.segFirst
iters2 = lw2.segPast - lw2.segFirst
)
if iters1 != iters2 {
return false
}
return equal(lw1.blks, lw2.blks)
}
if tie == -1 {
if i, ok := idx[lw1.fromW]; ok {
lw2 := lws[i]
if match(lw2) {
tie = i
fromStep = lw1.fromH - lw2.fromH
segStep = lw1.segFirst - lw2.segFirst
segPast = lw1.segPast
at = i
if flush {
commit3()
}
return
}
}
} else {
if at++; at == len(lws) {
at = tie
}
if match(lws[at]) {
segPast = lw1.segPast
if flush {
commit3()
}
return
}
commit3()
idx = make(map[int]int)
lws = lws[:0]
tie = -1
}
lw2 := lw1
lw2.blks = make([]*block, n)
for i, blk1 := range lw1.blks {
blk2 := *blk1
lw2.blks[i] = &blk2
}
idx[lw2.fromW] = len(lws)
lws = append(lws, &lw2)
if flush {
commit3()
}
}
commit1 := func(flush bool) {
n := len(blks)
if n == 0 {
if flush {
commit2(true)
}
return
}
i := segs.cnt
segs.cnt = i + 1
var (
h = blks[0].fromH
w = blks[0].fromW
)
for _, blk := range blks {
blk.fromH -= h
blk.fromW -= w
}
if len(lw1.blks) > 0 {
if lw1.fromH == h &&
equal(lw1.blks, blks) {
if lw1.segFirst == i-1 {
lw1.fromStep = w - lw1.fromW
}
lw1.segPast = i + 1
if flush {
commit2(true)
}
return
}
commit2(false)
}
lw1.fromH = h
lw1.fromW = w
lw1.fromStep = 0
lw1.segFirst = i
lw1.segPast = i + 1
lw1.blks = lw1.blks[:n]
for j, blk := range blks {
*lw1.blks[j] = *blk
}
if flush {
commit2(true)
}
}
layer5 := func() {
var (
h1 = spec.PaddingH
h2 = h1 + spec.From.Height
h3 = h2 + spec.PaddingH
w1 = spec.PaddingW
w2 = w1 + spec.From.Width
w3 = w2 + spec.PaddingW
)
for h := 0; h+3 <= h3; h += 6 {
for w := 0; w+3 <= w3; w += 6 {
i := len(blks)
if i == cap(blks) {
commit1(false)
i = 0
}
blks = blks[:i+1]
blk := blks[i]
blk.fromH = h
blk.fromW = w
blk.padH = min(max(h1-h, 0), 8)
blk.padW = min(max(w1-w, 0), 8)
blk.datH = min(max(h2-h, 0), 8) - blk.padH
blk.datW = min(max(w2-w, 0), 8) - blk.padW
blk.yieldH = min(h3-h-2, 6)
blk.yieldW = min(w3-w-2, 6)
if blk.datH == 0 || blk.datW == 0 {
blk.padH = 8
blk.padW = 8
blk.datH = 0
blk.datW = 0
}
}
}
commit1(true)
}
layer4 := func() {
idx = make(map[int]int)
tie = -1
layer5()
}
layer3 := func() {
lw1.blks = make([]*block, segBlks)
for i := range lw1.blks {
lw1.blks[i] = new(block)
}
lw1.blks = lw1.blks[:0]
layer4()
}
layer2 := func() {
blks = make([]*block, segBlks)
for i := range blks {
blks[i] = new(block)
}
blks = blks[:0]
layer3()
}
layer1 := func() *segments {
sig := fmt.Sprint(
"newSegments",
" ",
spec.From.Height,
spec.From.Width,
spec.PaddingH,
spec.PaddingW,
segBlks,
)
if prior, ok := ctx.dedup[sig]; ok {
return prior.(*segments)
}
ctx.dedup[sig] = &segs
layer2()
return &segs
}
return layer1()
}

type layout struct {
segs *segments
blkFrags int
fromChans int
toChans int
slices1 int
slices2 int
epochs1 int
epochs2 int
alignment int
biasBytes int
bfBytes int
bfEpochBytes int
bfTotalBytes int
wtBytes int
wfBytes int
wfSliceWfs1 int
wfSliceWfs2 int
wfSliceBytes1 int
wfSliceBytes2 int
wfCores1 int
wfCores2 int
wfCoreBytes11 int
wfCoreBytes12 int
wfCoreBytes21 int
wfCoreBytes22 int
wfFragBytes1 int
wfFragBytes2 int
wfGroupBytes1 int
wfGroupBytes2 int
wfEpochBytes1 int
wfEpochBytes2 int
wfTotalBytes int
datBytes int
dfBytes int
dfSliceDfs1 int
dfSliceDfs2 int
dfSliceBytes1 int
dfSliceBytes2 int
dfCores1 int
dfCores2 int
dfCoreBytes11 int
dfCoreBytes12 int
dfCoreBytes21 int
dfCoreBytes22 int
dfFragBytes1 int
dfFragBytes2 int
dfGroupBytes1 int
dfGroupBytes2 int
dfEpochBytes1 int
dfEpochBytes2 int
dfTotalBytes int
sfBytes int
sfSumBytes11 int
sfSumBytes12 int
sfSumBytes21 int
sfSumBytes22 int
sfCoreBytes1 int
sfCoreBytes2 int
sfFragBytes int
sfGroupBytes int
sfTotalBytes int
}

func newLayout(ctx *Ctx, spec *Spec) *layout {
var (
y layout
)
pad := func(n int) int {
n += y.alignment - 1
n &= -y.alignment
return n
}
layer9 := func() *layout {
y.dfCoreBytes11 = y.slices1 * y.dfSliceBytes1
y.dfCoreBytes12 = y.slices1 * y.dfSliceBytes2
y.dfCoreBytes21 = y.slices2 * y.dfSliceBytes1
y.dfCoreBytes22 = y.slices2 * y.dfSliceBytes2
y.dfFragBytes1 = y.dfCores1 * y.dfCoreBytes11
y.dfFragBytes2 = y.dfCores1 * y.dfCoreBytes21
if y.dfCores1 < y.dfCores2 {
y.dfFragBytes1 += y.dfCoreBytes12
y.dfFragBytes2 += y.dfCoreBytes22
}
y.dfGroupBytes1 = y.blkFrags * y.dfFragBytes1
y.dfGroupBytes2 = y.blkFrags * y.dfFragBytes2
y.dfEpochBytes1 = spec.Groups * y.dfGroupBytes1
y.dfEpochBytes2 = spec.Groups * y.dfGroupBytes2
y.dfTotalBytes = y.epochs1 * y.dfEpochBytes1
if y.epochs1 < y.epochs2 {
y.dfTotalBytes += y.dfEpochBytes2
}
return &y
}
layer8 := func() *layout {
y.wfCoreBytes11 = pad(y.slices1 * y.wfSliceBytes1)
y.wfCoreBytes12 = pad(y.slices1 * y.wfSliceBytes2)
y.wfCoreBytes21 = pad(y.slices2 * y.wfSliceBytes1)
y.wfCoreBytes22 = pad(y.slices2 * y.wfSliceBytes2)
y.wfFragBytes1 = y.wfCores1 * y.wfCoreBytes11
y.wfFragBytes2 = y.wfCores1 * y.wfCoreBytes21
if y.wfCores1 < y.wfCores2 {
y.wfFragBytes1 += y.wfCoreBytes12
y.wfFragBytes2 += y.wfCoreBytes22
}
y.wfGroupBytes1 = y.blkFrags * y.wfFragBytes1
y.wfGroupBytes2 = y.blkFrags * y.wfFragBytes2
y.wfEpochBytes1 = spec.Groups * y.wfGroupBytes1
y.wfEpochBytes2 = spec.Groups * y.wfGroupBytes2
y.wfTotalBytes = y.epochs1 * y.wfEpochBytes1
if y.epochs1 < y.epochs2 {
y.wfTotalBytes += y.wfEpochBytes2
}
return layer9()
}
layer7 := func() *layout {
y.bfEpochBytes = spec.Groups * y.toChans * y.bfBytes
y.bfTotalBytes = pad(y.epochs2 * y.bfEpochBytes)
return layer8()
}
layer6 := func() *layout {
wfSliceBytes := y.wfSliceBytes1
if y.wfCores1 == 0 {
wfSliceBytes = y.wfSliceBytes2
}
dfSliceBytes := y.dfSliceBytes1
if y.dfCores1 == 0 {
dfSliceBytes = y.dfSliceBytes2
}
switch ctx.platform {
case raw.AVX512Float32:
var (
sliceBytes = 2*wfSliceBytes + dfSliceBytes
cacheBytes = ctx.cacheBytes1 + ctx.cacheBytes2
)
const (
empirical1 = 4
empirical2 = 256
empirical3 = 4
)
y.slices1 = cacheBytes / empirical1 / sliceBytes
y.slices1 = max(y.slices1, empirical2)
y.slices2 = y.fromChans % y.slices1
y.epochs1 = y.fromChans / y.slices1
y.epochs2 = y.epochs1 + btoi(y.slices2 > 0)
if y.epochs1 > 0 && y.epochs1 < y.epochs2 {
if y.slices2*empirical3 < y.slices1 {
y.slices2 += y.slices1
y.epochs1--
y.epochs2--
}
}
default:
panic("bug")
}
return layer7()
}
layer5 := func() *layout {
var (
sums11 = y.dfSliceDfs1 * y.wfSliceWfs1
sums12 = y.dfSliceDfs1 * y.wfSliceWfs2
sums21 = y.dfSliceDfs2 * y.wfSliceWfs1
sums22 = y.dfSliceDfs2 * y.wfSliceWfs2
)
y.sfSumBytes11 = sums11 * y.sfBytes
y.sfSumBytes12 = sums12 * y.sfBytes
y.sfSumBytes21 = sums21 * y.sfBytes
y.sfSumBytes22 = sums22 * y.sfBytes
y.sfCoreBytes1 = y.wfCores1 * y.sfSumBytes11
y.sfCoreBytes2 = y.wfCores1 * y.sfSumBytes21
if y.wfCores1 < y.wfCores2 {
y.sfCoreBytes1 += y.sfSumBytes12
y.sfCoreBytes2 += y.sfSumBytes22
}
y.sfFragBytes = y.dfCores1 * y.sfCoreBytes1
if y.dfCores1 < y.dfCores2 {
y.sfFragBytes += y.sfCoreBytes2
}
y.sfGroupBytes = y.blkFrags * y.sfFragBytes
y.sfTotalBytes = spec.Groups * y.sfGroupBytes
return layer6()
}
layer4 := func() *layout {
y.segs = newSegments(ctx, spec, y.dfSliceDfs1)
var (
lh = y.segs.lhs[len(y.segs.lhs)-1]
lw = lh.lws[len(lh.lws)-1]
)
y.dfSliceDfs2 = len(lw.blks)
if y.dfSliceDfs2 == y.dfSliceDfs1 {
y.dfSliceDfs2 = 0
}
y.dfSliceBytes1 = y.dfSliceDfs1 * y.dfBytes
y.dfSliceBytes2 = y.dfSliceDfs2 * y.dfBytes
y.dfCores1 = y.segs.cnt - btoi(y.dfSliceDfs2 > 0)
y.dfCores2 = y.segs.cnt
return layer5()
}
layer3 := func() *layout {
y.wfSliceWfs2 = y.toChans % y.wfSliceWfs1
y.wfSliceBytes1 = y.wfSliceWfs1 * y.wfBytes
y.wfSliceBytes2 = y.wfSliceWfs2 * y.wfBytes
y.wfCores1 = y.toChans / y.wfSliceWfs1
y.wfCores2 = y.wfCores1 + btoi(y.wfSliceWfs2 > 0)
return layer4()
}
layer2 := func() *layout {
if len(spec.Filts) > 1 && spec.Groups > 1 {
panic("bug")
}
filts := 0
for i := range spec.Filts {
filts += spec.Filts[i].Cnt
}
y.fromChans = spec.From.Chans / spec.Groups
y.toChans = filts / spec.Groups
return layer3()
}
layer1 := func() *layout {
switch ctx.platform {
case raw.AVX512Float32:
y.blkFrags = 4
y.alignment = 64
y.biasBytes = 4
y.bfBytes = 4
y.wtBytes = 4
y.wfBytes = 32
y.wfSliceWfs1 = 4
y.datBytes = 4
y.dfBytes = 64
y.dfSliceDfs1 = 6
y.sfBytes = 64
default:
panic("bug")
}
return layer2()
}
return layer1()
}

type ArrangeFilts struct {
*Ctx
*Spec
Team cgen.Gen
Tensors []cgen.Gen
*layout
callerName string
}

func (a *ArrangeFilts) Prep() cgen.Gen {
a.layout = newLayout(a.Ctx, a.Spec)
const affix = "ArrangeFilts"
sig := fmt.Sprint(affix, " ", a.Spec)
if prior, ok := a.dedup[sig]; ok {
a.callerName = prior.(string)
return nil
}
a.callerName = a.name(a.prefix + affix)
a.dedup[sig] = a.callerName
return cgen.Gens{
&arrangeFilts{ArrangeFilts: a},
cgen.Newline,
}
}

func (a *ArrangeFilts) Bytes() int {
return a.bfTotalBytes + a.wfTotalBytes
}

func (a *ArrangeFilts) Append(to []byte) []byte {
var (
tensors = vb(a.name("tensors"))
ptrs = cgen.CommaLines(a.Tensors)
)
return cgen.Stmts{
cgen.Var{
Type: cgen.PtrChar,
What: cgen.Elem{Arr: tensors},
Init: cgen.Brace{Inner: ptrs},
},
cgen.Call{
Func: vb(a.callerName),
Args: cgen.CommaSpaced{
a.Team, tensors,
},
},
}.Append(to)
}

type arrangeFilts struct {
*ArrangeFilts
bundleFilts int
bundleTile int
bundleTiles int
bundleScrap int
bundleHull int
groupTile int
groupTiles int
groupScrap int
groupHull int
calleeName string
tensors cgen.Gen
bundleCoord cgen.Gen
groupCoord cgen.Gen
epochCoord cgen.Gen
slices int
coreBytes int
fragBytes int
groupBytes int
epochFirst int
epochCnt int
bfPtr cgen.Gen
wfPtr cgen.Gen
filtsIdx int
wtPtr cgen.Gen
biasPtr cgen.Gen
bnPtrs []cgen.Gen
groupIdx cgen.Gen
bundleIdx cgen.Gen
bundleLast cgen.Gen
baseFilt int
baseBundle int
filts1 int
filts2 int
coreIdx cgen.Gen
coreCut cgen.Gen
}

func (a *arrangeFilts) Append(to []byte) []byte {
var (
threadBlks int
groupBundles int
)
switch a.platform {
case raw.AVX512Float32:
a.bundleFilts = 4
threadBlks = 512
default:
panic("bug")
}
switch len(a.Filts) {
case 1:
groupBundles = ceilQuo(a.toChans, a.bundleFilts)
default:
for i := range a.Filts {
filts := a.Filts[i].Cnt
groupBundles += ceilQuo(filts, a.bundleFilts)
}
}
var (
filtBlks = ceilQuo(a.fromChans, a.epochs2)
bundleBlks = a.bundleFilts * filtBlks
groupBlks = a.toChans * filtBlks
)
switch {
case threadBlks <= bundleBlks:
a.bundleTile = 1
a.bundleTiles = groupBundles
a.bundleScrap = 0
a.bundleHull = groupBundles
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
case threadBlks <= groupBlks:
var (
tile = ceilQuo(threadBlks, bundleBlks)
tiles = max(groupBundles/tile, 1)
)
a.bundleTile = groupBundles / tiles
a.bundleTiles = tiles
a.bundleScrap = groupBundles - tiles*a.bundleTile
a.bundleHull = tiles
if a.bundleScrap > 0 {
a.bundleTiles--
a.bundleScrap += a.bundleTile
}
a.groupTile = 1
a.groupTiles = a.Groups
a.groupScrap = 0
a.groupHull = a.Groups
default:
a.bundleTile = groupBundles
a.bundleTiles = 1
a.bundleScrap = 0
a.bundleHull = 1
var (
tile = ceilQuo(threadBlks, groupBlks)
tiles = max(a.Groups/tile, 1)
)
a.groupTile = a.Groups / tiles
a.groupTiles = tiles
a.groupScrap = a.Groups - tiles*a.groupTile
a.groupHull = tiles
if a.groupScrap > 0 {
a.groupTiles--
a.groupScrap += a.groupTile
}
}
a.calleeName = a.name(a.callerName + "Callee")
var (
team = vb(a.name("team"))
tensors = vb(a.name("tensors"))
)
return cgen.Gens{
a.calleeFunc(),
cgen.Newline,
cgen.StaticFuncDef{
ReturnType: cgen.Void,
Name: a.callerName,
Params: cgen.CommaSpaced{
cgen.Param{
Type: a.tc.PtrTeam,
What: team,
},
cgen.Param{
Type: cgen.PtrPtrChar,
What: tensors,
},
},
Body: &threader.Do{
Ctx: a.tc,
Callee: vb(a.calleeName),
Any: tensors,
Hull: []cgen.Gen{
il(a.bundleHull),
il(a.groupHull),
il(a.epochs2),
},
Team: team,
},
},
}.Append(to)
}

func (a *arrangeFilts) calleeFunc() cgen.Gen {
callee := &threader.Callee{
Ctx: a.tc,
Name: a.calleeName,
Task: vb(a.name("task