Hard Autograd for Algebraic Expressions

author:王子轩

Date: 2024-4-6

Chapter 1: Introduction

Background

The application of automatic differentiation technology in frameworks such as torch and tensorflow has greatly facilitated people’s implementation and training of deep learning algorithms based on backpropagation.

Aim

Now, we hope to implement an automatic differentiation program for algebraic expressions.

What we need to do

first build a binary tree according to the given expression, then differentiate it with a recursive algorithm, then use a simplifying function to process the differentiated tree, and finally output the result of the derivation with a middle-order traversal.

Why do this

binary tree data structure can effectively organize the relationship between the numbers and letters, using its non-linear properties can effectively represent the priority between different operators, and then recursive algorithms can effectively clear the head and ideas, and finally also through the middle order of the tree traversal for output.

Chapter 2: Algorithm Specification

Main Function

sketch

1
2
3
4
5
6
7
graph TD;
F(binary tree construction)
F-->l
l(lexicographical order by bubble sorting)
l --> G(Derivation of a binary tree)
G --> H(simplified expression)
H --> I(Middle order traversal output)

Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
node root_node = build(); // binary tree construction
let var[] in the lexicographical order; // bubble sorting
for each variable in var {
if variable has not appeared before {
node newroot = diff(root_node, variable); // Derivation on variables
print "Expressions before simplification" + variable + ":";
printInfix(newroot); // Output the result before simplification
simplify(newroot); // streamlining
print "Simplified expressions." + variable + ":";
printInfix(newroot); // Output the simplified result
}
}

Build :Used to build expression trees

sketch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
graph TD
Start --> Initialize
Initialize --> ReadCharacters
ReadCharacters --> |Number| UpdateNumber
ReadCharacters --> |Operator| ProcessOperator
ReadCharacters --> |Special Character| HandleSpecialCases
UpdateNumber --> |Previous number| UpdateNumber
UpdateNumber --> |New number| CreateNode
CreateNode --> StoreNumber
ProcessOperator --> ProcessPriority
HandleSpecialCases --> |Mathematical function| PushOperator
HandleSpecialCases --> HandleCases
ProcessPriority --> |Higher priority| PushStack
ProcessPriority --> |Lower priority| PopBuildNodes
PopBuildNodes --> CheckPriority
CheckPriority --> |Satisfied| End
CheckPriority --> |Not satisfied| PopBuildNodes
End --> Stop

Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
build():
Initialize variables and arrays
Loop to read characters
If it is a number:
If the previous one is a number:
Update the number
Else:
Create a new node and store the number
If it is an operator:
Process nodes and operators based on priority
If it is a special character (e.g., parentheses):
Handle special cases
Return the constructed binary tree

process operator priority():
If the current operator has higher priority than the top of the stack:
Push onto the stack
Else:
Pop and build nodes until the priority is satisfied

handle special cases():
If it is a mathematical function:
Push it as an operator onto the stack
Else:
Handle special cases like parentheses

Diff :Used to differentiate functions

sketch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
graph TD
A[Start] --> B{Root is NULL?}
B -- No --> C{Check root data}
C -- '+' or '-' --> D{Create new node for derivative}
D --> E[Recursively differentiate left subtree]
E --> F[Recursively differentiate right subtree]
F --> G[Attach derivatives to new node]
C -- '*' --> H{Create new node for derivative}
H --> I[Recursively differentiate left subtree]
I --> J[Recursively differentiate right subtree]
J --> K[Attach derivatives to new node]
C -- '/' --> L{Create new node for derivative}
L --> M[Recursively differentiate left subtree]
M --> N[Recursively differentiate right subtree]
N --> O[Attach derivatives to new node]
C -- '^' --> P{Create new node for derivative}
P --> Q[Calculate derivative using power rule]
Q --> R[Attach derivatives to new node]
C -- Default --> S{Check if constant or variable}
S -- Constant --> T[Create node with derivative 0]
S -- Variable --> U{Check if variable matches}
U -- Yes --> V[Create node with derivative 1]
U -- No --> W[Create node with derivative 0]
A --> B

Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
function diff(root, variable)
if root is NULL
return NULL
newRoot = NULL
switch (root.data)
case '+':
case '-':
Handle case '+' and '-'
case '*':
Handle case '*'
case '/':
Handle case '/'
case '^':
Handle case '^'
case -10:
Handle case "ln"
case -9:
Handle case "log"
case -8:
Handle case "cos"
case -7:
Handle case "sin"
case -6:
Handle case "tan"
case -5:
Handle case "pow"
case -4:
Handle case "exp"
default:
if root.data != 0
newRoot = createNode(0)
else
if root.Var == variable
newRoot = createNode(1)
else
newRoot = createNode(0)
return newRoot

Simplify:Used to simplify expressions

sketch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
graph TD
B{Is the root node empty?}
B -- yes --> C[return]
B -- no --> D[Simplified left subtree]
D --> E[Simplified right subtree]
E --> F{Operator Type}
F -- addition operator --> G[Perform additive simplification]
F -- subtraction operator--> H[Implementation of subtractive simplification]
F -- multiplication operator --> I[Perform multiplicative simplification]
F -- division operator --> J[Perform division simplification]
F -- else --> K[Implementation of other simplifications]
G --> L[return]
H --> L
I --> L
J --> L
K --> L

Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
function simplify(root)
if root is NULL
return
simplify(root.left)
simplify(root.right)

switch (root.data)
case '+':
Handle case '+'
case '-':
Handle case '-'
case '*':
Handle case '*'
case '/':
Handle case '/'
case '^':
Handle case '^'
return

PrintInfix:Used to output mid-range expressions

sketch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
graph TD
A[Start] --> B{root == NULL?}
B -- No --> C{first?}
C -- Yes --> D{qu = 1}
D --> E{qu?}
E -- Yes --> F(No output of brackets)
E -- No --> G{root->left != NULL or root->right != NULL?}
G -- Yes --> H{Print left bracket}
H --> I{Call printInfix root->left}
I --> J{Print root->Var or root->data}
J --> K{Call printInfix root->right}
K --> L{Print right bracket}
L --> M
G -- No --> M
C -- No --> M
B -- Yes --> M
M[return]

Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
if root is NULL
return
if first
qu = 1
first = 0
if qu
do nothing
else if root's left child is not NULL or right child is not NULL
print "("
Call printInfix with root's left child
if root's Var is not empty
print root's Var
else if root's data is an operator
print root's data
else if root's data is 'L'
print "ln"
else if root's data is less than 0
print dic_math[root's data + 10]
else
print root's data
Call printInfix with root's right child
if qu
do nothing
else if root's left child is not NULL or right child is not NULL
print ")"

Bonus

Not using complex containers provided in STL

done

Support for expressions containing mathematical functions

done

Simplify an algebraic expression to reduce the length of the result by applying at least two rules.

done

Rules used in simplification:
  1. addition:
    1. If an item is 0, delete it.
  2. subtractive:
    1. Delete the previous item if it is zero.
    2. Delete the minus sign and the last item if the last item is zero.
  3. subtraction:
    1. Returns 0 if one of the items is 0.
    2. If one item is 1, return the other.
  4. division :
    1. If the numerator is 0, return 0.
    2. Return the numerator if the denominator is 1
  5. calculate the square:
    1. If the index is 0, return 1

Chapter 3: Testing Results

test case aim Expected results Actual results
a+b^c*d Test addition multiplication multiplication compound operations a: 1
b: c*b^(c-1)*d
c: ln(b)*b^c*d
d: b^c
a:1
b:(((b^c)*(c/b))*d)
c:(((b^c)*(lnb))*d)
d:(b^c)
a*10*b+2^a/a Test addition multiplication multiplication division compound operations a: 10*b-2^a/a^2+2^a*ln(2)/a
b: a*10
a:(10*b)+(((((2^a)*(ln2))*a)-(2^a))/(a^2))
b:(a*10)
xx^2/xy*xy+a^a Testing arithmetic on multi-character variables a: a^a*(1+ln(a))
xx: 2*xx
xy: 0
a:(a^a)*((lna)+(a/a))
xx:(((((xx^2)*(2/xx))*xy)/(xy^2))*xy)
xy:((((-(xx^2))/(xy^2))*xy)+((xx^2)/xy))
x*ln(y) Testing the performance of the lnx function x:lny
y:x*(1/y)
x:lny
y:(x*(1/y))
x*ln(x*y)+y*cos(x)+y*sin(2*x) Testing the performance of lnx, trigonometric functions x:(ln(x*y)+(x*(y/(x*y))))+y*(1/(cosx)^2+y*2/(cos(2*x)^2)) y:x*x/(x*y)+cosx+y*sin(2*x)+y*sin(2*x) x:(ln(x*y)+(x*(y/(x*y))))+y*(1/(cosx)^2+y*2/(cos(2*x)^2)) y:x*x/(x*y)+cosx+y*sin(2*x)+y*sin(2*x)
log(a,b)/log(c,a) Testing the performance of the log function a:(-((log(a,b))*((((1/a)*(lnc))-((1/a)*(lna)))/((lnc)^2))))/((log(c,a))^2)
b:((((((1/b)*(lna))-((1/b)*(lnb)))/((lna)^2))*(log(c,a)))/((log(c,a))^2))
c:0
a:(-((log(a,b))*((((1/a)*(lnc))-((1/a)*(lna)))/((lnc)^2))))/((log(c,a))^2)
b:((((((1/b)*(lna))-((1/b)*(lnb)))/((lna)^2))*(log(c,a)))/((log(c,a))^2))
c:0
a^x^a^x^a Testing Multiplier Composite Extreme Sample Performance a:(a^(x^(a^(x^a))))*((((x^(a^(x^a)))*(((a^(x^a))*((((x^a)*(lnx))*(lna))+((x^a)/a)))*(lnx)))*(lna))+((x^(a^(x^a)))/a))

x:((a^(x^(a^(x^a))))*(((x^(a^(x^a)))*((((a^(x^a))*(((x^a)*(a/x))*(lna)))*(lnx))+((a^(x^a))/x)))*(lna)))
a:(a^(x^(a^(x^a))))*((((x^(a^(x^a)))*(((a^(x^a))*((((x^a)*(lnx))*(lna))+((x^a)/a)))*(lnx)))*(lna))+((x^(a^(x^a)))/a))

x:((a^(x^(a^(x^a))))*(((x^(a^(x^a)))*((((a^(x^a))*(((x^a)*(a/x))*(lna)))*(lnx))+((a^(x^a))/x)))*(lna)))
log(log(a,b),log(c,a)) Testing extreme sample performance of logarithmic functions a:(((((((1/a)*(lnc))-((1/a)*(lna)))/((lnc)^2))/(log(c,a)))*(ln(log(a,b))))-((((((1/a)*(lnc))-((1/a)*(lna)))/((lnc)^2))/(log(c,a)))*(ln(log(c,a)))))/((ln(log(a,b)))^2)
b:0
c:0
a:(((((((1/a)*(lnc))-((1/a)*(lna)))/((lnc)^2))/(log(c,a)))*(ln(log(a,b))))-((((((1/a)*(lnc))-((1/a)*(lna)))/((lnc)^2))/(log(c,a)))*(ln(log(c,a)))))/((ln(log(a,b)))^2)
b:0
c:0

Chapter 4: Analysis and Comments

Time Complexity Analysis:

  • The process of constructing a binary tree involves traversing the medial expression with a time complexity of O(n), where n is the length of the medial expression.

  • The process of derivation and simplification of different variables involves traversing the binary tree nodes with a time complexity of O(n), where n is the number of binary tree nodes.

  • There is no significant loop nesting in the algorithm and the overall time complexity is O(n).

Space complexity analysis:

  • The algorithm uses arrays and structures to store data structures such as midfix expressions, variables, operator stacks, node stacks, etc. The space complexity depends on the size of these data structures and is O(n).

  • Recursive calls to the derivation function may take up additional stack space, but since the recursion depth will not be large, the space complexity is negligible.

Possible further improvements:

  1. Consider optimizing the algorithm, e.g., using more efficient data structures to store the medial expression to reduce unnecessary memory usage.
  2. For complex mathematical functions, the processing logic can be further optimized to improve the efficiency of the algorithm.
  3. Consider introducing an error handling mechanism to increase the validation of input expressions to prevent the program from crashing or outputting incorrect results.
  4. More mathematical function support can be added to make the program more general.
  5. For the case of large-scale input, parallelization of processing can be considered to improve the running efficiency of the program.

Appendix: Source Code (in C)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
#include <stdio.h>
#include <stdlib.h>
#include <string.h> //Import related libraries
char var[100][100] = {'\0'}; // Used to record variables
char dic_math[10][10] = {"ln", "log", "cos", "sin", "tan", "pow", "exp"}; // Dictionary of math functions
int var_cnt = 0, var_len = 0; // Define variable counters and variable lengths
typedef struct node // Define the binary tree node structure
{
int data; // Store symbols and numbers
char Var[10]; // Store variables
struct node *left; // Pointer to the left node
struct node *right; // pointer to the right node
} BTnode; // Rename the type names of binary tree nodes for ease of writing
int first = 1;
// Related Functions
BTnode *createNode(int data); // Functions for new nodes
BTnode *build(); // Build the binary tree of the original medial expression.
BTnode *diff(BTnode *root, char *variable); // Functions used to differentiate binary trees
void printInfix(BTnode *root); // Functions used for middle-order traversal
void simplify(BTnode *root); // Functions used to simplify derivative expressions
int main() // Main function
{
BTnode *root_node = build(); // Build the binary tree such that its middle-order traversal results in the given medial expression
for (int i = 1; i < var_cnt;i++){//lexicographical order by bubble sorting
for(int j = 0; j < var_cnt - i; j++){
if(strcmp(var[j],var[j+1])>0){
char tmp[100] = {0};
strcpy(tmp,var[j+1]);
strcpy(var[j+1],var[j]);
strcpy(var[j],tmp);
}
}
}
for (int i = 0; i < var_cnt; i++) // Perform a traversal with separate derivatives for different variables
{
int flg = 1; // Use flg to note if the variable has appeared before
for (int j = 0; j < i; j++)
{
if (!strcmp(var[j], var[i]))
flg = 0;
}
if (flg)
{
BTnode *newroot = diff(root_node, var[i]); // Derivation
simplify(newroot); // Simplify
printf("(simplified)%s:", var[i]);
printInfix(newroot); // Output the simplified result
printf("\n");
}
}
return 0;
}

BTnode *createNode(int data)
{ // Used to create a new node for later calls and to prevent duplicate writing.
BTnode *newNode = (BTnode *)malloc(sizeof(BTnode));
newNode->data = data;
(newNode->Var)[0] = '\0'; // Initialize to empty string
newNode->left = NULL; // Initialize to null pointer
newNode->right = NULL; // Initialize to null pointer
return newNode; // Return the new pointer
}
BTnode *build() // Build the binary tree such that its middle-order traversal results in the given medial expression
{
char ope[100] = {'\0'}; // Use arrays to simulate a stack for storing operators.
char dic_ope[10] = {'^', '/', '*', '-', '+', ',', '('}; // Operator dictionary to define priorities
int num[100] = {0}; // Use arrays to simulate a stack for storing operators.
BTnode *node_stack = (BTnode *)malloc(sizeof(BTnode) * 500); // Open up a space for a stack that holds nodes
for (int i = 0; i < 500; i++) // Initialize the node stack to empty
{
(node_stack + i)->Var[0] = '\0';
}
int ope_len = 0, num_len = 0, node_len = 0; // Defines the length of the respective stacks, corresponding to the index that is the last bit of the top element.
char tmp; // Define a temporary variable to store the characters for each read
int pre_pos = 7; // Initialize a pre-position to store the position of the previous operator for comparing priorities
int pre_isnum = 0, pre_isvar = 0; // The type used to mark the previous one
int math = 0; // Used to mark where in the dictionary the identified math function is located
tmp = getchar(); // Take one out first, determine if it's a negative sign, and handle it separately
if (tmp == '-')
{
tmp = '0'; // If it is a negative sign, add a 0 before it to normalize it.
node_stack[node_len].data = tmp - '0'; // Store numbers
node_stack[node_len].Var[0] = 0; // Variable set to empty string
node_stack[node_len].left = NULL; // Left node pointer initialized to null pointer
node_stack[node_len].right = NULL; // Right node pointer initialized to null pointer
node_len++; // Length setback
pre_isnum = 1; // Mark the previous one as a number
pre_isvar = 0; // Mark the previous one as not a variable
var_len = 0; // Initialize variable length to 0
tmp = '-'; // Turn the variable back into a negative sign
}
do // Begin cycle of reading in characters
{
if ('0' <= tmp && tmp <= '9')
{
if (pre_isnum) // Determine if the previous one is a number
{
node_stack[node_len - 1].data = node_stack[node_len - 1].data * 10 + tmp - '0'; // Update the number if the previous one was a number
}
else
{
node_stack[node_len].data = tmp - '0'; // Store numbers
node_stack[node_len].Var[0] = 0; // Variable set to empty string
node_stack[node_len].left = NULL; // Left node pointer initialized to null pointer
node_stack[node_len].right = NULL; // Right node pointer initialized to null pointer
node_len++; // Length setback
pre_isnum = 1; // Mark the previous one as a number
pre_isvar = 0; // Mark the previous one as not a variable
var_len = 0; // Initialize variable length to 0
}
}
else if (tmp == '(' || tmp == ')' || tmp == '+' || tmp == '-' || tmp == '*' || tmp == '/' || tmp == '^' || tmp == ',')
{
if (tmp == ')')
{
if (math) // Separate handling of mathematical formulas
{
while (ope[ope_len - 1] != '(') // Start by working inside the parentheses of the math equation.
{
node_stack[node_len].data = ope[ope_len - 1];
node_stack[node_len].Var[0] = 0;
node_stack[node_len].left = (BTnode *)malloc(sizeof(BTnode)); // Allocate new memory space to the left and right child nodes
node_stack[node_len].left->Var[0] = '\0'; // Initialize the left node variable to the empty string
*node_stack[node_len].left = node_stack[node_len - 2]; // out of the stack as a left node
node_stack[node_len].right = (BTnode *)malloc(sizeof(BTnode));
node_stack[node_len].right->Var[0] = '\0'; // Initialize the right node's variable to the empty string
*node_stack[node_len].right = node_stack[node_len - 1]; // Out of the stack as a right node
node_stack[node_len - 2] = node_stack[node_len]; // Putting the combined new node back
node_len--; // Node stack length reduction
ope_len--; // Symbol stack length reduction
}
node_stack[node_len].data = ope[ope_len - 2]; // Substitute the math equation
node_stack[node_len].Var[0] = 0;
node_stack[node_len].left = NULL;
node_stack[node_len].right = (BTnode *)malloc(sizeof(BTnode));
*node_stack[node_len].right = node_stack[node_len - 1]; // Store the internals of the formula in the right node of the operator.
node_stack[node_len - 1] = node_stack[node_len];
ope_len -= 2; // Symbol stack out of two
pre_isnum = 0; // Record that the previous one was not a number
pre_isvar = 0; // Record that the previous one was not a variable
var_len = 0; // Variable length reset
}
else // Not a mathematical formula
{
while (ope[ope_len - 1] != '(') // The special judgment is not a left bracket
{
node_stack[node_len].data = ope[ope_len - 1];
node_stack[node_len].Var[0] = 0;
node_stack[node_len].left = (BTnode *)malloc(sizeof(BTnode)); // Allocate new memory space to the left and right child nodes
node_stack[node_len].left->Var[0] = '\0'; // Initialize the left node variable to the empty string
*node_stack[node_len].left = node_stack[node_len - 2]; // out of the stack as a left node
node_stack[node_len].right = (BTnode *)malloc(sizeof(BTnode));
node_stack[node_len].right->Var[0] = '\0'; // Initialize the right node's variable to the empty string
*node_stack[node_len].right = node_stack[node_len - 1]; // out of the stack as a right node
node_stack[node_len - 2] = node_stack[node_len]; // Put the combined new node back
node_len--; // node stack length reduction
ope_len--; // Symbol stack length reduction
}
ope_len--;
pre_isnum = 0; // Record that the previous one was not a number
pre_isvar = 0; // Record that the previous one was not a variable
var_len = 0; // Variable length reset
}
}
else
{
int i = 0; // Loop through the dictionary to find the position of the operator.
for (; i < 7; i++)
{
if (dic_ope[i] == tmp)
break;
}
if (pre_pos >= i || tmp == '(') // Ensure that priorities on the stack are ascending from bottom to top
{
ope[ope_len] = tmp; // Into the stack
ope_len++; // More stack length
pre_isnum = 0;
pre_isvar = 0;
var_len = 0;
}
else // If the priority has been descended
{
while (pre_pos <= i)
{
node_stack[node_len].data = ope[ope_len - 1]; // Out of the stack
node_stack[node_len].Var[0] = 0; // Initialized to 0
node_stack[node_len].left = (BTnode *)malloc(sizeof(BTnode)); // Allocation of space
node_stack[node_len].left->Var[0] = '\0'; // Initialized to 0
*node_stack[node_len].left = node_stack[node_len - 2]; // out of the stack to the left node
node_stack[node_len].right = (BTnode *)malloc(sizeof(BTnode)); // Allocation of space
node_stack[node_len].right->Var[0] = '\0'; // Initialized to 0
*node_stack[node_len].right = node_stack[node_len - 1]; // out of the stack to the right node
node_stack[node_len - 2] = node_stack[node_len]; // Put the new node on the stack
node_len--;
ope_len--;
for (pre_pos = 0; pre_pos < 7; pre_pos++) // Find the operator's position in the dictionary again.
{
if (dic_ope[pre_pos] == ope[ope_len - 1])
break;
}
}
ope[ope_len] = tmp; // Into the stack
ope_len++; // More stack length
pre_isnum = 0;
pre_isvar = 0;
var_len = 0;
}
pre_pos = i;
}
}
else
{
if (pre_isvar) // The previous one was a variable
{
node_stack[node_len - 1].Var[var_len] = tmp; // Consecutive to the previous character
var[var_cnt - 1][var_len] = tmp; // Record variables
var_len++;
node_stack[node_len - 1].Var[var_len] = '\0';
// {"ln", "log", "cos", "sin", "tan", "pow", "exp"}
int math_pos = 0;
for (math_pos = 0; math_pos < 7; math_pos++) // Compare in a dictionary of math formulas
{
if (!strcmp(dic_math[math_pos], node_stack[node_len - 1].Var))
{
break;
}
}
if (math_pos != 7) // If it's a math equation, put it on the symbol stack as a symbol.
{
ope[ope_len] = math_pos - 10; // Marked with the corresponding index minus 10 to prevent confusion with numbers
ope_len++; // Stack length becomes more
node_stack[node_len - 1].Var[0] = 0; // Clear the variables
node_stack[node_len - 1].Var[1] = 0; // Clear the variables
node_stack[node_len - 1].Var[2] = 0; // Clear the variables
node_stack[node_len - 1].Var[3] = 0; // Clear the variables
node_stack[node_len - 1].Var[4] = 0; // Clear the variables
var[var_cnt - 1][0] = 0; // Clear the variables
var[var_cnt - 1][1] = 0; // Clear the variables
var[var_cnt - 1][2] = 0; // Clear the variables
var[var_cnt - 1][3] = 0; // Clear the variables
var[var_cnt - 1][4] = 0; // Clear the variables
node_len--;
pre_isnum = 0;
pre_isvar = 0;
var_len = 0;
var_cnt--;
math = 1;
}
}
else
{
node_stack[node_len].Var[var_len] = tmp; // Stored on the node stack
node_stack[node_len].left = NULL; // Initialize nodes
node_stack[node_len].right = NULL; // Initialize nodes
node_stack[node_len].data = 0; // Initialize nodes
node_len++;
var[var_cnt][var_len] = tmp; // Stored on the variable stack
var_cnt++;
var_len = 1;
pre_isvar = 1;
node_stack[node_len - 1].Var[var_len] = '\0';
pre_isnum = 0;
}
}
} while ((tmp = getchar()) != '\n');
while (node_len != 1) // take any elements that have not yet been combined off the stack
{
node_stack[node_len].data = ope[ope_len - 1]; // Out of the symbol stack
node_stack[node_len].Var[0] = 0;
node_stack[node_len].left = (BTnode *)malloc(sizeof(BTnode));
*node_stack[node_len].left = node_stack[node_len - 2]; // Out of the left node
node_stack[node_len].right = (BTnode *)malloc(sizeof(BTnode));
*node_stack[node_len].right = node_stack[node_len - 1]; // Out right node
node_stack[node_len - 2] = node_stack[node_len]; // Put the new node on the stack
node_len--;
ope_len--;
}
BTnode *root = (BTnode *)malloc(sizeof(BTnode));
root->Var[0] = '\0';
*root = node_stack[0];
return root;
}

void printInfix(BTnode *root) // Middle-order traversal output
{

if (root == NULL)
{
return;
}
int qu = 0;
if (first) // Outer brackets are not output
{
qu = 1;
first = 0;
}
if (qu)
;
else if (root->left != NULL || root->right != NULL) // Wrap in parentheses as long as it's not a leaf node
{
printf("(");
}
printInfix(root->left); // Recursively output the left node
if (root->Var[0] != 0) // Output variables
{
printf("%s", root->Var);
}
else if (root->data == '^' || root->data == '+' || root->data == '-' || root->data == '/' || root->data == '*' || root->data == ',')
{
printf("%c", root->data); // Output symbols
}
else if (root->data == 'L')
{
printf("ln"); // Output ln
}
else if (root->data < 0)
{
printf("%s", dic_math[root->data + 10]); // Output math functions
}
else
{
printf("%d", root->data); // Output digital
}
printInfix(root->right); // Recursively output the right node
if (qu) // Outer brackets are not output
;
else if (root->left != NULL || root->right != NULL)
{
printf(")"); // Wrap in parentheses as long as it's not a leaf node
}
return;
}

BTnode *diff(BTnode *root, char *variable)
{
if (root == NULL) // Recursion boundaries
{
return NULL;
}

BTnode *newRoot = NULL; // Define a new node pointer for return
switch (root->data)
{
case '+':
case '-': // Recursive tree building using derivation formulas for addition and subtraction
newRoot = createNode(root->data);
newRoot->left = diff(root->left, variable);
newRoot->right = diff(root->right, variable);
break;
case '*': // Recursive tree building using the derivation formula for multiplication
newRoot = createNode('+');
newRoot->left = createNode('*');
newRoot->left->left = diff(root->left, variable);
newRoot->left->right = root->right;
newRoot->right = createNode('*');
newRoot->right->left = root->left;
newRoot->right->right = diff(root->right, variable);
break;
case '/': // Recursive tree building using the derivation formula for division
newRoot = createNode('/');
newRoot->left = createNode('-');
newRoot->left->left = createNode('*');
newRoot->left->left->left = diff(root->left, variable);
newRoot->left->left->right = root->right;
newRoot->left->right = createNode('*');
newRoot->left->right->left = root->left;
newRoot->left->right->right = diff(root->right, variable);
newRoot->right = createNode('^');
newRoot->right->left = root->right;
newRoot->right->right = createNode(2);
break;
case '^': // Recursive tree building using the derivation formula for multiplication
newRoot = createNode('*');
newRoot->left = createNode('^');
newRoot->left->left = root->left;
newRoot->left->right = root->right;
newRoot->right = createNode('+');
newRoot->right->left = createNode('*');
newRoot->right->left->left = diff(root->right, variable);
newRoot->right->left->right = createNode('L');
newRoot->right->left->right->left = NULL;
newRoot->right->left->right->right = root->left;
newRoot->right->right = createNode('/');
newRoot->right->right->left = createNode('*');
newRoot->right->right->left->left = root->right;
newRoot->right->right->left->right = diff(root->left, variable);
newRoot->right->right->right = root->left;
break;
case -10: // Recursive tree building using derivation formulas for math equations
newRoot = createNode('/');
newRoot->left = diff(root->right, variable);
newRoot->right = root->right;
break;
case -9: // Recursive tree building using derivation formulas for math equations
newRoot = createNode('/');
newRoot->right = createNode('^');
newRoot->right->left = createNode(-10);
newRoot->right->left->right = root->right->left;
newRoot->right->right = createNode(2);
newRoot->left = createNode('-');
newRoot->left->left = createNode('*');
newRoot->left->right = createNode('*');
newRoot->left->left->left = createNode('/');
newRoot->left->left->right = createNode(-10);
newRoot->left->left->left->left = diff(root->right->right, variable);
newRoot->left->left->left->right = root->right->right;
newRoot->left->left->right->right = root->right->left;
newRoot->left->right->right = createNode(-10);
newRoot->left->right->right->right = root->right->right;
newRoot->left->right->left = createNode('/');
newRoot->left->right->left->left = diff(root->right->right, variable);
newRoot->left->right->left->right = root->right->right;
break;
case -8: // recursively builds a tree using the derivation formula of the math equation
newRoot = createNode('*');
newRoot->left = createNode('-');
newRoot->left->left = createNode(0);
newRoot->left->right = createNode(-7);
newRoot->left->right->right = root->right;
newRoot->right = diff(root->right, variable);
case -7: // Recursive tree building using derivation formulas for math equations
newRoot = createNode('*');
newRoot->left = createNode(-8);
newRoot->left->right = root->right;
newRoot->right = diff(root->right, variable);
case -6: // Recursive tree building using derivation formulas for math equations
newRoot = createNode('/');
newRoot->left = diff(root->right, variable);
newRoot->right = createNode('^');
newRoot->right->right = createNode(2);
newRoot->right->left = createNode(-8);
newRoot->right->left->right = root->right;
break;
case -5: // Recursive tree building using derivation formulas for math equations
newRoot = createNode('*');
newRoot->left = createNode('^');
newRoot->left->left = root->right->left;
newRoot->left->right = root->right->right;
newRoot->right = createNode('+');
newRoot->right->left = createNode('*');
newRoot->right->left->left = diff(root->right->right, variable);
newRoot->right->left->right = createNode('L');
newRoot->right->left->right->left = NULL;
newRoot->right->left->right->right = root->right->left;
newRoot->right->right = createNode('/');
newRoot->right->right->left = createNode('*');
newRoot->right->right->left->left = root->right->right;
newRoot->right->right->left->right = diff(root->right->left, variable);
newRoot->right->right->right = root->right->left;
break;
case -4: // Recursive tree building using derivation formulas for math equations
newRoot = createNode('*');
newRoot->left = createNode(-4);
newRoot->left->right = root->right;
newRoot->right = diff(root->right, variable);
break;
default:
if (root->data != 0) // Constant derivative is 0
{
newRoot = createNode(0);
}
else // Alphabetic derivatives are 1 only for the corresponding variable, and 0 for all other parameters.
{
if (!strcmp(root->Var, variable))
{
newRoot = createNode(1);
}
else
{
newRoot = createNode(0);
}
}
break;
}
return newRoot; // Return to new node
}

void simplify(BTnode *root) // Simplifying function
{
if (root == NULL) // Recursive function boundaries
return;
simplify(root->left); // Recursively simplify the left subtree
simplify(root->right); // Recursively simplify the right subtree
switch (root->data) // Discussion based on notation, specific simplification
{
case '+': // If an item is 0, delete it.
if ((root->left->Var[0] == 0) && (root->left->data == 0))
{
*root = *root->right;
}
else if ((root->right->Var[0] == 0) && (root->right->data == 0))
{
*root = *root->left;
}
break;
case '-': // Delete if previous item is zero
if ((root->right->Var[0] == 0) && (root->right->data == 0))
{
*root = *root->left;
} // Delete the minus sign at the same time if the last item is zero.
else if ((root->left->Var[0] == 0) && (root->left->data == 0))
{
root->left = NULL;
}
break;
case '*': // Returns 0 if one of the items is 0
if (((root->left->Var[0] == 0) && (root->left->data == 0)) || (root->right->Var[0] == 0) && (root->right->data == 0))
{
*root = *createNode(0);
} // If one item is 1, return the other
else if (root->left->data == 1)
{
*root = *root->right;
}
else if (root->right->data == 1)
{
*root = *root->left;
}
break;
case '/': // If the numerator is 0, return 0
if ((root->left->Var[0] == 0) && (root->left->data == 0))
{
*root = *createNode(0);
}
else if (root->right->data == 1)
{ // Return the numerator if the denominator is 1
*root = *root->left;
}
break;
case '^': // If the index is 0, return 1
if ((root->right->Var[0] == 0) && (root->right->data == 0))
{
*root = *createNode(1);
}
break;
default:
break;
}
return;
}

Declaration

I hereby declare that all the work done in this project titled “Autograd for Algebraic Expressions” is of my independent effort