Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DspComplex multiplies use growing addition. #157

Merged
merged 1 commit into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,30 @@
package dsptools.numbers

import chisel3._
import chisel3.experimental.FixedPoint
import dsptools.hasContext
import implicits._
import chisel3.util.ShiftRegister
import dsptools.DspException

class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext {
abstract class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext {
def plus(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = {
DspComplex.wire(f.real + g.real, f.imag + g.imag)
}
def plusContext(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = {
DspComplex.wire(f.real context_+ g.real, f.imag context_+ g.imag)
}

/**
* The builtin times calls +. Ideally we'd like to use growing addition, but we're relying on typeclasses and the
* default + for UInt, SInt, etc. is wrapping. Thus, we're making an escape hatch just for the default (non-context)
* complex multiply.
* @param l
* @param r
* @return the sum of l and r, preferrably growing
*/
protected def plusForTimes(l: T, r: T): T

def times(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = {
val c_p_d = g.real + g.imag
val a_p_b = f.real + f.imag
Expand Down Expand Up @@ -59,6 +71,22 @@ class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext
}
}

class DspComplexRingUInt extends DspComplexRing[UInt] {
override def plusForTimes(l: UInt, r: UInt): UInt = l +& r
}

class DspComplexRingSInt extends DspComplexRing[SInt] {
override def plusForTimes(l: SInt, r: SInt): SInt = l +& r
}

class DspComplexRingFixed extends DspComplexRing[FixedPoint] {
override def plusForTimes(l: FixedPoint, r: FixedPoint): FixedPoint = l +& r
}

class DspComplexRingData[T <: Data : Ring] extends DspComplexRing[T] {
override protected def plusForTimes(l: T, r: T): T = l + r
}

class DspComplexEq[T <: Data:Eq] extends Eq[DspComplex[T]] with hasContext {
override def eqv(x: DspComplex[T], y: DspComplex[T]): Bool = {
Eq[T].eqv(x.real, y.real) && Eq[T].eqv(x.imag, y.imag)
Expand All @@ -81,9 +109,15 @@ class DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] extend
DspComplex.wire(BinaryRepresentation[T].trimBinary(a.real, n), BinaryRepresentation[T].trimBinary(a.imag, n))
}

trait DspComplexImpl {
implicit def DspComplexRingImpl[T<: Data:Ring] = new DspComplexRing[T]()
trait GenericDspComplexImpl {
implicit def DspComplexRingDataImpl[T<: Data:Ring] = new DspComplexRingData[T]()
implicit def DspComplexEq[T <: Data:Eq] = new DspComplexEq[T]()
implicit def DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] =
new DspComplexBinaryRepresentation[T]()
}

trait DspComplexImpl extends GenericDspComplexImpl {
implicit def DspComplexRingUIntImpl = new DspComplexRingUInt
implicit def DspComplexRingSIntImpl = new DspComplexRingSInt
implicit def DspComplexRingFixedImpl = new DspComplexRingFixed
}
4 changes: 2 additions & 2 deletions src/test/scala/dsptools/numbers/DspComplexSpec.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// See LICENSE for license details.

package dsptools.numbers
package testing.dsptools.numbers

import chisel3._
import chisel3.iotesters.ChiselPropSpec
import chisel3.testers.BasicTester
import dsptools.numbers.implicits._
import dsptools.numbers._

//scalastyle:off magic.number
class DspComplexExamples extends Module {
Expand Down