RMUL2025/lib/cmsis_5/CMSIS/DSP/Testing/PatternGeneration/ComplexMaths.py

128 lines
3.8 KiB
Python
Executable File

import os.path
import numpy as np
import itertools
import Tools
# Those patterns are used for tests and benchmarks.
# For tests, there is the need to add tests for saturation
def randComplex(nb):
data = np.random.randn(2*nb)
data = Tools.normalize(data)
data_comp = data.view(dtype=np.complex128)
return(data_comp)
def asReal(a):
#return(a.view(dtype=np.float64))
return(a.reshape(np.size(a)).view(dtype=np.float64))
def writeTests(config,format):
NBSAMPLES=256
data1=randComplex(NBSAMPLES)
data2=randComplex(NBSAMPLES)
data3=np.random.randn(NBSAMPLES)
data3 = Tools.normalize(data3)
config.writeInput(1, asReal(data1))
config.writeInput(2, asReal(data2))
config.writeInput(3, data3)
ref = np.conj(data1)
config.writeReference(1, asReal(ref))
nb = Tools.loopnb(format,Tools.TAILONLY)
ref = np.array(np.dot(data1[0:nb],data2[0:nb]))
if format==31:
ref = ref / 2**15 # Because CMSIS format is 16.48
config.writeReferenceQ63(2, asReal(ref))
elif format==15:
ref = ref / 2**7 # Because CMSIS format is 8.24
config.writeReferenceQ31(2, asReal(ref))
else:
config.writeReference(2, asReal(ref))
nb = Tools.loopnb(format,Tools.BODYONLY)
ref = np.array(np.dot(data1[0:nb] ,data2[0:nb]))
if format==31:
ref = ref / 2**15 # Because CMSIS format is 16.48
config.writeReferenceQ63(3, asReal(ref))
elif format==15:
ref = ref / 2**7 # Because CMSIS format is 8.24
config.writeReferenceQ31(3, asReal(ref))
else:
config.writeReference(3, asReal(ref))
#
nb = Tools.loopnb(format,Tools.BODYANDTAIL)
ref = np.array(np.dot(data1[0:nb] ,data2[0:nb]))
if format==31:
ref = ref / 2**15 # Because CMSIS format is 16.48
config.writeReferenceQ63(4, asReal(ref))
elif format==15:
ref = ref / 2**7 # Because CMSIS format is 8.24
config.writeReferenceQ31(4, asReal(ref))
else:
config.writeReference(4, asReal(ref))
#
ref = np.absolute(data1)
if format==31:
ref = ref / 2 # Because CMSIS format is 2.30
elif format==15:
ref = ref / 2 # Because CMSIS format is 2.14
config.writeReference(5, ref)
#
ref = np.absolute(data1)**2
if format==31:
ref = ref / 4 # Because CMSIS format is 3.29
elif format==15:
ref = ref / 4 # Because CMSIS format is 3.13
config.writeReference(6, ref)
#
ref = data1 * data2
if format==31:
ref = ref / 4 # Because CMSIS format is 3.29
elif format==15:
ref = ref / 4 # Because CMSIS format is 3.13
config.writeReference(7, asReal(ref))
#
ref = data1 * data3
config.writeReference(8, asReal(ref))
ref = np.array(np.dot(data1 ,data2))
if format==31:
ref = ref / 2**15 # Because CMSIS format is 16.48
config.writeReferenceQ63(9, asReal(ref))
elif format==15:
ref = ref / 2**7 # Because CMSIS format is 8.24
config.writeReferenceQ31(9, asReal(ref))
else:
config.writeReference(9, asReal(ref))
def generatePatterns():
PATTERNDIR = os.path.join("Patterns","DSP","ComplexMaths","ComplexMaths")
PARAMDIR = os.path.join("Parameters","DSP","ComplexMaths","ComplexMaths")
configf64=Tools.Config(PATTERNDIR,PARAMDIR,"f64")
configf32=Tools.Config(PATTERNDIR,PARAMDIR,"f32")
configf16=Tools.Config(PATTERNDIR,PARAMDIR,"f16")
configq31=Tools.Config(PATTERNDIR,PARAMDIR,"q31")
configq15=Tools.Config(PATTERNDIR,PARAMDIR,"q15")
configf32.setOverwrite(False)
configf16.setOverwrite(False)
configq31.setOverwrite(False)
configq15.setOverwrite(False)
writeTests(configf64,Tools.F64)
writeTests(configf32,0)
writeTests(configf16,16)
writeTests(configq31,31)
writeTests(configq15,15)
if __name__ == '__main__':
generatePatterns()