The recent Deepmind’s Neural Arithmetic Logic Unit(NALU) is a very neat idea. It is a simple module that enables numeracy for neural nets. Contrary to popular belief, neural nets are not very good at arithmetic and counting(if at all). If you train an adder network between 0 and 10, it will do okay if you give it 3 + 5 but won’t be able to extrapolate and will fail miserably for 1000 + 3000. Similarly, if the net is trained to count up to 10, it won’t be able to count to 20. The NALU is able to track time, perform arithmetic, translate numerical language into scalars, execute computer code, and count objects in images.

The central idea behind Neural ALU is a differentiable function that outputs 0, -1 or 1, rendering the concept of addition and subtraction trainable. The beauty also lies in the simplicity of the formula: `tanh(m) * sigmoid(w)` made of fundamental building blocks. If you think about it: tanh is -1 or 1, sigmoid is 0 or 1 so the product of two would be one of 0, 1, -1.

Here is the plot of the function:

This image shows what the NAC looks like:

NALU is just NAC cast into log space and back with a learned gate:

The cost function is

``````0.5 * (y_hat - y) ** 2
``````

so the partial derivative `dJ/dm_0` is

``````dJ/dm_0 = (y_hat - y) * dy_hat/dm_0
= (y_hat - y) * d(x0 * tanh(m_0) * sigmoid(w_0))/dm_0
= (y_hat - y) * x0 * dtanh(m_0) * sigmoid(w_0)
``````

Here is a toy NALU implemented in x86(with SSE!) that uses real Intel FPU ALUs.

``````; Neural ALU implementation in x86_64
;
; 	nasm -felf64 nalu.s
; 	gcc -no-pie nalu.o -o nalu -g
;

%define USE_SUB 1
%define EPOCH 1_000_000

global main
extern printf

section .data
first_fmt: db "first weight: %f, ", 0
second_fmt: db "second weight: %f", 0xA, 0

rand_seed: dd 1
rand_max: dd -2147483648     ; -2^31

section .bss
result: resq 2              ; reserve 2 floats
PRN: resq 2

w_hats: resq 2
m_hats: resq 2

xs: resd 2
tanhs: resd 2
sigms: resd 2

tmp1: resq 2
tmp2: resq 2

weights: resq 1
err: resq 2

section .text

main:

mov ebx, EPOCH

.calc:
cmp ebx, 0
je .exit
dec ebx

.init_rand:
call rand
fstp dword [xs]
call rand
fstp dword [xs+4]

.tanhs_and_sigmoids:
;; first calculate tanhs and put those in tanhs
finit
fld dword [m_hats]
call tanh
fstp dword [tanhs]
finit
fld dword [m_hats+4]
call tanh
fstp dword [tanhs+4]

;; calculate sigmoids
finit
fld dword [w_hats]
call sigmoid
fstp dword [sigms]
finit
fld dword [w_hats+4]
call sigmoid
fstp dword [sigms+4]

.forward_pass:
movdqu xmm0, [tanhs]        ; move 128 bits
movdqu xmm1, [sigms]
movq xmm2, [xs]             ; move 64 bits

mulps xmm0, xmm1            ; tanh * sigmoid

movdqu [weights], xmm0

mulps xmm0, xmm2            ; tanh * sigmoid * xs

%if USE_SUB
hsubps xmm2, xmm2           ; y = x0 - x1
hsubps xmm2, xmm2
%else
haddps xmm2, xmm2           ; y = x0 + x1
%endif

.calc_error:
subps xmm0, xmm2            ; xmm0 <- y_hat - y
extractps eax, xmm0, 1
mov [err], eax

.backpropagate:

finit
;; m[0] -= err * x0 * sigm0 * dtanh(m[0]);
fld dword [m_hats]          ; dtanh(m0)
call dtanh
fld dword [xs]              ; x0
fmul
fld dword [err]             ; err
fmul
fld dword [sigms]           ; sigm0
fmul
fld dword [m_hats]          ; dtanh(m0)
fsubr
fstp dword [m_hats]

finit
;; m[1] -= err * x1 * sigm1 * dtanh(m[1]);
fld dword [m_hats+4]        ; dtanh(m1)
call dtanh
fld dword [xs+4]            ; x1
fmul
fld dword [err]             ; err
fmul
fld dword [sigms+4]         ; sigm1
fmul
fld dword [m_hats+4]        ; dtanh(m1)
fsubr
fstp dword [m_hats+4]

finit
;; w[0] -= err * x0 * dsigmoid(w[0]) * tanh0;
fld dword [w_hats]
call dsigmoid
fld dword [xs]
fmul
fld dword [err]
fmul
fld dword [tanhs]
fmul
fld dword [w_hats]
fsubr
fstp dword [w_hats]

finit
;; w[1] -= err * x1 * dsigmoid(w[1]) * tanh1;
fld dword [w_hats+4]
call dsigmoid
fld dword [xs+4]
fmul
fld dword [err]
fmul
fld dword [tanhs+4]
fmul
fld dword [w_hats+4]
fsubr
fstp dword [w_hats+4]

.print:
sub rsp, 8                  ; reserve stack pointer
movd xmm0, [weights]        ; pass result to printf via xmm0
cvtps2pd xmm0, xmm0         ; convert float to double
mov rdi, first_fmt          ; printf format string
mov rax, 1                  ; number of varargs
call printf                 ; call printf

sub rsp, 8                  ; reserve stack pointer
movd xmm0, [weights+4]      ; pass result to printf via xmm0
cvtps2pd xmm0, xmm0         ; convert float to double
mov rdi, second_fmt         ; printf format string
mov rax, 1                  ; number of varargs
call printf                 ; call printf

jmp .calc

.exit:
mov eax, 0x60
xor edi, edi
syscall

tanh:                           ; (exp(x) - exp(-1)) / (exp(x) + exp(-x))
fst dword [tmp1]            ; tmp1 <- x
call exp;                   ; exp(x)
fst dword [tmp2]            ; tmp2 <- exp(x)
fld dword [tmp1]
fchs
call exp
fst dword [tmp1]            ; tmp1 <- exp(-x)
fld dword [tmp2]
fsubr
fld dword [tmp2]            ; load exp(x) and exp(-x)
fld dword [tmp1]
fdiv
ret

dtanh:                          ; 1. - pow(tanh(x), 2.)
call tanh
fst dword [tmp1]            ; duplicate tanh on the stack
fld dword [tmp1]
fmul                        ; tanh(x) * tanh(x)
fsubr                       ; 1 - tanh(x) ** 2
ret

sigmoid:                        ; 1 / (1 + exp(-x))
fchs                        ; -x
call exp                    ; exp(-x)
fdivr                       ; 1 / ST(0)
ret

dsigmoid:                       ; sigmoid(x) * (1. - sigmoid(x))
call sigmoid
fst dword [tmp1]            ; tmp <- sigmoid(x)
fchs
fld1
fld dword [tmp1]            ; st(0) <- sigmoid(x)
fmul
ret

exp:
fldl2e
fmulp st1,st0               ; st0 = x*log2(e) = tmp1
fld1
fscale                      ; st0 = 2^int(tmp1), st1=tmp1
fxch
fld1
fxch                        ; st0 = tmp1, st1=1, st2=2^int(tmp1)
fprem                       ; st0 = fract(tmp1) = tmp2
f2xm1                       ; st0 = 2^(tmp2) - 1 = tmp3
faddp st1,st0               ; st0 = tmp3+1, st1 = 2^int(tmp1)
fmulp st1,st0               ; st0 = 2^int(tmp1) + 2^fract(tmp1) = 2^(x*log2(e))
ret

rand:
imul eax, dword [rand_seed], 16807 ; RandSeed *= 16807
mov dword [rand_seed], eax
fild dword [rand_seed]             ; load RandSeed as an integer
fidiv dword [rand_max]             ; div by max int value (absolute) = eax / (-2^31)
ret
``````

If you run this, the first `tanh * sigmoid` goes to 1 and second one go to -1.

``````Epoch           l0                  l1
0               0.0                 0.0
50000           0.987506901824      -0.987548950867
100000          0.991264033674      -0.991189817923
150000          0.992845113954      -0.992861588357
200000          0.993821244128      -0.993813140853
250000          0.994479531604      -0.994470005826
300000          0.994956870738      -0.994965214447
350000          0.995335580972      -0.995335751094
400000          0.995641550629      -0.995639510579
450000          0.99588903762       -0.995888041575
500000          0.996102719885      -0.996098271471
550000          0.996282859485      -0.996286010814
600000          0.996444518075      -0.996441767134
650000          0.996583070776      -0.996582158171
700000          0.996711963875      -0.99670336452
750000          0.996820796932      -0.996818826574
800000          0.996921023282      -0.9969240341
850000          0.997012684359      -0.997014549213
900000          0.997100144072      -0.997097107772
950000          0.997177851616      -0.99717492668
``````

Here is a runnable NAC toy example implemented in python:

``````from random import random
import math

def tanh(x):
return math.tanh(x)

def dtanh(x):
return 1. - math.tanh(x) ** 2

def sigmoid(x):
return 1 / (1 + math.exp(-x))

def dsigmoid(x):
return sigmoid(x)*(1-sigmoid(x))

m0 = m1 = w0 = w1 = 0.0

for i in range(1000000):
x0 = random()
x1 = random()
y = x0 - x1

# forward pass
l0 = tanh(m0) * sigmoid(w0)
l1 = tanh(m1) * sigmoid(w1)
y_h = l0 * x0 + l1 * x1

# calculate error
e = y_h - y

# backpropagation
m0 -= e * x0 * sigmoid(w0) * dtanh(m0)
m1 -= e * x1 * sigmoid(w1) * dtanh(m1)
w0 -= e * x0 * dsigmoid(w0) * tanh(m0)
w1 -= e * x1 * dsigmoid(w1) * tanh(m1)

if not i % 50000:
print i, l0, l1
``````

You should see the neural net converge immediately.

# Source

• Neural Accumulator(NAC) Source: arXiv:1808.00508 [cs.NE]

• image source