summaryrefslogtreecommitdiff
path: root/booth_multiplier.py
blob: afd6710d57af1bda331a6701cb3cec5f86dd2717 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
from tabulate import tabulate
import matplotlib.pyplot as plt

with open('input.txt') as f:
  input_string = f.read().split('\n')

def twos_comp(num, length):
  if num == 0:
    return 0
  return abs((num ^ ((1 << length) - 1)) + 1)

def arithmatic_shiftr(num, length, times):
  for t in range(times):
    num = (num >> 1) | ((1 << length - 1) & num)
  return num

def arithmatic_shiftl(num, length):
  if num & (1 << length - 1):
    return (num << 1) | (1 << length - 1)
  else:
    return (num << 1) & ~(1 << length - 1)


def twoscomp_to_int(num, length):
  if num & (1 << length - 1):
    return (-1 * twos_comp(num, length))
  return num & (1 << length) - 1

def debug(results):
  headers = ['multiplicand bin', 'multiplier bin', 'multiplicand dec', 'multiplier dec', 'expected bin', 'expected dec', 'booth if correct', 'booth mod if correct']
  table = []
  for [multiplicand_bin, multiplier_bin, result_booth, result_booth_mod, length] in results:
    multiplicand = twoscomp_to_int(multiplicand_bin, length)
    multiplier = twoscomp_to_int(multiplier_bin, length)
    expected = multiplicand * multiplier
    expected_bin = (twos_comp(expected, length * 2), expected) [expected > 0]
    success_b = [bin(result_booth), "PASS"] [result_booth == expected_bin]
    success_bm = [bin(result_booth_mod), "PASS"] [result_booth_mod == expected_bin]
    
    table.append([bin(multiplicand_bin), bin(multiplier_bin), multiplicand, multiplier, bin(expected_bin), expected, success_b, success_bm])
  print("\nCHECKS: \n", tabulate(table, headers), "\n")


  
def booth(multiplier, multiplicand, length):
  operations = 0
  multiplicand_twos_comp = twos_comp(multiplicand, length)
  result = multiplier << 1 # extended bit
  for i in range(length):
    op = result & 0b11
    if op == 0b01:
      operations += 1
      result += multiplicand << (length + 1)
    if op == 0b10:
      operations += 1
      result += multiplicand_twos_comp << (length + 1)
    result &= (1 << (length * 2) + 1) - 1 # get rid of any overflows
    result = arithmatic_shiftr(result, (length * 2) + 1, 1)
  result = result >> 1
  return (result, operations)

def booth_mod(multiplier, multiplicand, length):
  operations = 0
  multiplicand |= ((1 << length - 1) & multiplicand) << 1 # extend multiplicand sign to prevent overflow when mult/sub by 2
  multiplicand_twos_comp = twos_comp(multiplicand, length + 1)
  result = multiplier << 1 # extended bit
  for i in range(int((length) / 2)):
    op = result & 0b111
    match op:
      case 0b010 | 0b001: # add
        print("add")
        result += multiplicand << (length + 1)
        operations += 1
      case 0b011:         # add * 2
        print("add * 2")
        result += arithmatic_shiftl(multiplicand, length + 1) << (length + 1) 
        operations += 1
      case 0b100:         # sub * 2
        print("sub * 2")
        result += arithmatic_shiftl(multiplicand_twos_comp, length + 1) << (length + 1)
        operations += 1
      case 0b101 | 0b110: # sub
        print("sub ")
        result += multiplicand_twos_comp << (length + 1)
        operations += 1
    result &= (1 << ((length * 2) + 2)) - 1 # get rid of any overflows
    result = arithmatic_shiftr(result, (length * 2) + 2, 2)
  # *barfs on your dog*
  result = ((result | ((1 << ((length * 2) + 2)) >> 1)) & ((1 << ((length * 2) + 1)) - 1)) >> 1
  return (result, operations)

if __name__ == "__main__":
  headers = ['multiplicand', 'multiplier', 'result (bin)', 'result (hex)']
  table = []
  lengths = [] # for matplotlib plot
  ops_booth = []
  ops_mod_booth = []

  debug_results = []

  for operation in input_string:
    if operation == '' or operation[0] == '#':
      continue
    length = len(operation.split(" ")[0])
    multiplicand = int(operation.split(" ")[0], 2)
    multiplier = int(operation.split(" ")[1], 2)

    
    result_booth = booth(multiplier, multiplicand, length)
    result_mod_booth = booth_mod(multiplier, multiplicand, length)

    # gather data for matplotlib
    ops_booth.append(result_booth[1])
    ops_mod_booth.append(result_mod_booth[1])
    lengths.append(length)
    
    table.append([bin(multiplicand), bin(multiplier), bin(result_booth[0]), hex(result_booth[0])])
    debug_results.append([multiplicand, multiplier, result_booth[0], result_mod_booth[0], length])

  debug(debug_results)
  print(tabulate(table, headers))
  
  # generate graph
  plt.plot(lengths, ops_booth, '^--m', label='booths algorithim')
  plt.plot(lengths, ops_mod_booth, 'v--c', label='modified booths algorithim')
  plt.gca().set_xlabel("Length of Operands")
  plt.gca().set_ylabel("Number of Additions and Subtractions")
  plt.legend(loc='upper left')
  plt.show()