Skip to content

Improviser

Notochord MIDI co-improviser server. Notochord plays different instruments along with the player.

Authors

Victor Shepardson Intelligent Instruments Lab 2023

NotoTUI

Bases: TUI

Source code in src/notochord/app/improviser.py
class NotoTUI(TUI):
    CSS_PATH = 'improviser.css'

    BINDINGS = [
        ("m", "mute", "Mute Notochord"),
        ("s", "sustain", "Mute without ending notes"),
        ("q", "query", "Re-query Notochord"),
        ("r", "reset", "Reset Notochord")]

    def compose(self):
        """Create child widgets for the app."""
        yield Header()
        yield self.std_log
        yield NotoLog(id='note')
        yield NotoPrediction(id='prediction')
        yield NotoControl()
        yield Footer()

compose()

Create child widgets for the app.

Source code in src/notochord/app/improviser.py
def compose(self):
    """Create child widgets for the app."""
    yield Header()
    yield self.std_log
    yield NotoLog(id='note')
    yield NotoPrediction(id='prediction')
    yield NotoControl()
    yield Footer()

main(checkpoint='notochord-latest.ckpt', player_config=None, noto_config=None, initial_mute=False, initial_query=False, midi_in=None, midi_out=None, thru=False, send_pc=False, dump_midi=False, balance_sample=False, n_recent=64, n_margin=8, max_note_len=5, max_time=None, nominal_time=False, osc_port=None, osc_host='', use_tui=True, predict_player=True, auto_query=True, testing=False)

Parameters:

Name Type Description Default
checkpoint

path to notochord model checkpoint.

'notochord-latest.ckpt'
player_config Dict[int, int]

mapping from MIDI channels to MIDI instruments controlled by the player.

None
noto_config Dict[int, int]

mapping from MIDI channels to MIDI instruments controlled by notochord. Both indexed from 1. instruments should be different from the player instruments. channels should be different unless different ports are used. MIDI channels and General MIDI instruments are indexed from 1.

None
initial_mute

start Notochord muted so it won't play with input.

False
initial_query

query Notochord immediately so it plays even without input.

False
midi_in Optional[str]

MIDI ports for player input. default is to use all input ports. can be comma-separated list of ports.

None
midi_out Optional[str]

MIDI ports for Notochord output. default is to use only virtual 'From iipyper' port. can be comma-separated list of ports.

None
thru

if True, copy incoming MIDI to output ports. only makes sense if input and output ports are different.

False
send_pc

if True, send MIDI program change messages to set the General MIDI instrument on each channel according to player_config and noto_config. useful when using a General MIDI synthesizer like fluidsynth.

False
dump_midi

if True, print all incoming MIDI for debugging purposes

False
balance_sample

choose instruments which have played less recently ensures that all configured instruments will play.

False
n_recent

number of recent note-on events to consider for above

64
n_margin

amount of 'slack' in the balance_sample calculation

8
max_note_len

time in seconds after which to force-release sustained notochord notes.

5
max_time

maximum time in seconds between predicted events. default is the Notochord model's maximum (usually 10 seconds).

None
nominal_time

if True, feed Notochord with its own predicted times instead of the actual elapsed time. May make Notochord more likely to play chords.

False
osc_port

optional. if supplied, listen for OSC to set controls

None
osc_host

hostname or IP of OSC sender. leave this as empty string to get all traffic on the port

''
use_tui

run textual UI.

True
predict_player

forecasted next events can be for player. generally should be true, use balance_sample to force Notochord to play.

True
Source code in src/notochord/app/improviser.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
def main(
    checkpoint="notochord-latest.ckpt", # Notochord checkpoint
    player_config:Dict[int,int]=None, # map MIDI channel : GM instrument
    noto_config:Dict[int,int]=None, # map MIDI channel : GM instrument

    initial_mute=False, # start with Notochord muted
    initial_query=False, # let Notochord start playing immediately

    midi_in:Optional[str]=None, # MIDI port for player input
    midi_out:Optional[str]=None, # MIDI port for Notochord output
    thru=False, # copy player input to output
    send_pc=False, # send program change messages
    dump_midi=False, # print all incoming MIDI

    balance_sample=False, # choose instruments which have played less recently
    n_recent=64, # number of recent note-on events to consider for above
    n_margin=8, # amount of 'slack' in the balance_sample calculation

    max_note_len=5, # in seconds, to auto-release stuck Notochord notes
    max_time=None, # max time between events
    nominal_time=False, #feed Notochord with nominal dt instead of actual

    osc_port=None, # if supplied, listen for OSC to set controls on this port
    osc_host='', # leave this as empty string to get all traffic on the port

    use_tui=True, # run textual UI
    predict_player=True, # forecasted next events can be for player (preserves model distribution, but can lead to Notochord deciding not to play)
    auto_query=True, # query notochord whenever it is unmuted and there is no pending event. generally should be True except for testing purposes.
    testing=False
    ):
    """
    Args:
        checkpoint: path to notochord model checkpoint.

        player_config: mapping from MIDI channels to MIDI instruments controlled
            by the player.
        noto_config: mapping from MIDI channels to MIDI instruments controlled
            by notochord. Both indexed from 1.
            instruments should be different from the player instruments.
            channels should be different unless different ports are used.
            MIDI channels and General MIDI instruments are indexed from 1.

        initial_mute: start Notochord muted so it won't play with input.
        initial_query: query Notochord immediately so it plays even without input.

        midi_in: MIDI ports for player input. 
            default is to use all input ports.
            can be comma-separated list of ports.
        midi_out: MIDI ports for Notochord output. 
            default is to use only virtual 'From iipyper' port.
            can be comma-separated list of ports.
        thru: if True, copy incoming MIDI to output ports.
            only makes sense if input and output ports are different.
        send_pc: if True, send MIDI program change messages to set the General MIDI
            instrument on each channel according to player_config and noto_config.
            useful when using a General MIDI synthesizer like fluidsynth.
        dump_midi: if True, print all incoming MIDI for debugging purposes

        balance_sample: choose instruments which have played less recently
            ensures that all configured instruments will play.
        n_recent: number of recent note-on events to consider for above
        n_margin: amount of 'slack' in the balance_sample calculation

        max_note_len: time in seconds after which to force-release sustained
            notochord notes.
        max_time: maximum time in seconds between predicted events.
            default is the Notochord model's maximum (usually 10 seconds).
        nominal_time: if True, feed Notochord with its own predicted times
            instead of the actual elapsed time.
            May make Notochord more likely to play chords.

        osc_port: optional. if supplied, listen for OSC to set controls
        osc_host: hostname or IP of OSC sender.
            leave this as empty string to get all traffic on the port

        use_tui: run textual UI.
        predict_player: forecasted next events can be for player.
            generally should be true, use balance_sample to force Notochord to
            play.
        auto_query=True, # query notochord whenever it is unmuted and there is no pending event. generally should be True unless debugging.
    """
    if osc_port is not None:
        osc = OSC(osc_host, osc_port)
    midi = MIDI(midi_in, midi_out)

    ### Textual UI
    tui = NotoTUI()
    print = tui.print
    ###

    # default channel:instrument mappings
    if player_config is None:
        player_config = {1:1} # channel 1: grand piano
    if noto_config is None:
        noto_config = {2:257} # channel 2: anon

    # convert 1-indexed MIDI channels to 0-indexed here
    player_map = MIDIConfig({k-1:v for k,v in player_config.items()})
    noto_map = MIDIConfig({k-1:v for k,v in noto_config.items()})

    if len(player_map.insts & noto_map.insts):
        print("WARNING: Notochord and Player instruments shouldn't overlap")
        print('setting to an anonymous instrument')
        # TODO: set to anon insts without changing mel/drum
        # respecting anon insts selected for player
        raise NotImplementedError
    # TODO:
    # check for repeated insts/channels

    def warn_inst(i):
        if i > 128:
            if i < 257:
                print(f"WARNING: drum instrument {i} selected, be sure to select a drum bank in your synthesizer")
            else:
                print(f"WARNING: instrument {i} is not General MIDI")

    if send_pc:
        for c,i in (player_map | noto_map).items():
            warn_inst(i)
            midi.program_change(channel=c, program=(i-1)%128)

    # TODO: add arguments for this,
    # and sensible defaults for drums etc
    inst_pitch_map = {i: range(128) for i in noto_map.insts | player_map.insts}

    # load notochord model
    try:
        noto = Notochord.from_checkpoint(checkpoint)
        noto.eval()
        noto.reset()
    except Exception:
        print("""error loading notochord model""")
        raise

    # main stopwatch to track time difference between MIDI events
    stopwatch = Stopwatch()

    # simple class to hold pending event prediction
    class Prediction:
        def __init__(self):
            self.event = None
            self.gate = not initial_mute
    pending = Prediction()

    # query parameters controlled via MIDI / OSC
    controls = {}

    # tracks held notes, recently played instruments, etc
    history = NotoPerformance()

    def display_event(tag, memo, inst, pitch, vel, channel, **kw):
        """print an event to the terminal"""
        if tag is None:
            return
        s = f'{tag}:\t {inst=:4d}    {pitch=:4d}    {vel=:4d}    {channel=:3d}'
        if memo is not None:
            s += f'    ({memo})'
        tui(note=s)

    def play_event(event, channel, feed=True, send=True, tag=None, memo=None):
        """realize an event as MIDI, terminal display, and Notochord update"""
        # normalize values
        vel = event['vel'] = round(event['vel'])
        dt = stopwatch.punch()
        if 'time' not in event or not nominal_time:
            event['time'] = dt

        # send out as MIDI
        if send:
            midi.send(
                'note_on' if vel > 0 else 'note_off', 
                note=event['pitch'], velocity=vel, channel=channel)

        # feed to NotoPerformance
        # put a stopwatch in the held_note_data field for tracking note length
        history.feed(held_note_data=Stopwatch(), channel=channel, **event)

        # print
        display_event(tag, memo=memo, channel=channel, **event)

        # feed to Notochord
        if feed:
            noto.feed(**event)

    # @lock
    def noto_reset():
        """reset Notochord and end all of its held notes"""
        print('RESET')

        # cancel pending predictions
        pending.event = None
        tui(prediction=pending.event)

        # end Notochord held notes
        for (chan,inst,pitch) in history.note_triples:
            if inst in noto_map.insts:
                play_event(
                    dict(inst=inst, pitch=pitch, vel=0),
                    channel=chan, 
                    feed=False, # skip feeding Notochord since we are resetting it
                    tag='NOTO', memo='reset')
        # reset stopwatch
        stopwatch.punch()
        # reset notochord state
        noto.reset()
        # reset history
        history.push()
        # query the fresh notochord for a new prediction
        if pending.gate:
            noto_query()

    # @lock
    def noto_mute(sustain=False):
        tui.query_one('#mute').label = 'UNMUTE' if pending.gate else 'MUTE'
        # if sustain:
        tui.query_one('#sustain').label = 'END SUSTAIN' if pending.gate else 'SUSTAIN'

        pending.gate = not pending.gate

        if sustain:
            print('END SUSTAIN' if pending.gate else 'SUSTAIN')
        else:
            print('UNMUTE' if pending.gate else 'MUTE')
        # if unmuting, we're done
        if pending.gate:
            if sustain:
                noto_query()
            return
        # cancel pending predictions
        pending.event = None
        tui(prediction=pending.event)

        if sustain:
            return

        # end+feed all held notes
        for (chan,inst,pitch) in history.note_triples:
            if chan in noto_map:
                play_event(
                    dict(inst=inst, pitch=pitch, vel=0), 
                    channel=chan, tag='NOTO', memo='mute')

    # query Notochord for a new next event
    # @lock
    def noto_query():
        # check for stuck notes
        # and prioritize ending those
        for (_, inst, pitch), sw in history.note_data.items():
            if (
                inst in noto_map.insts 
                and sw.read() > max_note_len*(.1+controls.get('steer_duration', 1))
                ):
                # query for the end of a note with flexible timing
                # with profile('query', print=print):
                t = stopwatch.read()
                pending.event = noto.query(
                    next_inst=inst, next_pitch=pitch,
                    next_vel=0, min_time=t, max_time=t+0.5)
                print(f'END STUCK NOTE {inst=},{pitch=}')
                return

        counts = history.inst_counts(
            n=n_recent, insts=noto_map.insts | player_map.insts)
        print(counts)

        all_insts = noto_map.insts 
        if predict_player:
            all_insts = all_insts | player_map.insts

        held_notes = history.held_inst_pitch_map(all_insts)

        steer_time = 1-controls.get('steer_rate', 0.5)
        steer_pitch = controls.get('steer_pitch', 0.5)
        steer_density = controls.get('steer_density', 0.5)

        tqt = (max(0,steer_time-0.5), min(1, steer_time+0.5))
        tqp = (max(0,steer_pitch-0.5), min(1, steer_pitch+0.5))

        # if using nominal time,
        # *subtract* estimated feed latency to min_time; (TODO: really should
        #   set no min time when querying, use stopwatch when re-querying...)
        # if using actual time, *add* estimated query latency
        time_offset = -5e-3 if nominal_time else 10e-3
        min_time = stopwatch.read()+time_offset

        # balance_sample: note-ons only from instruments which have played less
        bal_insts = set(counts.index[counts <= counts.min()+n_margin])
        if balance_sample and len(bal_insts)>0:
            insts = bal_insts
        else:
            insts = all_insts

        # VTIP is better for time interventions,
        # VIPT is better for instrument interventions
        # could decide probabilistically based on value of controls + insts...
        if insts==all_insts:
            query_method = noto.query_vtip
        else:
            query_method = noto.query_vipt

        # print(f'considering {insts} for note_on')
        # use only currently selected instruments
        note_on_map = {
            i: set(inst_pitch_map[i])-set(held_notes[i]) # exclude held notes
            for i in insts
        }
        # use any instruments which are currently holding notes
        note_off_map = {
            i: set(ps)&set(held_notes[i]) # only held notes
            for i,ps in inst_pitch_map.items()
        }

        max_t = None if max_time is None else max(max_time, min_time+0.2)

        pending.event = query_method(
            note_on_map, note_off_map,
            min_time=min_time, max_time=max_t,
            truncate_quantile_time=tqt,
            truncate_quantile_pitch=tqp,
            steer_density=steer_density,
        )

        # display the predicted event
        tui(prediction=pending.event)

    #### MIDI handling

    # print all incoming MIDI for debugging
    if dump_midi:
        @midi.handle
        def _(msg):
            print(msg)

    @midi.handle(type='program_change')
    def _(msg):
        """Program change events set instruments"""
        if msg.channel in player_map:
            player_map[msg.channel] = msg.program
        if msg.channel in noto_map:
            noto_map[msg.channel] = msg.program

    @midi.handle(type='pitchwheel')
    def _(msg):
        controls['steer_pitch'] = (msg.pitch+8192)/16384
        # print(controls)

    # very basic CC handling for controls
    @midi.handle(type='control_change')
    def _(msg):
        """CC messages on any channel"""

        if msg.control==1:
            controls['steer_pitch'] = msg.value/127
            print(f"{controls['steer_pitch']=}")
        if msg.control==2:
            controls['steer_density'] = msg.value/127
            print(f"{controls['steer_density']=}")
        if msg.control==3:
            controls['steer_rate'] = msg.value/127
            print(f"{controls['steer_rate']=}")

        if msg.control==4:
            noto_reset()
        if msg.control==5:
            noto_query()
        if msg.control==6:
            noto_mute()

    # very basic OSC handling for controls
    if osc_port is not None:
        @osc.args('/notochord/improviser/*')
        def _(route, *a):
            print('OSC:', route, *a)
            ctrl = route.split['/'][3]
            if ctrl=='reset':
                noto_reset()
            elif ctrl=='query':
                noto_query()
            elif ctrl=='mute':
                noto_mute()
            else:
                assert len(a)==0
                arg = a[0]
                assert isinstance(arg, Number)
                controls[ctrl] = arg
                print(controls)

    @midi.handle(type=('note_on', 'note_off'))
    def _(msg):
        """MIDI NoteOn events from the player"""
        # if thru and msg.channel not in noto_map.channels:
            # midi.send(msg)

        if msg.channel not in player_map.channels:
            return

        inst = player_map[msg.channel]
        pitch = msg.note
        vel = msg.velocity if msg.type=='note_on' else 0

        # feed event to Notochord
        # with profile('feed', print=print):
        play_event(
            {'inst':inst, 'pitch':pitch, 'vel':vel}, 
            channel=msg.channel, send=thru, tag='PLAYER')

        # query for new prediction
        noto_query()

        # send a MIDI reply for latency testing purposes:
        # if testing: midi.cc(control=3, value=msg.note, channel=15)

    def noto_event():
        # notochord event happens:
        event = pending.event
        inst, pitch, vel = event['inst'], event['pitch'], round(event['vel'])

        # note on which is already playing or note off which is not
        if (vel>0) == ((inst, pitch) in history.note_pairs): 
            print(f're-query for invalid {vel=}, {inst=}, {pitch=}')
            noto_query()
            return

        chan = noto_map.inv(inst)
        play_event(event, channel=chan, tag='NOTO')

    @repeat(1e-3, lock=True)
    def _():
        """Loop, checking if predicted next event happens"""
        # check if current prediction has passed
        if (
            not testing and
            pending.gate and
            pending.event is not None and
            stopwatch.read() > pending.event['time']
            ):
            # if so, check if it is a notochord-controlled instrument
            if pending.event['inst'] in noto_map.insts:
                # prediction happens
                noto_event()
            # query for new prediction
            if auto_query:
                noto_query()

    @cleanup
    def _():
        """end any remaining notes"""
        # print(f'cleanup: {notes=}')
        for (chan,inst,pitch) in history.note_triples:
        # for (inst,pitch) in notes:
            if inst in noto_map.insts:
                midi.note_on(note=pitch, velocity=0, channel=chan)

    @tui.set_action
    def mute():
        noto_mute()

    @tui.set_action
    def sustain():
        noto_mute(sustain=True)

    @tui.set_action
    def reset():
        noto_reset()

    @tui.set_action
    def query():
        noto_query()

    if initial_query:
        noto_query()

    if use_tui:
        tui.run()